audio icon indicating copy to clipboard operation
audio copied to clipboard

RNNT Loss Accepting Log Probabilities

Open BriansIDP opened this issue 3 years ago • 2 comments

🚀 The feature

@carolineechen The current RNNT loss takes logits as inputs. I wonder if it is possible to have a version that takes log probabilities rather than logits.

Motivation, pitch

I am implementing the Tree-constrained pointer generator in TorchAudio. As it performs an interpolation between the model output distribution and the distribution derived from the biasing component named TCPGen, it would be much easier to implement this interpolation outside the RNNT loss. Therefore, it would be helpful if RNNT loss accepts log probabilities as inputs.

Alternatives

No response

Additional context

No response

BriansIDP avatar Sep 14 '22 19:09 BriansIDP

Hi @BriansIDP, thanks for the request -- I'll take a look into the implementation to see how feasible this is. Just making sure I understand correctly, you want to be able to pass in the log softmax directly into the RNNT loss, rather than it being computed as part of the loss?

carolineechen avatar Sep 20 '22 14:09 carolineechen

Hi @BriansIDP, thanks for the request -- I'll take a look into the implementation to see how feasible this is. Just making sure I understand correctly, you want to be able to pass in the log softmax directly into the RNNT loss, rather than it being computed as part of the loss?

Hi @carolineechen. That's right. Thank you!

BriansIDP avatar Sep 20 '22 14:09 BriansIDP

Hi @BriansIDP, sorry for the late reply -- I think we should be able to add this feature, but it likely won't be available as of the upcoming release, probably the one after. I can update this thread when it becomes available in our nightly build

carolineechen avatar Oct 03 '22 15:10 carolineechen

Hi @BriansIDP, sorry for the late reply -- I think we should be able to add this feature, but it likely won't be available as of the upcoming release, probably the one after. I can update this thread when it becomes available in our nightly build

Hi @carolineechen. Thank you for your reply! When do you reckon the nightly build would be ready for this? Thank you!

BriansIDP avatar Oct 03 '22 15:10 BriansIDP

I would estimate we can have it ready by Nov/Dec

carolineechen avatar Oct 05 '22 16:10 carolineechen

Hi @BriansIDP, I created a PR to add support for this in #2798. Unit tests seem to be passing, but I haven't gotten the chance to do testing in a training environment yet (busy w/ other work so might need 1-2 weeks). Could you take a look at the PR to see if the API makes sense for what you're looking for, and happy to hear any feedback on this if you get the chance to play around with it.

cc @xiaohui-zhang

carolineechen avatar Oct 27 '22 15:10 carolineechen

Hi @carolineechen. Thank you so much for the implementation! Could you please help me by pointing to which API I should be looking at since I only found changed files on that PR? Also, I'd like to confirm, I have to pull from the branch called rnntl-log-probs and do setup.py again. I will be able to actually try this feature out next Month due to my current resource limitation, but I will update you with any progress I make.

BriansIDP avatar Oct 27 '22 16:10 BriansIDP

I added the API usage to the PR description, usage should be similar to before, except you can call log_softmax on the logits outside of the rnnt loss function, and then pass in fused_log_softmax=False to the function argument to treat the input as log probs instead of logits. [docs from the PR]

And yes, you'll have to pull in the changes from the PR branch, and locally run setup.py again to

carolineechen avatar Oct 27 '22 17:10 carolineechen

I added the API usage to the PR description, usage should be similar to before, except you can call log_softmax on the logits outside of the rnnt loss function, and then pass in fused_log_softmax=False to the function argument to treat the input as log probs instead of logits. [docs from the PR]

And yes, you'll have to pull in the changes from the PR branch, and locally run setup.py again to

Great! I think that's the feature I need. Thank you very much!

BriansIDP avatar Oct 27 '22 19:10 BriansIDP

addressed in #2798

carolineechen avatar Nov 09 '22 17:11 carolineechen