TriForce icon indicating copy to clipboard operation
TriForce copied to clipboard

Attention Scores Matrix Visualization

Open bulaikexiansheng opened this issue 1 year ago • 1 comments

Hi, I would like to ask why the attention mask is not used in the prefill stage. I want to output the attention scores matrix in prefill stage. Is the code below right?

        if spec: # spec decoding
            key_states, value_states = graph_cache.update(new_k_cache=key_states, new_v_cache=value_states, layer_idx=self.layer_idx)
        else:
            # update kv cache first
            key_states, value_states = kv_cache.update(key_states, value_states, layer_idx=self.layer_idx)
            if query_states.shape[1] == 1 and (isinstance(graph_cache, RetrievalCache)): 
                if graph_cache.init_graph == False:
                    # init graph cache
                    graph_cache.init_graph_cache(kv_cache, query_states, self.layer_idx)
                else:
                    # update graph cache (customized)
                    graph_cache.update_graph_cache_retrieval(kv_cache, query_states, self.layer_idx)

        # 计算注意力得分矩阵
        attention_scores = torch.einsum("bqhd,bkhd->bhqk", query_states, key_states)
        attention_scores /= torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if attention_mask is not None:
            attention_mask = attention_mask.to(attention_scores.device)
            attention_scores += attention_mask

        attn_output = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, softmax_scale=1/torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float16)), causal=True)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output, attention_scores

bulaikexiansheng avatar Aug 27 '24 03:08 bulaikexiansheng

Hello,

We use flash attention function which already has causal mask for prefilling phase.

It should be noted that it is easy to have OOM issue when you are trying to compute attention matrix directly for long sequences.

preminstrel avatar Aug 27 '24 03:08 preminstrel