CogVideo icon indicating copy to clipboard operation
CogVideo copied to clipboard

About Attention_map Visualization

Open Hryxyhe opened this issue 1 year ago • 1 comments

System Info / 系統信息

diffusers>=0.3.0 torch==2.4.0 torchvision==0.19.0

Information / 问题信息

  • [X] The official example scripts / 官方的示例脚本
  • [ ] My own modified scripts / 我自己修改的脚本和任务

Reproduction / 复现过程

I want to visualize the 3D-full-attention maps in the Expert Transformer blocks. However, I noted that with pytorch>2.0(mine is 2.4.0), the default attn_processor is AttnProcessor2_0 with F.scaled_dot_product_attention to directly compute the final QKV results. The F.scaled_dot_product_attention Function is defined in C and hard to modify its output. I also change to oringinal AttnProcessor but it is OOM. It seems that the F.scaled_dot_product_attention can compute efficiently and save about 10GB+ memory?

Expected behavior / 期待表现

Based on AttnProcessor2_0, how can I get the attn_weight= Q* K(T) without complex code modification?

Hryxyhe avatar Aug 11 '24 09:08 Hryxyhe

Instead of using flash attention, you can use vanilla attention calculations. To avoid OOM, you can manually split the Q(K^T) matrix multiplication into blocks.

tengjiayan20 avatar Aug 11 '24 15:08 tengjiayan20

@Hryxyhe you can pass an identity matrix to F.scaled_dot_product_attention instead of the actual values, which will return the self-attention between the queries and the keys.

tnarek avatar Sep 09 '24 12:09 tnarek

def compute_attn_map(q, k, batch_chunk=1, head_chunk=1):
    """
    Computes q @ k^T with chunking to avoid OOM issues.
    
    Args:
        q (torch.Tensor): Query tensor of shape (n, head, S, feature).
        k (torch.Tensor): Key tensor of shape (n, head, S, feature).
        batch_chunk (int): Number of batches to process at a time.
        head_chunk (int): Number of attention heads to process at a time.

    Returns:
        torch.Tensor: Result tensor of shape (n, head, S, S).
    """
    # Ensure q and k have the expected dimensions
    assert q.dim() == 4 and k.dim() == 4, "q and k must be 4D tensors"
    assert q.shape == k.shape, "q and k must have the same shape"

    # Get dimensions
    n, head, S, feature = q.shape
    
    # Initialize result tensor (on GPU)
    attn_map = torch.zeros(n, head, S, S, device=q.device, dtype=q.dtype)
    
    # Chunked computation
    for i in range(0, n, batch_chunk):
        for j in range(0, head, head_chunk):
            # Select chunks for current batch and head
            q_chunk = q[i:i+batch_chunk, j:j+head_chunk]  # Shape: (batch_chunk, head_chunk, S, feature)
            k_chunk = k[i:i+batch_chunk, j:j+head_chunk]  # Shape: (batch_chunk, head_chunk, S, feature)
            
            # Compute q @ k^T
            attn_chunk = torch.matmul(q_chunk, k_chunk.transpose(-1, -2))  # Shape: (batch_chunk, head_chunk, S, S)
            
            # Assign to the result tensor
            attn_map[i:i+batch_chunk, j:j+head_chunk] = attn_chunk

    return attn_map


def chunk_dot_product(query, key, num_chunks=10):
    
    chunk_size = query.shape[1] // num_chunks
    
    attn_weight = torch.zeros(query.shape[0], query.shape[1], query.shape[2], query.shape[2], device=query.device, dtype=query.dtype)
    for i in range(num_chunks):
        q_chunk = query[:, i*chunk_size:(i+1)*chunk_size]
        k_chunk = key[:, i*chunk_size:(i+1)*chunk_size]
        attn_weight[:, i*chunk_size:(i+1)*chunk_size] = torch.matmul(q_chunk, k_chunk.transpose(-1, -2))
    
    return attn_weight

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight```

in scaled_dot_product_attention the query @ key will got OOM, so I using the above function chunk_dot_product or compute_attn_map but still can't solve this problem due to I have to create a     attn_map = torch.zeros(n, head, S, S, device=q.device, dtype=q.dtype)
I don't know is there any thing can solve this I want to get the query @ key before the attn_weight softmax

JustinKai0527 avatar Jan 21 '25 14:01 JustinKai0527