TriForce
TriForce copied to clipboard
Attention Scores Matrix Visualization
Hi, I would like to ask why the attention mask is not used in the prefill stage. I want to output the attention scores matrix in prefill stage. Is the code below right?
if spec: # spec decoding
key_states, value_states = graph_cache.update(new_k_cache=key_states, new_v_cache=value_states, layer_idx=self.layer_idx)
else:
# update kv cache first
key_states, value_states = kv_cache.update(key_states, value_states, layer_idx=self.layer_idx)
if query_states.shape[1] == 1 and (isinstance(graph_cache, RetrievalCache)):
if graph_cache.init_graph == False:
# init graph cache
graph_cache.init_graph_cache(kv_cache, query_states, self.layer_idx)
else:
# update graph cache (customized)
graph_cache.update_graph_cache_retrieval(kv_cache, query_states, self.layer_idx)
# 计算注意力得分矩阵
attention_scores = torch.einsum("bqhd,bkhd->bhqk", query_states, key_states)
attention_scores /= torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
if attention_mask is not None:
attention_mask = attention_mask.to(attention_scores.device)
attention_scores += attention_mask
attn_output = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, softmax_scale=1/torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float16)), causal=True)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attention_scores
Hello,
We use flash attention function which already has causal mask for prefilling phase.
It should be noted that it is easy to have OOM issue when you are trying to compute attention matrix directly for long sequences.