Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 14243

Is it possible to statically wrap/derive a parameter from another paramerter in pytorch?

$
0
0

I often run into situations where I want to train a parameter, but that parameter only gets used by the model in a transformed form. For example, think of a simple scalar parameter defined on (-∞, +∞), but the model only wants to use that parameter wrapped in a sigmoid, to constrain it to (0, 1) e.g. to model a probability. A naive attempt to implemented such a wrapped/derived parameter would be:

class ConstrainedModel(torch.nn.Module):    def __init__(self):        super().__init__()        self.x_raw = torch.nn.Parameter(torch.tensor(0.0))        self.x = torch.nn.functional.sigmoid(self.x_raw)    def forward(self) -> torch.Tensor:        # An actual model of course would use self.x in a more sophisticated way...        return self.x

The model is only really interested in using the constrained/transformed self.x, e.g. because it is semantically more relevant than the underlying/unstransformed self.x_raw. In a sense, the implementation tries to hide away self.x_raw because it is more confusing to think in this unconstrained parameter.

However training this model with something like this...

def main():    model = ConstrainedModel()    opt = torch.optim.Adam(model.parameters())    loss_func = torch.nn.MSELoss()    y_truth = torch.tensor(0.9)    for i in range(10000):        y_predicted = model.forward()        loss = loss_func.forward(y_predicted, y_truth)        print(f"iteration: {i+1}    loss: {loss.item()}    x: {model.x.item()}")        loss.backward()        opt.step()        opt.zero_grad()

... will fail in the second iteration with the infamous error:

RuntimeError: Trying to backward through the graph a second time [...]

Note that the reason here is different from the common causes (1, 2) of this error: In this case the issue is hidden by the fact that the model has wrapped/derived self.x from self.x_raw in its constructor, which leads to holding a reference to something from the previous backwards pass.

My standard work-around is to make all wrapping repeatedly/dynamically only inside forward, i.e., this would work:

class ConstrainedModelWorkAround(torch.nn.Module):    def __init__(self):        super().__init__()        self.x_raw = torch.nn.Parameter(torch.tensor(0.0))    def forward(self) -> torch.Tensor:        x = torch.nn.functional.sigmoid(self.x_raw)        return x

But now the model only exposes model.x_raw and lacks the more human-friendly model.x representation of the parameter, which would be nicer to work with inside the model code, and which would also be more convenient to monitor during training (e.g. in a tensorboard display).

I'm wondering if I'm missing a trick how to achieve a kind of "static" wrapping of parameters instead?


Viewing all articles
Browse latest Browse all 14243

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>