diff --git a/pyproject.toml b/pyproject.toml index e452de6..dad80f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dependencies = [ "pydantic", "supervision", "matplotlib", + "soft_moe", ] [project.optional-dependencies] diff --git a/rfdetr/config.py b/rfdetr/config.py index da57456..7cff8c4 100644 --- a/rfdetr/config.py +++ b/rfdetr/config.py @@ -30,6 +30,8 @@ class ModelConfig(BaseModel): resolution: int = 560 group_detr: int = 13 gradient_checkpointing: bool = False + MoE: bool = False + MoE_params: List[int] = [32, 1] class RFDETRBaseConfig(ModelConfig): encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = "dinov2_windowed_small" diff --git a/rfdetr/models/transformer.py b/rfdetr/models/transformer.py index 81e13e5..06520fc 100644 --- a/rfdetr/models/transformer.py +++ b/rfdetr/models/transformer.py @@ -20,6 +20,7 @@ from typing import Optional import torch +from soft_moe import SoftMoELayerWrapper import torch.nn.functional as F from torch import nn, Tensor @@ -39,6 +40,18 @@ def forward(self, x): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x +class FFNBlock(nn.Module): + def __init__(self, d_model, dim_feedforward, dropout): + super().__init__() + self.net = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + def forward(self, x): + return self.net(x) def gen_sineembed_for_position(pos_tensor, dim=128): # n_query, bs, _ = pos_tensor.size() @@ -136,7 +149,8 @@ def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300, num_feature_levels=4, dec_n_points=4, lite_refpoint_refine=False, decoder_norm_type='LN', - bbox_reparam=False): + bbox_reparam=False, + MoE=False, MoE_params=[32,1]): super().__init__() self.encoder = None @@ -145,7 +159,9 @@ def __init__(self, d_model=512, sa_nhead=8, ca_nhead=8, num_queries=300, group_detr=group_detr, num_feature_levels=num_feature_levels, dec_n_points=dec_n_points, - skip_self_attn=False,) + skip_self_attn=False, + MoE=MoE, + MoE_params=MoE_params) assert decoder_norm_type in ['LN', 'Identity'] norm = { "LN": lambda channels: nn.LayerNorm(channels), @@ -441,7 +457,7 @@ class TransformerDecoderLayer(nn.Module): def __init__(self, d_model, sa_nhead, ca_nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, group_detr=1, num_feature_levels=4, dec_n_points=4, - skip_self_attn=False): + skip_self_attn=False, MoE=False, MoE_params=[32,1]): super().__init__() # Decoder Self-Attention self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=sa_nhead, dropout=dropout, batch_first=True) @@ -453,11 +469,34 @@ def __init__(self, d_model, sa_nhead, ca_nhead, dim_feedforward=2048, dropout=0. d_model, n_levels=num_feature_levels, n_heads=ca_nhead, n_points=dec_n_points) self.nhead = ca_nhead - - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) + + # Implementation of Feedforward or the MoE Layer (done by @LeosCtrt) + self.MoE = MoE + if self.MoE == True: + print("\n" + "="*80) + print("Loading Mixture of Expert (MoE) Architecture") + print("="*80) + print(f"Experts Count : {MoE_params[0]}") + print(f"Slots per Expert : {MoE_params[1]}") + print("-"*80) + print("Warning: This custom architecture prevents loading full pretrained weights.") + print("Note : It may be slightly slower but could improve accuracy.") + print("="*80 + "\n") + + self.moe_layer = SoftMoELayerWrapper( + dim=d_model, + num_experts=MoE_params[0], + slots_per_expert=MoE_params[1], + layer=FFNBlock, + d_model=d_model, + dim_feedforward=dim_feedforward, + dropout=dropout + ) + else: + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.activation = _get_activation_fn(activation) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) @@ -465,7 +504,6 @@ def __init__(self, d_model, sa_nhead, ca_nhead, dim_feedforward=2048, dropout=0. self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) - self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self.group_detr = group_detr @@ -521,7 +559,10 @@ def forward_post(self, tgt, memory, tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + if self.MoE == True: + tgt2 = self.moe_layer(tgt) + else: + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt @@ -571,6 +612,8 @@ def build_transformer(args): lite_refpoint_refine=args.lite_refpoint_refine, decoder_norm_type=args.decoder_norm, bbox_reparam=args.bbox_reparam, + MoE=args.MoE, + MoE_params=args.MoE_params )