Sequence head not working as intended?
Hello,
ESMC obviously has very good general protein representations, but it seems to fail the most basic reconstruction task of returning the input sequence. Is this expected?
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
from tqdm.auto import tqdm
client = ESMC.from_pretrained("esmc_300m").to("cuda") # or "cpu"
accs = []
for seq in tqdm(sequences): # 1000 random sequences from Uniref50
protein = ESMProtein(sequence=seq)
protein_tensor = client.encode(protein)
input_ids = protein_tensor.sequence[1:-1]
logits_output = client.logits(
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
preds = logits_output.logits.sequence.argmax(dim=-1)[0][1:-1]
matching = sum(preds.cpu().numpy() == input_ids.cpu().numpy())
accs.append(matching / len(input_ids))
print(sum(accs) / len(accs) * 100)
35.51462866990015
don't know if this may be the issue but if you check the output shape of the sequence head it return you a 64-dimensional matrix of which only the first ( 23 ? ) tokens are actually used, so before doing an argmax you have to first subset to only valid tokens, then using softmax function and then doing argmax, i'm not shure how many 'unusefull' tokens there are in the decoder but i think this is the right usage
Good suggestion! but should not change the answer.
1.) softmax will not change the answer from argmax with the full tensor present, so it is redundant if you just want the prediction and not the predicted probabilities. I suppose it could change the answer by indexing or slicing first, but this would be unexpected too.
2.) It was trained with all 64 (presumably) so it should still guess correctly when given an unnoised input.
Indeed:
preds = logits_output.logits.sequence[:, :, :23].argmax(dim=-1)[0][1:-1] # take only amino acid tokens
34.92075859473515
or
preds = logits_output.logits.sequence[:, :, :23].softmax(dim=-1).argmax(dim=-1)[0][1:-1] # softmax first
34.92075859473515
i still think there are some problem with indexing...
client.tokenizer('ACDEFGHIKLMNPQRSTVWY')['input_ids']
outputs:
[0, 5, 23, 13, 9, 18, 6, 21, 12, 15, 4, 20, 17, 14, 16, 10, 8, 11, 7, 22, 19, 2]
while
client.tokenizer('<cls><pad><eos><unk>|_')['input_ids']
outputs:
[0, 0, 1, 2, 3, 31, 3, 2]
What do you mean? That looks fine to me
seq = ['AAAAAAAAAAAAAAAAA']
with torch.no_grad():
tok = client.tokenizer(seq,add_special_tokens=True,padding=True)
ids = torch.tensor( tok['input_ids'],dtype=torch.int64).to('cuda')
logit_manual = client(ids).sequence_logits
logit_manual[0,1:-1,4:24].argmax(dim=1) + 4
output:
[20, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
that is fine, since first aa should be 'M' while all the other aa are 'A' as for the input
I don't understand the +4, but 20 for methionine is still wrong... Not every protein starts with M. This isn't an issue for other plms: they output an unnoised input perfectly.
first 3 tokens are special tokens so you consider only the tokens from 4 to 23, then you add 4 to the result to map the output ids back to input ids. I think 'perfectly' is not something you could expect from a neural network and also ESM2 was biased in predicting the first aminoacid as an M since most of the sequences in the train set starts with M.
however predicting non-masked tokens is not the way how LLM are intendend to be used.
I tried running my code on 1000 random human proteins and using your metric obtaining an accuracy of 0.54, which compared to a random predictor accuracy of 0.05, is definetly a good performance.
Correct way to benchmark LLM should be masked token prediction or perplexity. if you plan to futher investigate performance of this model i will suggest masking 15-20% of the tokens in each sequence and compute # number of correctly predicted token / number of masked token for each sequence.
Otherwise if you have time and hardware you could try to compute perplexity of each sequence ( mask 1 token at time and get the likelyhood across the full sequence)
Noticed this issue as well. Based on what I found, my guess is that:
- The encode/decode methods provided are working as expected. There doesn't seem to be a need to fiddle around with the output logit indices or worry about the argmax over the predicted tokens hitting special characters.
- ESMC is actually just worse(?) at passing through sequences than ESM3.
- Manually removing the BOS/EOS tokens from the encoded sequence improves pass-through.
SETUP: I tested ESMC and ESM3, trying to get them to just directly pass-through the first 1000 sequences of Swiss-prot. I plotted histograms for the percent residues per protein where the encoded target sequence tokens and argmax over the predicted logits matched. I also just checked for the argmax of any predicted residue being for a non-AA token.
RESULTS: See this gist if you want to reproduce yourself. Also, at least with spot-checking I described, I didn't see the argmax for any predictions being a special token.
I wonder if this might also have something to do with it. The ESM3 paper says:
"In addition to enabling generation, ESM3’s training objective is also effective for representation learning. High masking rates improve the generative capability, while lower masking rates improve representation learning. We chose to train ESM3 with a noise schedule that balances generative capabilities with representation learning (Appendix A.2.2)."
I assume that means ESMC is trained with a far lower masking rate. Maybe that's causing some of these issues?
Maybe, I have a feeling it has to do with some post training quantization, as I've seen a lot of suspicious integers in logits and hidden states. Curious what Evo scale people think.
first 3 tokens are special tokens so you consider only the tokens from 4 to 23, then you add 4 to the result to map the output ids back to input ids. I think 'perfectly' is not something you could expect from a neural network and also ESM2 was biased in predicting the first aminoacid as an M since most of the sequences in the train set starts with M.
ESM2 is indeed biased towards M, but it is much better at this pass through task than ESMC or ESM3.
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D").cuda()
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
canonical_amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
accs = []
for seq in tqdm(sequences): # 10000 random sequences from Uniref50
input_ids = tokenizer(seq, return_tensors='pt', add_special_tokens=True, truncation=True, max_length=1024).input_ids.cuda()
logits = model(input_ids).logits
preds = logits.argmax(dim=-1)[0][1:-1].detach().cpu().numpy()
input_ids = input_ids.detach().cpu().numpy()[0][1:-1]
matching = sum(preds == input_ids)
accs.append(matching / len(input_ids))
print(sum(accs) / len(accs) * 100)
sequences = [random.choice(canonical_amino_acids) + seq[1:] for seq in sequences]
first_pos_accs = []
for seq in tqdm(sequences): # 10000 random sequences from Uniref50
input_ids = tokenizer(seq, return_tensors='pt', add_special_tokens=True, truncation=True, max_length=1024).input_ids.cuda()
logits = model(input_ids).logits
preds = logits.argmax(dim=-1)[0][1:-1].detach().cpu().numpy()
input_ids = input_ids.detach().cpu().numpy()[0][1:-1]
matching = sum(preds == input_ids)
accs.append(matching / len(input_ids))
# Check if first position was predicted correctly
first_pos_accs.append(preds[0] == input_ids[0])
print(f"Overall accuracy: {sum(accs) / len(accs) * 100:.2f}%")
print(f"First position accuracy: {sum(first_pos_accs) / len(first_pos_accs) * 100:.2f}%")
Regular seq accuracy
97.18935011435738
Random starting
Overall accuracy: 96.86%
First position accuracy: 6.03%
Hi, have you figured out the reason? I find when do not add special tokens, it seems working well.
@xiwilliam this seems like a bug to me, can you take a look? I would expect pretty high pass-through performance with ESMC.