When trying to fine-tune for ASR using M2 Max, I get "The operator 'aten::_ctc_loss' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications." When I run it on CPU, the training loss is OK (decreasing), but it is very slow. When I run it on MPS, it is much faster but training loss is increasing, and WER becomes 1 or greater. Is there any way how to intervene directly to functional.py file or CTCLoss class from Pytorch so it works fine with MPS? I understand that not all ops cannot be supported for MPS soon, but I want to better understand how these ops function and if there can be any intervention (to give a try) on the installed Pytorch library? The link for functional.py is here. Or should I try to intervene to Wav2Vec library from Huggingface, where it calls ctc_loss from pytorch, which I am using for ASR? I want to give a try, but don't where to start, since I don't have much experience with Pytorch and Huggingface transformers.
↧