zigma icon indicating copy to clipboard operation
zigma copied to clipboard

How to implement torch.compile for Mamba models?

Open yyNoBug opened this issue 1 year ago • 1 comments

Hi, I notice from your README file that torch.compile provides a great speedup. However, I didn't see where you implemented torch.compile for your train_acc.py. I tried to add torch.compile(model) for models containing Mamba blocks, but it causes some errors. May I know how you implemented torch.compile for your zigzag model? Thanks!

yyNoBug avatar Apr 15 '24 13:04 yyNoBug

torch.compile is mainly used for indexing operation for the zigzag path, not for the whole model.

see https://github.com/CompVis/zigma/blob/1e78944ebce400d34a12efd4baba1daad0fae9f3/dis_mamba/mamba_ssm/modules/mamba_simple.py#L55

and

https://github.com/CompVis/zigma/blob/1e78944ebce400d34a12efd4baba1daad0fae9f3/dis_mamba/mamba_ssm/modules/mamba_simple.py#L60

dongzhuoyao avatar Apr 15 '24 18:04 dongzhuoyao