Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 23131

Issues with Custom FISTA Optimizer and Model State Rollback in PyTorch

$
0
0

I've developed a custom FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) optimizer in PyTorch for a project I'm working on. The optimizer works perfectly under normal circumstances. However, I've encountered a problem when trying to rollback the model parameters to a previous state during training.

import torch import torch.nn as nnclass FISTA(torch.optim.Optimizer):    def __init__(self, params, lr, lambda_):        if lr < 0.0:            raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")        if lambda_ < 0.0:            raise ValueError(f"Invalid lambda: {lambda_} - should be >= 0.0")        defaults = dict(lr=lr, lambda_=lambda_)        super(FISTA, self).__init__(params, defaults)    def shrinkage_operator(self, u, tresh):        return torch.sign(u) * torch.maximum(torch.abs(u) - tresh, torch.tensor(0.0, device=u.device))    def step(self, closure=None):        loss = None        if closure is not None:            with torch.enable_grad():                loss = closure()        for group in self.param_groups:            for p in group['params']:                if p.grad is None:                    continue                grad = p.grad.data                lr = group['lr']                lambda_ = group['lambda_']                state = self.state[p]                # State initialization                if len(state) == 0:                    state = self.state[p]                    state['x_prev'] = p.data                    state['y_prev'] = p.data.clone()                    state['t_prev'] = torch.tensor(1., device=p.device)                x_prev, y_prev, t_prev = state['x_prev'], state['y_prev'], state['t_prev']                x_next = self.shrinkage_operator(y_prev - lr * grad, lambda_)                t_next = (1. + torch.sqrt(1. + 4. * t_prev ** 2)) / 2.                y_next = x_next + ((t_prev - 1) / t_next) * (x_next - x_prev)                state['x_prev'], state['y_prev'], state['t_prev'] = x_next, y_next, t_next                p.data.copy_(x_next)        return loss

This optimizer is then used in my model train loop:

optimizer_penalized = FISTA(params=layer1.parameters(), lambda_=lambda_, lr=lr)optimizer_unpenalized = FISTA(params=layer2.parameters(), lambda_=0.0, lr=lr)

I need two of them because I only want the first layer of my model to be shrinked. Now, this optimizer works perfectly fine until I modify layer1 and layer2. Let's say I train my model and at every epoch creating a lower cost value, I save the layers like so :

layer1_before_dict = (layer1.weight.data.clone().detach(), layer1.bias.data.clone().detach())layer2_before_dict = (layer2.weight.data.clone().detach(), layer2.bias.data.clone().detach())

and when an epoch increases the cost, I change the learning rate and call

with torch.no_grad():                layer1.weight.copy_(layer1_before_dict[0])                layer1.bias.copy_(layer1_before_dict[1])                layer2.weight.copy_(layer2_before_dict[0])                layer2.bias.copy_(layer2_before_dict[1])

The issue arises here: after rolling back the layer parameters, the changes don't seem to reflect within the optimizer's internal state. The model layers have the right values, but calling the optimiser step method doesn't return the expected values.

I tried recreating the optimiser every time the loss increases, but that's not working either.


Viewing all articles
Browse latest Browse all 23131

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>