I want to learn how to convert pytorch model into TorchScript. To do that I have to define a torch.nn.Module class that wraps the model first.
I use HuggingFace Diffusers or Transformers class to wrap the models and convert into TorchScript before. I want to know how to define the wrapper class myself. If I have only the downloaded pytorch model, is it possible to define a wrapper class? Or is there anything I have to know?
Below is my code for a downloaded pre-trained model.
import torchPATH = 'model.pth'pretrained_dict = torch.load(PATH)for key in list(pretrained_dict.keys()): print(key)class MyModel(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return 0model = MyModel()model.load_state_dict(pretrained_dict)model.eval()example_input = torch.rand(1, 3, 224, 224) torch_script = torch.jit.trace(model, example_input)
output:
tok_embeddings.weightnorm.weightoutput.weightlayers.0.attention.wq.weightlayers.0.attention.wk.weightlayers.0.attention.wv.weightlayers.0.attention.wo.weightlayers.0.feed_forward.w1.weightlayers.0.feed_forward.w2.weightlayers.0.feed_forward.w3.weightlayers.0.attention_norm.weightlayers.0.ffn_norm.weightlayers.1.attention.wq.weightlayers.1.attention.wk.weightlayers.1.attention.wv.weightlayers.1.attention.wo.weightlayers.1.feed_forward.w1.weightlayers.1.feed_forward.w2.weightlayers.1.feed_forward.w3.weightlayers.1.attention_norm.weightlayers.1.ffn_norm.weightlayers.2.attention.wq.weightlayers.2.attention.wk.weightlayers.2.attention.wv.weightlayers.2.attention.wo.weight...layers.31.feed_forward.w3.weightlayers.31.attention_norm.weightlayers.31.ffn_norm.weightrope.freqs--> 17 model.load_state_dict(pretrained_dict) 18 model.eval() 19 example_input = torch.rand(1, 3, 224, 224) File ~/text-generation-webui-main/installer_files/env/lib/python3.10/site-packages/torch/nn/modules/module.py:2041, in Module.load_state_dict(self, state_dict, strict) 2036 error_msgs.insert( 2037 0, 'Missing key(s) in state_dict: {}. '.format( 2038 ', '.join('"{}"'.format(k) for k in missing_keys))) 2040 if len(error_msgs) > 0:-> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.__class__.__name__, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)RuntimeError: Error(s) in loading state_dict for MyModel: Unexpected key(s) in state_dict: "tok_embeddings.weight", "norm.weight", "output.weight", "layers.0.attention.wq.weight", "layers.0.attention.wk.weight", "layers.0.attention.wv.weight", "layers.0.attention.wo.weight", "layers.0.feed_forward.w1.weight", "layers.0.feed_forward.w2.weight", "layers.....