Quantcast
Viewing all articles
Browse latest Browse all 14215

Any way to intervention on calling _ctc_loss from Pytorch to Huggingface transformers

When trying to fine-tune for ASR using M2 Max, I get "The operator 'aten::_ctc_loss' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications." When I run it on CPU, the training loss is OK (decreasing), but it is very slow. When I run it on MPS, it is much faster but training loss is increasing, and WER becomes 1 or greater. Is there any way how to intervene directly to functional.py file or CTCLoss class from Pytorch so it works fine with MPS? I understand that not all ops cannot be supported for MPS soon, but I want to better understand how these ops function and if there can be any intervention (to give a try) on the installed Pytorch library? The link for functional.py is here. Or should I try to intervene to Wav2Vec library from Huggingface, where it calls ctc_loss from pytorch, which I am using for ASR? I want to give a try, but don't where to start, since I don't have much experience with Pytorch and Huggingface transformers.


Viewing all articles
Browse latest Browse all 14215

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>