I've developed a custom FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) optimizer in PyTorch for a project I'm working on. The optimizer works perfectly under normal circumstances. However, I've encountered a problem when trying to rollback the model parameters to a previous state during training.
import torch import torch.nn as nnclass FISTA(torch.optim.Optimizer): def __init__(self, params, lr, lambda_): if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") if lambda_ < 0.0: raise ValueError(f"Invalid lambda: {lambda_} - should be >= 0.0") defaults = dict(lr=lr, lambda_=lambda_) super(FISTA, self).__init__(params, defaults) def shrinkage_operator(self, u, tresh): return torch.sign(u) * torch.maximum(torch.abs(u) - tresh, torch.tensor(0.0, device=u.device)) def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data lr = group['lr'] lambda_ = group['lambda_'] state = self.state[p] # State initialization if len(state) == 0: state = self.state[p] state['x_prev'] = p.data state['y_prev'] = p.data.clone() state['t_prev'] = torch.tensor(1., device=p.device) x_prev, y_prev, t_prev = state['x_prev'], state['y_prev'], state['t_prev'] x_next = self.shrinkage_operator(y_prev - lr * grad, lambda_) t_next = (1. + torch.sqrt(1. + 4. * t_prev ** 2)) / 2. y_next = x_next + ((t_prev - 1) / t_next) * (x_next - x_prev) state['x_prev'], state['y_prev'], state['t_prev'] = x_next, y_next, t_next p.data.copy_(x_next) return lossThis optimizer is then used in my model train loop:
optimizer_penalized = FISTA(params=layer1.parameters(), lambda_=lambda_, lr=lr)optimizer_unpenalized = FISTA(params=layer2.parameters(), lambda_=0.0, lr=lr)I need two of them because I only want the first layer of my model to be shrinked. Now, this optimizer works perfectly fine until I modify layer1 and layer2. Let's say I train my model and at every epoch creating a lower cost value, I save the layers like so :
layer1_before_dict = (layer1.weight.data.clone().detach(), layer1.bias.data.clone().detach())layer2_before_dict = (layer2.weight.data.clone().detach(), layer2.bias.data.clone().detach())and when an epoch increases the cost, I change the learning rate and call
with torch.no_grad(): layer1.weight.copy_(layer1_before_dict[0]) layer1.bias.copy_(layer1_before_dict[1]) layer2.weight.copy_(layer2_before_dict[0]) layer2.bias.copy_(layer2_before_dict[1])The issue arises here: after rolling back the layer parameters, the changes don't seem to reflect within the optimizer's internal state. The model layers have the right values, but calling the optimiser step method doesn't return the expected values.
I tried recreating the optimiser every time the loss increases, but that's not working either.