From c9c973b3b566602ac27a9d0fa3d0102c80938605 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 27 Oct 2024 11:51:11 -0700 Subject: [PATCH] Experimenting with differential attention --- timm/models/vision_transformer.py | 76 +++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 8bc09e94fb..4104151ff2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -107,6 +107,82 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class DiffAttention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RmsNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads // 2 + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.lambda_init = 0.8 + self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + + self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5) + + def _set_lambda_init(self, depth: int): + self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + q, k, v = self.qkv(x).chunk(3, dim=2) + q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2) + v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + q = q.reshape(B, self.num_heads, 2, N, self.head_dim) + k = k.reshape(B, self.num_heads, 2, N, self.head_dim) + q1, q2 = q.unbind(2) + k1, k2 = k.unbind(2) + attn1 = F.scaled_dot_product_attention(q1, k1, v) + attn2 = F.scaled_dot_product_attention(q2, k2, v) + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + x = attn1 - lambda_full * attn2 + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + attn = attn.view(B, self.num_heads, 2, N, N) + attn = attn[:, :, 0] - lambda_full * attn[:, :, 1] + x = attn @ v + + x = self.sub_norm(x) + x = x * (1 - self.lambda_init) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + class LayerScale(nn.Module): def __init__( self,