esm
esm copied to clipboard
How can we get attention weights from example sequence and structure?
How can we get attention weights from example sequence and structure? There were no arguments to get attention weights in transformer blocks, unlike esm2.
also interested in this feature, if available!
Unfortunately, pytorch flash attention doesn't let you do this. You'll have to hack it in, we'll look into support it officially. Here's where the attention is computed, you'll just have use a pytorch implementation of attention to expose the attention matrix.
https://github.com/evolutionaryscale/esm/blob/39a3a6cb1e722347947dc375e3f8e2ba80ed8b59/esm/layers/attention.py#L62-L75