I am using BertModel for training in Transformers. I want to change max_length
in mini-batch training. but this leads to OOM(Out of memory). The reason is that the constant tensor created in BertModel is not being recycled. I can use gc.collect()
and torch.cuda.empty_cache()
to force recovery, but this slows down training. How can I predict that OOM might occur and call gc.collect()
and torch.cuda.empty_cache()
appropriately?The core code is as follows.
def truncate_inputs( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, labels: torch.Tensor, max_length: int, ): input_ids = input_ids[:, :max_length].contiguous() attention_mask = attention_mask[:, :max_length].contiguous() token_type_ids = token_type_ids[:, :max_length].contiguous() labels = labels[:, :max_length].contiguous() return input_ids, attention_mask, token_type_ids, labels
recycle:
class GCCallBack(TrainerCallback): def on_step_end( self, args: NERTrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): gc.collect() torch.cuda.empty_cache()
If I don't clear the cache. The speed is 10it/s (out of memory). If I follow the code, the speed is 3it/s~5it/s (max length is dynamic).Baseline is 2it/s (max length is constant)