efficientvit
efficientvit copied to clipboard
Could someone explain to me the operation of efficient attention?
def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
B, _, H, W = list(qkv.size())
if qkv.dtype == torch.float16:
qkv = qkv.float()
qkv = torch.reshape(
qkv,
(
B,
-1,
3 * self.dim,
H * W,
),
)
q, k, v = (
qkv[:, :, 0 : self.dim],
qkv[:, :, self.dim : 2 * self.dim],
qkv[:, :, 2 * self.dim :],
)
# lightweight linear attention
q = self.kernel_func(q)
k = self.kernel_func(k)
# linear matmul
trans_k = k.transpose(-1, -2)
v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
vk = torch.matmul(v, trans_k)
out = torch.matmul(vk, q)
out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)
out = torch.reshape(out, (B, -1, H, W))
return out
In the above segment, could someone explain why we need the pad operation?
v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
vk = torch.matmul(v, trans_k)
out = torch.matmul(vk, q)
out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)