iftopt
An Implicit Function Theorem (IFT) optimizer for bi-level optimizations.
Requirements
- Python 3.7+
- PyTorch 1.x
Installation
$ pip install git+https://github.com/money-shredder/iftopt.git
Usage
Assuming a bi-level optimization of the form:
y*= argmin_{y} val_loss(x*,y), wherex*= argmin_{x} train_loss(x,y).
To solve for the optimal x* and y* in the optimization problem, we can implement the following with iftopt:
from iftopt import HyperOptimizer
train_lr = val_lr = 0.1
# parameter to minimize the training loss
x = torch.nn.Parameter(...)
# hyper-parameter to minimize the validation loss
y = torch.nn.Parameter(...)
# training loss optimizer
opt = torch.optim.SGD([x], lr=train_lr)
# validation loss optimizer
hopt = HyperOptimizer(
[y], torch.optim.SGD([y], lr=val_lr), vih_lr=0.1, vih_iterations=5)
# outer optimization loop for y
for _ in range(...):
# inner optimization loop for x
for _ in range(...):
z = train_loss(x, y)
# inner optimization step for x
opt.zero_grad()
z.backward()
opt.step()
# outer optimization step for y
hopt.set_train_parameters([x])
z = train_loss(x, y)
hopt.train_step(z)
v = val_loss(x, y)
hopt.val_step(v)
hopt.grad()
hopt.step()
For a concrete simple example, please check out and run demo.py, where
train_loss = lambda x, y: (x + y) ** 2
val_loss = lambda x, y: x ** 2
with x = y = 1.0 initially. It will generate a video demo.mp4 showing the optimization trajectory in the animation below. Note that although the hyper-parameter y does not have a direct gradient w.r.t. the validation loss, iftopt can still minimize the validation loss by computing the hyper-gradient via implicit function theorem.
