Avoid creating tensor in CosmosAttnProcessor2_0 (#11761)
What does this PR do?
Fixes #11761
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline?
- [ ] Did you read our philosophy doc (important for complex PRs)?
- [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Just wanted to make a note that the reason this has to be a tensor is because it seemingly breaks ONNX export. I had it implemented the same way earlier but changed to this after suggestions from the nvidia team. cc @asfiyab-nvidia
ohh thanks for the info @a-r-r-o-w 10% speed difference is really not small though (if confirmed)
but I think this size ratio is actually determined by config (inner_dim vs inner_kv_dim) and won't vary at run time, no?
@yiyixuxu Yeah it shouldn't vary and we can compute this beforehand. I think the problem stemmed from using an integer (or int-like type) to do the repeat_interleave instead of a tensor. So, it doesn't matter if we compute it with query.idx(...) / key.idx(...) or pre-caculate the ratio. I'm not sure about the details but it looks like there were a few similar issues (https://github.com/pytorch/pytorch/issues/100429, for example) which have been marked as resolved. They are very simple examples though, so this "mark dynamic" thing probably does not work with Cosmos (I'm too unfamiliar with ONNX to comment on this).
I don't think we have most of our model definitions compatible with ONNX though (has this been checked before?), so I think it might be okay to make a not and break this compatibility?
ohh thanks @a-r-r-o-w
do you have the original conversation about onnx export breaking? we could try to look into a solution if there is a reproducible script for the issue
otherwise, I think the easiest way is we could put an if else in with if torch.onnx.is_in_onnx_export() else statement in. Let me know what you think!
@yiyixuxu Unfortunately, there's not much to gather from the original conversation. You can find it here: https://github.com/huggingface/diffusers/pull/10660#issuecomment-2752624025
Your suggestion sounds good to me
cc @chenxiao111222 can you add a if torch.onnx.is_in_onnx_export( ) and keep the original code path there?
Gentle ping @chenxiao111222. Let me know if you would like me to update with the required changes 🤗
Gentle ping @chenxiao111222. Let me know if you would like me to update with the required changes 🤗
I'm sorry for the late reply. I mainly use native PyTorch, so I didn't pay attention to the ONNX issue. If you're willing to fix it, please feel free to do so.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.