About Attention_map Visualization
System Info / 系統信息
diffusers>=0.3.0 torch==2.4.0 torchvision==0.19.0
Information / 问题信息
- [X] The official example scripts / 官方的示例脚本
- [ ] My own modified scripts / 我自己修改的脚本和任务
Reproduction / 复现过程
I want to visualize the 3D-full-attention maps in the Expert Transformer blocks. However, I noted that with pytorch>2.0(mine is 2.4.0), the default attn_processor is AttnProcessor2_0 with F.scaled_dot_product_attention to directly compute the final QKV results. The F.scaled_dot_product_attention Function is defined in C and hard to modify its output. I also change to oringinal AttnProcessor but it is OOM. It seems that the F.scaled_dot_product_attention can compute efficiently and save about 10GB+ memory?
Expected behavior / 期待表现
Based on AttnProcessor2_0, how can I get the attn_weight= Q* K(T) without complex code modification?
Instead of using flash attention, you can use vanilla attention calculations. To avoid OOM, you can manually split the Q(K^T) matrix multiplication into blocks.
@Hryxyhe you can pass an identity matrix to F.scaled_dot_product_attention instead of the actual values, which will return the self-attention between the queries and the keys.
def compute_attn_map(q, k, batch_chunk=1, head_chunk=1):
"""
Computes q @ k^T with chunking to avoid OOM issues.
Args:
q (torch.Tensor): Query tensor of shape (n, head, S, feature).
k (torch.Tensor): Key tensor of shape (n, head, S, feature).
batch_chunk (int): Number of batches to process at a time.
head_chunk (int): Number of attention heads to process at a time.
Returns:
torch.Tensor: Result tensor of shape (n, head, S, S).
"""
# Ensure q and k have the expected dimensions
assert q.dim() == 4 and k.dim() == 4, "q and k must be 4D tensors"
assert q.shape == k.shape, "q and k must have the same shape"
# Get dimensions
n, head, S, feature = q.shape
# Initialize result tensor (on GPU)
attn_map = torch.zeros(n, head, S, S, device=q.device, dtype=q.dtype)
# Chunked computation
for i in range(0, n, batch_chunk):
for j in range(0, head, head_chunk):
# Select chunks for current batch and head
q_chunk = q[i:i+batch_chunk, j:j+head_chunk] # Shape: (batch_chunk, head_chunk, S, feature)
k_chunk = k[i:i+batch_chunk, j:j+head_chunk] # Shape: (batch_chunk, head_chunk, S, feature)
# Compute q @ k^T
attn_chunk = torch.matmul(q_chunk, k_chunk.transpose(-1, -2)) # Shape: (batch_chunk, head_chunk, S, S)
# Assign to the result tensor
attn_map[i:i+batch_chunk, j:j+head_chunk] = attn_chunk
return attn_map
def chunk_dot_product(query, key, num_chunks=10):
chunk_size = query.shape[1] // num_chunks
attn_weight = torch.zeros(query.shape[0], query.shape[1], query.shape[2], query.shape[2], device=query.device, dtype=query.dtype)
for i in range(num_chunks):
q_chunk = query[:, i*chunk_size:(i+1)*chunk_size]
k_chunk = key[:, i*chunk_size:(i+1)*chunk_size]
attn_weight[:, i*chunk_size:(i+1)*chunk_size] = torch.matmul(q_chunk, k_chunk.transpose(-1, -2))
return attn_weight
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight```
in scaled_dot_product_attention the query @ key will got OOM, so I using the above function chunk_dot_product or compute_attn_map but still can't solve this problem due to I have to create a attn_map = torch.zeros(n, head, S, S, device=q.device, dtype=q.dtype)
I don't know is there any thing can solve this I want to get the query @ key before the attn_weight softmax