Differentiable Proximal Algorithm Modeling for Large-Scale Optimization
Paper | Tutorials | Examples | Documentation | Citation
∇-Prox is a domain-specific language (DSL) and compiler that transforms optimization problems into differentiable proximal solvers. Departing from handwriting these solvers and differentiating via autograd, ∇-Prox requires only a few lines of code to define a solver that can be specialized based on user requirements w.r.t memory constraints or training budget by optimized algorithm unrolling, deep equilibrium learning, and deep reinforcement learning. ∇-Prox makes it easier to prototype different learning-based bi-level optimization problems for a diverse range of applications. We compare our framework against existing methods with naive implementations. ∇-Prox is significantly more compact in terms of lines of code and compares favorably in memory consumption in applications across domains.
-
August 2023 :
$\nabla$ -Prox is presented at SIGGRAPH 2023 and its code base is now public. -
May 2023 :
$\nabla$ -Prox is accepted as a journal paper at SIGGRAPH 2023.
We recommend installing
pip install dprox
Please refer to the Installation guide for other options.
Consider a simple image deconvolution problem, where we want to find a clean image
where
from dprox import *
from dprox.utils import *
from dprox.contrib import *
img = sample()
psf = point_spread_function(15, 5)
b = blurring(img, psf)
x = Variable()
data_term = sum_squares(conv(x, psf) - b)
reg_term = deep_prior(x, denoiser='ffdnet_color')
prob = Problem(data_term + reg_term)
prob.solve(method='admm', x0=b)
We can also specialize the solver via bi-level optimization. For example, we can specialize the solver into a reinforcement learning (RL) solver for automatic parameter tuning.
solver = compile(data_term + reg_term, method='admm')
rl_solver = specialize(solver, method='rl')
rl_solver = train(rl_solver, **training_kwargs)
Alternatively, we can specialize the solver into an unrolled solver for end-to-end optics optimization.
x = Variable()
y = Placeholder()
PSF = Placeholder()
data_term = sum_squares(conv_doe(x, PSF, circular=True) - y)
reg_term = deep_prior(x, denoiser='ffdnet_color')
solver = compile(data_term + reg_term, method='admm')
unrolled_solver = specialize(solver, method='unroll', max_iter=10)
# training doe model and hyperparameters
doe_model = build_doe_model()
doe_model.rhos = nn.parameter.Parameter(rhos)
doe_model.lams = nn.parameter.Parameter(lams)
def step_fn(gt):
psf = doe_model.get_psf()
inp = img_psf_conv(gt, psf, circular=True)
inp = inp + torch.randn(*inp.shape) * sigma
y.value = inp
PSF.value = psf
out = solver.solve(x0=inp, rhos=doe_model.rhos, lams={reg_term: doe_model.lams})
return gt, inp, out
train(doe_model, step_fn, dataset)
Want to learn more? Check out the step-by-step tutorials for the framework and its applications.
@article{deltaprox2023,
title = {$\nabla$-Prox: Differentiable Proximal Algorithm Modeling for Large-Scale Optimization},
author = {Lai, Zeqiang and Wei, Kaixuan and Fu, Ying and H\"{a}rtel, Philipp and Heide, Felix},
journal={ACM Transactions on Graphics (TOG)},
volume = {42},
number = {4},
articleno = {105},
pages = {1--19},
year={2023},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
doi = {10.1145/3592144},
}