I often run into situations where I want to train a parameter, but that parameter only gets used by the model in a transformed form. For example, think of a simple scalar parameter defined on (-∞, +∞)
, but the model only wants to use that parameter wrapped in a sigmoid, to constrain it to (0, 1)
e.g. to model a probability. A naive attempt to implemented such a wrapped/derived parameter would be:
class ConstrainedModel(torch.nn.Module): def __init__(self): super().__init__() self.x_raw = torch.nn.Parameter(torch.tensor(0.0)) self.x = torch.nn.functional.sigmoid(self.x_raw) def forward(self) -> torch.Tensor: # An actual model of course would use self.x in a more sophisticated way... return self.x
The model is only really interested in using the constrained/transformed self.x
, e.g. because it is semantically more relevant than the underlying/unstransformed self.x_raw
. In a sense, the implementation tries to hide away self.x_raw
because it is more confusing to think in this unconstrained parameter.
However training this model with something like this...
def main(): model = ConstrainedModel() opt = torch.optim.Adam(model.parameters()) loss_func = torch.nn.MSELoss() y_truth = torch.tensor(0.9) for i in range(10000): y_predicted = model.forward() loss = loss_func.forward(y_predicted, y_truth) print(f"iteration: {i+1} loss: {loss.item()} x: {model.x.item()}") loss.backward() opt.step() opt.zero_grad()
... will fail in the second iteration with the infamous error:
RuntimeError: Trying to backward through the graph a second time [...]
Note that the reason here is different from the common causes (1, 2) of this error: In this case the issue is hidden by the fact that the model has wrapped/derived self.x
from self.x_raw
in its constructor, which leads to holding a reference to something from the previous backwards pass.
My standard work-around is to make all wrapping repeatedly/dynamically only inside forward
, i.e., this would work:
class ConstrainedModelWorkAround(torch.nn.Module): def __init__(self): super().__init__() self.x_raw = torch.nn.Parameter(torch.tensor(0.0)) def forward(self) -> torch.Tensor: x = torch.nn.functional.sigmoid(self.x_raw) return x
But now the model only exposes model.x_raw
and lacks the more human-friendly model.x
representation of the parameter, which would be nicer to work with inside the model code, and which would also be more convenient to monitor during training (e.g. in a tensorboard display).
I'm wondering if I'm missing a trick how to achieve a kind of "static" wrapping of parameters instead?