Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 14126

Wrap a pre-trained pytorch model into torch.nn.Module class

$
0
0

I want to learn how to convert pytorch model into TorchScript. To do that I have to define a torch.nn.Module class that wraps the model first.

I use HuggingFace Diffusers or Transformers class to wrap the models and convert into TorchScript before. I want to know how to define the wrapper class myself. If I have only the downloaded pytorch model, is it possible to define a wrapper class? Or is there anything I have to know?

Below is my code for a downloaded pre-trained model.

import torchPATH = 'model.pth'pretrained_dict = torch.load(PATH)for key in list(pretrained_dict.keys()):    print(key)class MyModel(torch.nn.Module):    def __init__(self):        super().__init__()    def forward(self, x):        return 0model = MyModel()model.load_state_dict(pretrained_dict)model.eval()example_input = torch.rand(1, 3, 224, 224) torch_script = torch.jit.trace(model, example_input)

output:

tok_embeddings.weightnorm.weightoutput.weightlayers.0.attention.wq.weightlayers.0.attention.wk.weightlayers.0.attention.wv.weightlayers.0.attention.wo.weightlayers.0.feed_forward.w1.weightlayers.0.feed_forward.w2.weightlayers.0.feed_forward.w3.weightlayers.0.attention_norm.weightlayers.0.ffn_norm.weightlayers.1.attention.wq.weightlayers.1.attention.wk.weightlayers.1.attention.wv.weightlayers.1.attention.wo.weightlayers.1.feed_forward.w1.weightlayers.1.feed_forward.w2.weightlayers.1.feed_forward.w3.weightlayers.1.attention_norm.weightlayers.1.ffn_norm.weightlayers.2.attention.wq.weightlayers.2.attention.wk.weightlayers.2.attention.wv.weightlayers.2.attention.wo.weight...layers.31.feed_forward.w3.weightlayers.31.attention_norm.weightlayers.31.ffn_norm.weightrope.freqs--> 17 model.load_state_dict(pretrained_dict)     18 model.eval()     19 example_input = torch.rand(1, 3, 224, 224) File ~/text-generation-webui-main/installer_files/env/lib/python3.10/site-packages/torch/nn/modules/module.py:2041, in Module.load_state_dict(self, state_dict, strict)   2036         error_msgs.insert(   2037             0, 'Missing key(s) in state_dict: {}. '.format(   2038                 ', '.join('"{}"'.format(k) for k in missing_keys)))   2040 if len(error_msgs) > 0:-> 2041     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(   2042                        self.__class__.__name__, "\n\t".join(error_msgs)))   2043 return _IncompatibleKeys(missing_keys, unexpected_keys)RuntimeError: Error(s) in loading state_dict for MyModel:    Unexpected key(s) in state_dict: "tok_embeddings.weight", "norm.weight", "output.weight", "layers.0.attention.wq.weight", "layers.0.attention.wk.weight", "layers.0.attention.wv.weight", "layers.0.attention.wo.weight", "layers.0.feed_forward.w1.weight", "layers.0.feed_forward.w2.weight", "layers.....

Viewing all articles
Browse latest Browse all 14126

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>