Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 13951

Transformers(Pytorch) how to resume train when out of memory?

$
0
0

I am using BertModel for training in Transformers. I want to change max_length in mini-batch training. but this leads to OOM(Out of memory). The reason is that the constant tensor created in BertModel is not being recycled. I can use gc.collect() and torch.cuda.empty_cache() to force recovery, but this slows down training. How can I predict that OOM might occur and call gc.collect() and torch.cuda.empty_cache() appropriately?The core code is as follows.

def truncate_inputs(        self,        input_ids: torch.Tensor,        attention_mask: torch.Tensor,        token_type_ids: torch.Tensor,        labels: torch.Tensor,        max_length: int,    ):        input_ids = input_ids[:, :max_length].contiguous()        attention_mask = attention_mask[:, :max_length].contiguous()        token_type_ids = token_type_ids[:, :max_length].contiguous()        labels = labels[:, :max_length].contiguous()        return input_ids, attention_mask, token_type_ids, labels

recycle:

class GCCallBack(TrainerCallback):    def on_step_end(        self,        args: NERTrainingArguments,        state: TrainerState,        control: TrainerControl,        **kwargs,    ):                    gc.collect()        torch.cuda.empty_cache()

If I don't clear the cache. The speed is 10it/s (out of memory). If I follow the code, the speed is 3it/s~5it/s (max length is dynamic).Baseline is 2it/s (max length is constant)


Viewing all articles
Browse latest Browse all 13951

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>