Avoiding memory-leaks when training on MPS
Call the following function after each training step:
torch.mps.empty_cache()
For example, you can use a trainer callback:
class RAMCleaner(transformers.TrainerCallback):
def on_step_end(self, args, state, control, logs=None, **kwargs):
print("Emptying cache...")
torch.mps.empty_cache()
trainer = transformers.Trainer(
# ... other options ...
callbacks=[RAMCleaner]
)
trainer.train()