RNNT Loss Accepting Log Probabilities
🚀 The feature
@carolineechen The current RNNT loss takes logits as inputs. I wonder if it is possible to have a version that takes log probabilities rather than logits.
Motivation, pitch
I am implementing the Tree-constrained pointer generator in TorchAudio. As it performs an interpolation between the model output distribution and the distribution derived from the biasing component named TCPGen, it would be much easier to implement this interpolation outside the RNNT loss. Therefore, it would be helpful if RNNT loss accepts log probabilities as inputs.
Alternatives
No response
Additional context
No response
Hi @BriansIDP, thanks for the request -- I'll take a look into the implementation to see how feasible this is. Just making sure I understand correctly, you want to be able to pass in the log softmax directly into the RNNT loss, rather than it being computed as part of the loss?
Hi @BriansIDP, thanks for the request -- I'll take a look into the implementation to see how feasible this is. Just making sure I understand correctly, you want to be able to pass in the log softmax directly into the RNNT loss, rather than it being computed as part of the loss?
Hi @carolineechen. That's right. Thank you!
Hi @BriansIDP, sorry for the late reply -- I think we should be able to add this feature, but it likely won't be available as of the upcoming release, probably the one after. I can update this thread when it becomes available in our nightly build
Hi @BriansIDP, sorry for the late reply -- I think we should be able to add this feature, but it likely won't be available as of the upcoming release, probably the one after. I can update this thread when it becomes available in our nightly build
Hi @carolineechen. Thank you for your reply! When do you reckon the nightly build would be ready for this? Thank you!
I would estimate we can have it ready by Nov/Dec
Hi @BriansIDP, I created a PR to add support for this in #2798. Unit tests seem to be passing, but I haven't gotten the chance to do testing in a training environment yet (busy w/ other work so might need 1-2 weeks). Could you take a look at the PR to see if the API makes sense for what you're looking for, and happy to hear any feedback on this if you get the chance to play around with it.
cc @xiaohui-zhang
Hi @carolineechen. Thank you so much for the implementation! Could you please help me by pointing to which API I should be looking at since I only found changed files on that PR? Also, I'd like to confirm, I have to pull from the branch called rnntl-log-probs and do setup.py again. I will be able to actually try this feature out next Month due to my current resource limitation, but I will update you with any progress I make.
I added the API usage to the PR description, usage should be similar to before, except you can call log_softmax on the logits outside of the rnnt loss function, and then pass in fused_log_softmax=False to the function argument to treat the input as log probs instead of logits. [docs from the PR]
And yes, you'll have to pull in the changes from the PR branch, and locally run setup.py again to
I added the API usage to the PR description, usage should be similar to before, except you can call log_softmax on the logits outside of the rnnt loss function, and then pass in
fused_log_softmax=Falseto the function argument to treat the input as log probs instead of logits. [docs from the PR]And yes, you'll have to pull in the changes from the PR branch, and locally run setup.py again to
Great! I think that's the feature I need. Thank you very much!
addressed in #2798