Skip to content

Commit

Permalink
update docs & add a training notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeqiang-Lai committed Aug 6, 2023
1 parent 65ebbdf commit 1f469b4
Show file tree
Hide file tree
Showing 35 changed files with 463 additions and 203 deletions.
Empty file.
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Paper <https://dl.acm.org/doi/abs/10.1145/3592144>

🎉 ∇-Prox is a domain-specific language (DSL) and compiler that transforms optimization problems into differentiable proximal solvers.
<br/>
🎉 ∇-Prox allows for rapid prototyping of learning-based bi-level optimization problems for a diverse range of applications, by [optimized algorithm unrolling](https://pypi.org/project/dprox/), [deep equilibrium learning](https://pypi.org/project/dprox/), and [deep reinforcement learning](https://pypi.org/project/dprox/).
🎉 ∇-Prox allows for rapid prototyping of learning-based bi-level optimization problems for a diverse range of applications, by [optimized algorithm unrolling](api/primitive), [deep equilibrium learning](api/primitive), and [deep reinforcement learning](api/primitive).

The library includes the following major components:

Expand Down
8 changes: 4 additions & 4 deletions docs/source/started/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ pip install dprox
**Install from source**

```bash
pip install git+https://github.com/Zeqiang-Lai/DeltaProx.git
pip install git+https://github.com/princeton-computational-imaging/Delta-Prox.git
```

**Editable install**
**Editable installation**

You will need an editable install if you would like to:

Expand All @@ -26,11 +26,11 @@ You will need an editable install if you would like to:
To do so, clone the repository and install 🎉 Delta Prox with the following commands:

```
git clone git+https://github.com/Zeqiang-Lai/DeltaProx.git
git clone git+https://github.com/princeton-computational-imaging/Delta-Prox.git
cd DeltaProx
pip install -e .
```

```{caution}
Note that you must keep the DeltaProx folder if you want to keep using the library.
Note that you must keep the DeltaProx folder for editable installation if you want to keep using the library.
```
5 changes: 3 additions & 2 deletions dprox/algo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .problem import Problem, compile, specialize, visualize, optimize
from .problem import Problem
from .admm import ADMM, ADMM_vxu, LinearizedADMM
from .hqs import HQS
from .pc import PockChambolle
from .pgd import ProximalGradientDescent
from .base import Algorithm
from .tune.dpir import log_descent
from .special import AutoTuneSolver, DEQSolver, UnrolledSolver
from .specialization import AutoTuneSolver, DEQSolver, UnrolledSolver
from .primitives import compile, specialize, visualize, optimize, train
191 changes: 191 additions & 0 deletions dprox/algo/primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from pathlib import Path
from typing import List, Union

import torch
import torch.nn.functional as F
import torchlight as tl
import torchlight.nn as tlnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from dprox import *
from dprox.contrib.optic import Dataset
from dprox.proxfn import ProxFn
from dprox.utils import *

from . import opt
from .admm import ADMM, ADMM_vxu, LinearizedADMM
from .base import Algorithm
from .hqs import HQS
from .pc import PockChambolle
from .pgd import ProximalGradientDescent
from .specialization import DEQSolver, UnrolledSolver, AutoTuneSolver, build_unrolled_solver

SOLVERS = {
'admm': ADMM,
'admm_vxu': ADMM_vxu,
'ladmm': LinearizedADMM,
'hqs': HQS,
'pc': PockChambolle,
'pgd': ProximalGradientDescent,
}

SPECAILIZATIONS = {
'deq': DEQSolver,
'rl': AutoTuneSolver,
'unroll': build_unrolled_solver,
}

def compile(
prox_fns: List[ProxFn],
method: str = 'admm',
device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu',
**kwargs
):
"""
Compile the given objective (in terms of a list of proxable functions) into a proximal solver.
>>> solver = compile(data_term+reg_term, method='admm')
Args:
prox_fns (List[ProxFn]): A list or the sum of proxable functions.
method (str): A string that specifies the name of the optimization method to use. Defaults to `admm`.
Valid methods include [`admm`, `admm_vxu`, `ladmm`, `hqs`, `pc`, `pgd`].
device (Union[str, torch.device]): The device (CPU or GPU) on which the solver should run.
It can be either a string ('cpu' or 'cuda') or a `torch.device` object. Defaults to cuda if avaliable.
Returns:
An instance of a solver object that is created using the specified algorithm and proximal functions.
"""
algorithm: Algorithm = SOLVERS[method]
device = torch.device(device) if isinstance(device, str) else device

psi_fns, omega_fns = algorithm.partition(prox_fns)
solver = algorithm.create(psi_fns, omega_fns, **kwargs)
solver = solver.to(device)
return solver


def specialize(
solver: Algorithm,
method: str = 'deq',
device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu',
**kwargs
):
"""
Specialize the given solver based on the given method.
>>> deq_solver = specialize(solver, method='deq')
>>> rl_solver = specialize(solver, method='rl')
>>> unroll_solver = specialize(solver, method='unroll')
Args:
solver (Algorithm): the proximal solver that need to be specialized.
method (str): the strategy for the specialization. Choose from [`deq`, `rl`, `unroll`].
device (Union[str, torch.device]): The device (CPU or GPU) on which the solver should run.
It can be either a string ('cpu' or 'cuda') or a `torch.device` object. Defaults to cuda if avaliable
Returns:
The specialized solver.
"""
solver = SPECAILIZATIONS[method](solver, **kwargs)
device = torch.device(device) if isinstance(device, str) else device
solver = solver.to(device)
return solver


def optimize(
prox_fns: List[ProxFn],
merge=False,
absorb=False
):
if absorb:
prox_fns = opt.absorb.absorb_all_linops(prox_fns)
return prox_fns


def visualize():
pass


def train(
model,
step_fn,
dataset='BSD500',
savedir='saved',
epochs=10,
bs=2,
lr=1e-4,
resume=None,
):
savedir = Path(savedir)
savedir.mkdir(exist_ok=True, parents=True)
logger = tl.logging.Logger(savedir)

# ----------------- Start Training ------------------------ #
root = hf.download_dataset(dataset, force_download=False)
dataset = Dataset(root)
loader = DataLoader(dataset, batch_size=bs, shuffle=True)

optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4, weight_decay=1e-3
)
tlnn.utils.adjust_learning_rate(optimizer, lr)

epoch = 0
gstep = 0
best_psnr = 0
imgdir = savedir / 'imgs'
imgdir.mkdir(exist_ok=True, parents=True)

if resume:
ckpt = torch.load(savedir / resume)
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
epoch = ckpt['epoch'] + 1
gstep = ckpt['gstep'] + 1
best_psnr = ckpt['best_psnr']

def save_ckpt(name, psnr=0):
ckpt = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'gstep': gstep,
'psnr': psnr,
'best_psnr': best_psnr,
}
torch.save(ckpt, savedir / name)

save_ckpt('last.pth')
while epoch < epochs:
tracker = tl.trainer.util.MetricTracker()
pbar = tqdm(total=len(loader), dynamic_ncols=True, desc=f'Epcoh[{epoch}]')

for i, batch in enumerate(loader):

gt, inp, pred = step_fn(batch)

loss = F.mse_loss(gt, pred)
loss.backward()

optimizer.step()
optimizer.zero_grad()

psnr = tl.metrics.psnr(pred, gt)
loss = loss.item()
tracker.update('loss', loss)
tracker.update('psnr', psnr)

pbar.set_postfix({'loss': f'{tracker["loss"]:.4f}',
'psnr': f'{tracker["psnr"]:.4f}'})
pbar.update()

gstep += 1

logger.info('Epoch {} Loss={} LR={}'.format(epoch, tracker['loss'], tlnn.utils.get_learning_rate(optimizer)))

save_ckpt('last.pth', tracker['psnr'])
pbar.close()
epoch += 1
96 changes: 2 additions & 94 deletions dprox/algo/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,100 +6,8 @@
from dprox.linop.constaints import equality, less, matmul
from dprox.proxfn import ProxFn

from . import lp, opt
from .admm import ADMM, ADMM_vxu, LinearizedADMM
from .base import Algorithm
from .hqs import HQS
from .pc import PockChambolle
from .pgd import ProximalGradientDescent
from .special import DEQSolver, UnrolledSolver

SOLVERS = {
'admm': ADMM,
'admm_vxu': ADMM_vxu,
'ladmm': LinearizedADMM,
'hqs': HQS,
'pc': PockChambolle,
'pgd': ProximalGradientDescent,
}

SPECAILIZATIONS = {
'deq': DEQSolver,
'rl': UnrolledSolver,
'unroll': UnrolledSolver,
}


def compile(
prox_fns: List[ProxFn],
method: str = 'admm',
device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu',
**kwargs
):
"""
Compile the given objective (in terms of a list of proxable functions) into a proximal solver.
>>> solver = compile(data_term+reg_term, method='admm')
Args:
prox_fns (List[ProxFn]): A list or the sum of proxable functions.
method (str): A string that specifies the name of the optimization method to use. Defaults to `admm`.
Valid methods include [`admm`, `admm_vxu`, `ladmm`, `hqs`, `pc`, `pgd`].
device (Union[str, torch.device]): The device (CPU or GPU) on which the solver should run.
It can be either a string ('cpu' or 'cuda') or a `torch.device` object. Defaults to cuda if avaliable.
Returns:
An instance of a solver object that is created using the specified algorithm and proximal functions.
"""
algorithm: Algorithm = SOLVERS[method]
device = torch.device(device) if isinstance(device, str) else device

psi_fns, omega_fns = algorithm.partition(prox_fns)
solver = algorithm.create(psi_fns, omega_fns, **kwargs)
solver = solver.to(device)
return solver


def specialize(
solver: Algorithm,
method: str = 'deq',
device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu',
**kwargs
):
"""
Specialize the given solver based on the given method.
>>> deq_solver = specialize(solver, method='deq')
>>> rl_solver = specialize(solver, method='rl')
>>> unroll_solver = specialize(solver, method='unroll')
Args:
solver (Algorithm): the proximal solver that need to be specialized.
method (str): the strategy for the specialization. Choose from [`deq`, `rl`, `unroll`].
device (Union[str, torch.device]): The device (CPU or GPU) on which the solver should run.
It can be either a string ('cpu' or 'cuda') or a `torch.device` object. Defaults to cuda if avaliable
Returns:
The specialized solver.
"""
solver = SPECAILIZATIONS[method](solver, **kwargs)
device = torch.device(device) if isinstance(device, str) else device
solver = solver.to(device)
return solver


def optimize(
prox_fns: List[ProxFn],
merge=False,
absorb=False
):
if absorb:
prox_fns = opt.absorb.absorb_all_linops(prox_fns)
return prox_fns


def visualize():
pass
from . import lp
from .primitives import compile


class Problem:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .deq import DEQSolver, train_deq
from .unroll import UnrolledSolver
from .unroll import UnrolledSolver, build_unrolled_solver
from .rl import AutoTuneSolver
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,25 +1,34 @@
import torch.nn as nn
import copy
from functools import partial

import torch
import torch.nn as nn

from ..base import auto_convert_to_tensor, move


from ..base import move, auto_convert_to_tensor
def clone(x, nums, share):
return [x if share else copy.deepcopy(x) for _ in range(nums)]


def clone(x, nums, shared):
return [x if shared else copy.deepcopy(x) for _ in range(nums)]
def build_unrolled_solver(solver, share=True, **kwargs):
if share == True:
solver.solve = partial(solver.solve, **kwargs)
return solver
return UnrolledSolver(solver, share=share, **kwargs)


class UnrolledSolver(nn.Module):
def __init__(self, solver, max_iter, shared=False, learned_params=False):
def __init__(self, solver, max_iter, share=False, learned_params=False):
super().__init__()
if shared == False:
self.solvers = nn.ModuleList(clone(solver, max_iter, shared=shared))
if share == False:
self.solvers = nn.ModuleList(clone(solver, max_iter, share=share))
else:
self.solver = solver
self.solvers = [self.solver for _ in range(max_iter)]

self.max_iter = max_iter
self.shared = shared
self.share = share

self.learned_params = learned_params
if learned_params:
Expand Down
Loading

0 comments on commit 1f469b4

Please sign in to comment.