diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 2cda1be9ad..204fa90300 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -373,6 +373,7 @@ ) from .models.ctrl import CTRLForSequenceClassification, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel from .models.cvt import CvtForImageClassification, CvtModel, CvtPreTrainedModel +from .models.d_fine import DFineForObjectDetection, DFineModel, DFinePreTrainedModel from .models.dac import DacModel, DacPreTrainedModel from .models.data2vec import ( Data2VecAudioForAudioFrameClassification, @@ -463,6 +464,7 @@ DPRReader, ) from .models.dpt import DPTForDepthEstimation, DPTImageProcessor, DPTModel, DPTPreTrainedModel +from .models.efficientloftr import EfficientLoFTRForKeypointMatching, EfficientLoFTRModel, EfficientLoFTRPreTrainedModel from .models.efficientnet import ( EfficientNetForImageClassification, EfficientNetImageProcessor, @@ -650,6 +652,7 @@ ) from .models.granite import GraniteForCausalLM, GraniteModel, GranitePreTrainedModel from .models.granitemoe import GraniteMoeForCausalLM, GraniteMoeModel, GraniteMoePreTrainedModel +from .models.granitemoehybrid import GraniteMoeHybridForCausalLM, GraniteMoeHybridModel, GraniteMoeHybridPreTrainedModel from .models.granitemoeshared import GraniteMoeSharedForCausalLM, GraniteMoeSharedModel, GraniteMoeSharedPreTrainedModel from .models.grounding_dino import ( GroundingDinoForObjectDetection, @@ -666,6 +669,7 @@ HeliumModel, HeliumPreTrainedModel, ) +from .models.hgnet_v2 import HGNetV2Backbone, HGNetV2ForImageClassification, HGNetV2PreTrainedModel from .models.hiera import ( HieraBackbone, HieraForImageClassification, @@ -1190,7 +1194,7 @@ RoFormerModel, RoFormerPreTrainedModel, ) -from .models.rt_detr import RTDetrForObjectDetection, RTDetrModel, RTDetrPreTrainedModel +from .models.rt_detr import RTDetrForObjectDetection, RTDetrImageProcessor, RTDetrModel, RTDetrPreTrainedModel from .models.rt_detr_v2 import RTDetrV2ForObjectDetection, RTDetrV2Model, RTDetrV2PreTrainedModel from .models.rwkv import RwkvForCausalLM, RwkvModel, RwkvPreTrainedModel from .models.sam import SamImageProcessor, SamModel, SamPreTrainedModel, SamProcessor diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index 28e7ced270..92bf330354 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -54,6 +54,7 @@ convnextv2, ctrl, cvt, + d_fine, dac, data2vec, dbrx, @@ -67,6 +68,7 @@ distilbert, dpr, dpt, + efficientloftr, efficientnet, electra, emu3, @@ -99,8 +101,10 @@ gptj, granite, granitemoe, + granitemoehybrid, granitemoeshared, groupvit, + hgnet_v2, hiera, hubert, idefics, diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index b1d25744a3..4171b2cc83 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -124,6 +124,7 @@ ("granitemoeshared", "GraniteMoeSharedConfig"), ("groupvit", "GroupViTConfig"), ("helium", "HeliumConfig"), + ("hgnet_v2", "HGNetV2Config"), ("hiera", "HieraConfig"), ("hubert", "HubertConfig"), ("ibert", "IBertConfig"), @@ -392,6 +393,7 @@ ("granitemoeshared", "GraniteMoeSharedMoe"), ("groupvit", "GroupViT"), ("helium", "Helium"), + ("hgnet_v2", "HGNet-V2"), ("hiera", "Hiera"), ("hubert", "Hubert"), ("ibert", "I-BERT"), diff --git a/mindone/transformers/models/auto/image_processing_auto.py b/mindone/transformers/models/auto/image_processing_auto.py index baea350d79..bca22c86a5 100644 --- a/mindone/transformers/models/auto/image_processing_auto.py +++ b/mindone/transformers/models/auto/image_processing_auto.py @@ -73,6 +73,7 @@ ("owlv2", ("Owlv2ImageProcessor",)), ("owlvit", ("OwlViTImageProcessor",)), ("qwen2_5_vl", ("Qwen2VLImageProcessor",)), + ("rt_detr", ("RTDetrImageProcessor",)), ("sam", ("SamImageProcessor",)), ("segformer", ("SegformerImageProcessor",)), ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")), diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index c89e50bc47..0a7d4602c9 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -118,6 +118,7 @@ ("groupvit", "GroupViTModel"), ("grounding-dino", "GroundingDinoModel"), ("helium", "HeliumModel"), + ("hgnet_v2", "HGNetV2Backbone"), ("hiera", "HieraModel"), ("hubert", "HubertModel"), ("ibert", "IBertModel"), @@ -583,6 +584,7 @@ ("dinov2_with_registers", "Dinov2WithRegistersForImageClassification"), ("efficientnet", "EfficientNetForImageClassification"), ("focalnet", "FocalNetForImageClassification"), + ("hgnet_v2", "HGNetV2ForImageClassification"), ("hiera", "HieraForImageClassification"), ("ijepa", "IJepaForImageClassification"), ("imagegpt", "ImageGPTForImageClassification"), @@ -1212,6 +1214,7 @@ ("dinov2", "Dinov2Backbone"), ("dinov2_with_registers", "Dinov2WithRegistersBackbone"), ("focalnet", "FocalNetBackbone"), + ("hgnet_v2", "HGNetV2Backbone"), ("hiera", "HieraBackbone"), ("maskformer-swin", "MaskFormerSwinBackbone"), ("pvt_v2", "PvtV2Backbone"), diff --git a/mindone/transformers/models/d_fine/__init__.py b/mindone/transformers/models/d_fine/__init__.py new file mode 100644 index 0000000000..9ac736d2ef --- /dev/null +++ b/mindone/transformers/models/d_fine/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_d_fine import * diff --git a/mindone/transformers/models/d_fine/modeling_d_fine.py b/mindone/transformers/models/d_fine/modeling_d_fine.py new file mode 100644 index 0000000000..0013706eec --- /dev/null +++ b/mindone/transformers/models/d_fine/modeling_d_fine.py @@ -0,0 +1,2084 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/d_fine/modular_d_fine.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_d_fine.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Baidu Inc and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Optional, Union + +from transformers.models.d_fine.configuration_d_fine import DFineConfig + +import mindspore as ms +import mindspore.mint.nn.functional as F +from mindspore import Tensor, mint, nn, ops + +from ...activations import ACT2CLS, ACT2FN +from ...image_transforms import center_to_corners_format, corners_to_center_format +from ...mindspore_adapter import dtype_to_max +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ModelOutput, mindspore_int +from ...utils.backbone_utils import load_backbone + + +def multi_scale_deformable_attention_v2( + value: Tensor, + value_spatial_shapes: Tensor, + sampling_locations: Tensor, + attention_weights: Tensor, + num_points_list: list[int], + method="default", +) -> Tensor: + batch_size, _, num_heads, hidden_dim = value.shape + _, num_queries, num_heads, num_levels, num_points = sampling_locations.shape + value_list = ( + value.permute(0, 2, 3, 1) + .flatten(0, 1) + .split([height * width for height, width in value_spatial_shapes], dim=-1) + ) + # sampling_offsets [8, 480, 8, 12, 2] + if method == "default": + sampling_grids = 2 * sampling_locations - 1 + elif method == "discrete": + sampling_grids = sampling_locations + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_grids = sampling_grids.split(num_points_list, dim=-2) + sampling_value_list = [] + for level_id, (height, width) in enumerate(value_spatial_shapes): + # batch_size, height*width, num_heads, hidden_dim + # -> batch_size, height*width, num_heads*hidden_dim + # -> batch_size, num_heads*hidden_dim, height*width + # -> batch_size*num_heads, hidden_dim, height, width + value_l_ = value_list[level_id].reshape(batch_size * num_heads, hidden_dim, height, width) + # batch_size, num_queries, num_heads, num_points, 2 + # -> batch_size, num_heads, num_queries, num_points, 2 + # -> batch_size*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[level_id] + # batch_size*num_heads, hidden_dim, num_queries, num_points + if method == "default": + sampling_value_l_ = mint.nn.functional.grid_sample( + value_l_.float(), sampling_grid_l_.float(), mode="bilinear", padding_mode="zeros", align_corners=False + ).to(value_l_.dtype) + elif method == "discrete": + sampling_coord = (sampling_grid_l_ * ms.tensor([[width, height]]) + 0.5).to(ms.int64) + + # Separate clamping for x and y coordinates + sampling_coord_x = sampling_coord[..., 0].clamp(0, width - 1) + sampling_coord_y = sampling_coord[..., 1].clamp(0, height - 1) + + # Combine the clamped coordinates + sampling_coord = mint.stack([sampling_coord_x, sampling_coord_y], dim=-1) + sampling_coord = sampling_coord.reshape(batch_size * num_heads, num_queries * num_points_list[level_id], 2) + sampling_idx = mint.arange(sampling_coord.shape[0]).unsqueeze(-1).tile((1, sampling_coord.shape[1])) + sampling_value_l_ = value_l_[sampling_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] + sampling_value_l_ = sampling_value_l_.permute(0, 2, 1).reshape( + batch_size * num_heads, hidden_dim, num_queries, num_points_list[level_id] + ) + sampling_value_list.append(sampling_value_l_) + # (batch_size, num_queries, num_heads, num_levels, num_points) + # -> (batch_size, num_heads, num_queries, num_levels, num_points) + # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.permute(0, 2, 1, 3).reshape( + batch_size * num_heads, 1, num_queries, sum(num_points_list) + ) + output = ( + (mint.concat(sampling_value_list, dim=-1) * attention_weights) + .sum(-1) + .view(batch_size, num_heads * hidden_dim, num_queries) + ) + return output.swapaxes(1, 2).contiguous() + + +class DFineMultiscaleDeformableAttention(nn.Cell): + def __init__(self, config: DFineConfig): + """ + D-Fine version of multiscale deformable attention + """ + super().__init__() + self.d_model = config.d_model + self.n_heads = config.decoder_attention_heads + self.n_levels = config.num_feature_levels + self.offset_scale = config.decoder_offset_scale + self.decoder_method = config.decoder_method + self.n_points = config.decoder_n_points + + if isinstance(self.n_points, list): + num_points_list = self.n_points + else: + num_points_list = [self.n_points for _ in range(self.n_levels)] + + self.num_points_list = num_points_list + num_points_scale = [1 / n for n in self.num_points_list for _ in range(n)] + self.num_points_scale = ms.Parameter( + ms.tensor(num_points_scale, dtype=ms.float32), requires_grad=False, name="num_points_scale" + ) + + self.total_points = self.n_heads * sum(self.num_points_list) + + self.sampling_offsets = mint.nn.Linear(self.d_model, self.total_points * 2) + self.attention_weights = mint.nn.Linear(self.d_model, self.total_points) + + self.ms_deformable_attn_core = multi_scale_deformable_attention_v2 + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + reference_points=None, + encoder_hidden_states=None, + spatial_shapes=None, + spatial_shapes_list=None, + ) -> tuple[ms.Tensor, ms.Tensor]: + batch_size, num_queries, _ = hidden_states.shape + batch_size, sequence_length, _ = encoder_hidden_states.shape + + if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + raise ValueError( + "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" + ) + + # Reshape for multi-head attention + value = encoder_hidden_states.reshape(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads) + if attention_mask is not None: + value = value.masked_fill(~attention_mask[..., None], float(0)) + + sampling_offsets: ms.Tensor = self.sampling_offsets(hidden_states) + sampling_offsets = sampling_offsets.reshape(batch_size, num_queries, self.n_heads, sum(self.num_points_list), 2) + + attention_weights = self.attention_weights(hidden_states).reshape( + batch_size, num_queries, self.n_heads, sum(self.num_points_list) + ) + attention_weights = F.softmax(attention_weights, dim=-1) + + if reference_points.shape[-1] == 2: + offset_normalizer = ms.tensor(spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.n_levels, 1, 2) + sampling_locations = ( + reference_points.reshape(batch_size, sequence_length, 1, self.n_levels, 1, 2) + + sampling_offsets / offset_normalizer + ) + elif reference_points.shape[-1] == 4: + # reference_points [8, 480, None, 1, 4] + # sampling_offsets [8, 480, 8, 12, 2] + num_points_scale = self.num_points_scale.to(dtype=hidden_states.dtype).unsqueeze(-1) + offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError( + f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead." + ) + + output = self.ms_deformable_attn_core( + value, + spatial_shapes_list, + sampling_locations, + attention_weights, + self.num_points_list, + self.decoder_method, + ) + + return output, attention_weights + + +class DFineGate(nn.Cell): + def __init__(self, d_model: int): + super().__init__() + self.gate = mint.nn.Linear(2 * d_model, 2 * d_model) + self.norm = mint.nn.LayerNorm(d_model) + + def construct(self, second_residual: ms.Tensor, hidden_states: ms.Tensor) -> ms.Tensor: + gate_input = mint.cat([second_residual, hidden_states], dim=-1) + gates = mint.sigmoid(self.gate(gate_input)) + gate1, gate2 = gates.chunk(2, dim=-1) + hidden_states = self.norm(gate1 * second_residual + gate2 * hidden_states) + return hidden_states + + +class DFineMultiheadAttention(nn.Cell): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the Deformable DETR paper). + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + if self.head_dim * num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self.k_proj = mint.nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = mint.nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = mint.nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = mint.nn.Linear(embed_dim, embed_dim, bias=bias) + + def _reshape(self, tensor: ms.Tensor, seq_len: int, batch_size: int): + return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).swapaxes(1, 2).contiguous() + + def with_pos_embed(self, tensor: ms.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + position_embeddings: Optional[ms.Tensor] = None, + output_attentions: bool = False, + ) -> tuple[ms.Tensor, Optional[ms.Tensor], Optional[tuple[ms.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, target_len, embed_dim = hidden_states.shape + # add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # get queries, keys and values + query_states = self.q_proj(hidden_states) * self.scaling + key_states = self._reshape(self.k_proj(hidden_states), -1, batch_size) + value_states = self._reshape(self.v_proj(hidden_states_original), -1, batch_size) + + proj_shape = (batch_size * self.num_heads, -1, self.head_dim) + query_states = self._reshape(query_states, target_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + source_len = key_states.shape[1] + + attn_weights = mint.bmm(query_states, key_states.swapaxes(1, 2)) + + if attn_weights.shape != (batch_size * self.num_heads, target_len, source_len): + raise ValueError( + f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" + f" {attn_weights.shape}" + ) + + # expand attention_mask + if attention_mask is not None: + # [seq_len, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] + attention_mask = attention_mask.broadcast_to((batch_size, 1, *attention_mask.shape)) + + if attention_mask is not None: + if attention_mask.shape != (batch_size, 1, target_len, source_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" + f" {attention_mask.shape}" + ) + attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask + attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) + + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) + else: + attn_weights_reshaped = None + + attn_probs = mint.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = mint.bmm(attn_probs, value_states) + + if attn_output.shape != (batch_size * self.num_heads, target_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" + f" {attn_output.shape}" + ) + + attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) + attn_output = attn_output.swapaxes(1, 2) + attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +class DFineDecoderLayer(nn.Cell): + def __init__(self, config: DFineConfig): + super().__init__() + # self-attention + self.self_attn = DFineMultiheadAttention( + embed_dim=config.d_model, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.decoder_activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = mint.nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + + # override the encoder attention module with d-fine version + self.encoder_attn = DFineMultiscaleDeformableAttention(config=config) + # feedforward neural networks + self.fc1 = mint.nn.Linear(config.d_model, config.decoder_ffn_dim) + self.fc2 = mint.nn.Linear(config.decoder_ffn_dim, config.d_model) + self.final_layer_norm = mint.nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + # gate + self.gateway = DFineGate(config.d_model) + + def construct( + self, + hidden_states: ms.Tensor, + position_embeddings: Optional[ms.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + encoder_hidden_states: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[ms.Tensor, Any, Any]: + """ + Args: + hidden_states (`ms.Tensor`): + Input to the layer of shape `(seq_len, batch, embed_dim)`. + position_embeddings (`ms.Tensor`, *optional*): + Position embeddings that are added to the queries and keys in the self-attention layer. + reference_points (`ms.Tensor`, *optional*): + Reference points. + spatial_shapes (`ms.Tensor`, *optional*): + Spatial shapes. + level_start_index (`ms.Tensor`, *optional*): + Level start index. + encoder_hidden_states (`ms.Tensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`ms.Tensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + # Self Attention + hidden_states_2, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states_2 = mint.nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training) + hidden_states = hidden_states + hidden_states_2 + hidden_states = self.self_attn_layer_norm(hidden_states) + residual = hidden_states + + # Cross-Attention + cross_attn_weights = None + hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings + hidden_states_2, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + ) + + hidden_states_2 = mint.nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training) + hidden_states = self.gateway(residual, hidden_states_2) + + # Fully Connected + hidden_states_2 = self.activation_fn(self.fc1(hidden_states)) + hidden_states_2 = mint.nn.functional.dropout(hidden_states_2, p=self.activation_dropout, training=self.training) + hidden_states_2 = self.fc2(hidden_states_2) + hidden_states_2 = mint.nn.functional.dropout(hidden_states_2, p=self.dropout, training=self.training) + hidden_states = hidden_states + hidden_states_2 + hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504)) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +@dataclass +class DFineModelOutput(ModelOutput): + r""" + last_hidden_state (`ms.Tensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`ms.Tensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`ms.Tensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points used for the first decoder layer. + init_reference_points (`ms.Tensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`): + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_topk_bboxes (`ms.Tensor` of shape `(batch_size, sequence_length, 4)`): + Logits of predicted bounding boxes coordinates in the encoder stage. + enc_outputs_class (`ms.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): # noqa: E501 + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`ms.Tensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): # noqa: E501 + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values. + """ + + last_hidden_state: Optional[ms.Tensor] = None + intermediate_hidden_states: Optional[ms.Tensor] = None + intermediate_logits: Optional[ms.Tensor] = None + intermediate_reference_points: Optional[ms.Tensor] = None + intermediate_predicted_corners: Optional[ms.Tensor] = None + initial_reference_points: Optional[ms.Tensor] = None + decoder_hidden_states: Optional[tuple[ms.Tensor]] = None + decoder_attentions: Optional[tuple[ms.Tensor]] = None + cross_attentions: Optional[tuple[ms.Tensor]] = None + encoder_last_hidden_state: Optional[ms.Tensor] = None + encoder_hidden_states: Optional[tuple[ms.Tensor]] = None + encoder_attentions: Optional[tuple[ms.Tensor]] = None + init_reference_points: Optional[ms.Tensor] = None + enc_topk_logits: Optional[ms.Tensor] = None + enc_topk_bboxes: Optional[ms.Tensor] = None + enc_outputs_class: Optional[ms.Tensor] = None + enc_outputs_coord_logits: Optional[ms.Tensor] = None + denoising_meta_values: Optional[dict] = None + + +@dataclass +class DFineObjectDetectionOutput(ModelOutput): + r""" + loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `labels` are provided)): + Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a + bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized + scale-invariant IoU loss. + loss_dict (`Dict`, *optional*): + A dictionary containing the individual losses. Useful for logging. + logits (`ms.Tensor` of shape `(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (`ms.Tensor` of shape `(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These + values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding + possible padding). You can use [`~DFineImageProcessor.post_process_object_detection`] to retrieve the + unnormalized (absolute) bounding boxes. + auxiliary_outputs (`list[Dict]`, *optional*): + Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`) + and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and + `pred_boxes`) for each decoder layer. + last_hidden_state (`ms.Tensor` of shape `(batch_size, num_queries, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + intermediate_hidden_states (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked initial reference points (initial reference points of each layer of the decoder). + init_reference_points (`ms.Tensor` of shape `(batch_size, num_queries, 4)`): + Initial reference points sent through the Transformer decoder. + enc_topk_logits (`ms.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): # noqa: E501 + Logits of predicted bounding boxes coordinates in the encoder. + enc_topk_bboxes (`ms.Tensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): # noqa: E501 + Logits of predicted bounding boxes coordinates in the encoder. + enc_outputs_class (`ms.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): # noqa: E501 + Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are + picked as region proposals in the first stage. Output of bounding box binary classification (i.e. + foreground and background). + enc_outputs_coord_logits (`ms.Tensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`): # noqa: E501 + Logits of predicted bounding boxes coordinates in the first stage. + denoising_meta_values (`dict`): + Extra dictionary for the denoising related values + """ + + loss: Optional[ms.Tensor] = None + loss_dict: Optional[dict] = None + logits: Optional[ms.Tensor] = None + pred_boxes: Optional[ms.Tensor] = None + auxiliary_outputs: Optional[list[dict]] = None + last_hidden_state: Optional[ms.Tensor] = None + intermediate_hidden_states: Optional[ms.Tensor] = None + intermediate_logits: Optional[ms.Tensor] = None + intermediate_reference_points: Optional[ms.Tensor] = None + intermediate_predicted_corners: Optional[ms.Tensor] = None + initial_reference_points: Optional[ms.Tensor] = None + decoder_hidden_states: Optional[tuple[ms.Tensor]] = None + decoder_attentions: Optional[tuple[ms.Tensor]] = None + cross_attentions: Optional[tuple[ms.Tensor]] = None + encoder_last_hidden_state: Optional[ms.Tensor] = None + encoder_hidden_states: Optional[tuple[ms.Tensor]] = None + encoder_attentions: Optional[tuple[ms.Tensor]] = None + init_reference_points: Optional[tuple[ms.Tensor]] = None + enc_topk_logits: Optional[ms.Tensor] = None + enc_topk_bboxes: Optional[ms.Tensor] = None + enc_outputs_class: Optional[ms.Tensor] = None + enc_outputs_coord_logits: Optional[ms.Tensor] = None + denoising_meta_values: Optional[dict] = None + + +class DFineFrozenBatchNorm2d(nn.Cell): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than + torchvision.models.resnet[18,34,50,101] produce nans. + """ + + def __init__(self, n): + super().__init__() + self.weight = ms.Parameter(mint.ones(n), requires_grad=False, name="weight") + self.bias = ms.Parameter(mint.zeros(n), requires_grad=False, name="bias") + self.running_mean = ms.Parameter(mint.zeros(n), requires_grad=False, name="running_mean") + self.running_var = ms.Parameter(mint.ones(n), requires_grad=False, name="running_var") + + # TODO + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def construct(self, x): + # move reshapes to the beginning + # to make it user-friendly + weight = self.weight.reshape(1, -1, 1, 1) + bias = self.bias.reshape(1, -1, 1, 1) + running_var = self.running_var.reshape(1, -1, 1, 1) + running_mean = self.running_mean.reshape(1, -1, 1, 1) + epsilon = 1e-5 + scale = weight * (running_var + epsilon).rsqrt() + bias = bias - running_mean * scale + return x * scale + bias + + +def replace_batch_norm(model): + r""" + Recursively replace all `mint.nn.BatchNorm2d` with `DFineFrozenBatchNorm2d`. + + Args: + model (nn.Cell): + input model + """ + for name, module in model.name_cells().items(): + if isinstance(module, mint.nn.BatchNorm2d): + new_module = DFineFrozenBatchNorm2d(module.num_features) + + # model._modules[name] = new_module + setattr(model, name, new_module) + + if len(list(module.name_cells())) > 0: + replace_batch_norm(module) + + +class DFineConvEncoder(nn.Cell): + """ + Convolutional backbone using the modeling_d_fine_resnet.py. + + mint.nn.BatchNorm2d layers are replaced by DFineFrozenBatchNorm2d as defined above. + https://github.com/lyuwenyu/RT-DETR/blob/main/DFine_pytorch/src/nn/backbone/presnet.py#L142 + """ + + def __init__(self, config): + super().__init__() + + backbone = load_backbone(config) + + if config.freeze_backbone_batch_norms: + # replace batch norm by frozen batch norm + replace_batch_norm(backbone) + self.model = backbone + self.intermediate_channel_sizes = self.model.channels + + def construct(self, pixel_values: ms.Tensor, pixel_mask: ms.Tensor): + # send pixel_values through the model to get list of feature maps + features = self.model(pixel_values).feature_maps + + out = [] + for feature_map in features: + # downsample pixel_mask to match shape of corresponding feature_map + mask = mint.nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(ms.bool_)[0] + out.append((feature_map, mask)) + return out + + +class DFineEncoderLayer(nn.Cell): + def __init__(self, config: DFineConfig): + super().__init__() + self.normalize_before = config.normalize_before + + # self-attention + self.self_attn = DFineMultiheadAttention( + embed_dim=config.encoder_hidden_dim, + num_heads=config.num_attention_heads, + dropout=config.dropout, + ) + self.self_attn_layer_norm = mint.nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.encoder_activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = mint.nn.Linear(config.encoder_hidden_dim, config.encoder_ffn_dim) + self.fc2 = mint.nn.Linear(config.encoder_ffn_dim, config.encoder_hidden_dim) + self.final_layer_norm = mint.nn.LayerNorm(config.encoder_hidden_dim, eps=config.layer_norm_eps) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: ms.Tensor, + position_embeddings: Optional[ms.Tensor] = None, + output_attentions: bool = False, + **kwargs, + ): + """ + Args: + hidden_states (`ms.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`ms.Tensor`): attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + position_embeddings (`ms.Tensor`, *optional*): + Object queries (also called content embeddings), to be added to the hidden states. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + if self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = mint.nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + if self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + residual = hidden_states + + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = mint.nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + + hidden_states = mint.nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + if not self.normalize_before: + hidden_states = self.final_layer_norm(hidden_states) + + if self.training: + if mint.isinf(hidden_states).any() or mint.isnan(hidden_states).any(): + clamp_value = dtype_to_max(hidden_states.dtype) - 1000 + hidden_states = mint.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +def inverse_sigmoid(x, eps=1e-5): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return mint.log(x1 / x2) + + +def get_contrastive_denoising_training_group( + targets, + num_classes, + num_queries, + class_embed, + num_denoising_queries=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, +): + """ + Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes. + + Args: + targets (`list[dict]`): + The target objects, each containing 'class_labels' and 'boxes' for objects in an image. + num_classes (`int`): + Total number of classes in the dataset. + num_queries (`int`): + Number of query slots in the transformer. + class_embed (`callable`): + A function or a model layer to embed class labels. + num_denoising_queries (`int`, *optional*, defaults to 100): + Number of denoising queries. + label_noise_ratio (`float`, *optional*, defaults to 0.5): + Ratio of noise applied to labels. + box_noise_scale (`float`, *optional*, defaults to 1.0): + Scale of noise applied to bounding boxes. + Returns: + `tuple` comprising various elements: + - **input_query_class** (`ms.Tensor`) -- + Class queries with applied label noise. + - **input_query_bbox** (`ms.Tensor`) -- + Bounding box queries with applied box noise. + - **attn_mask** (`ms.Tensor`) -- + Attention mask for separating denoising and reconstruction queries. + - **denoising_meta_values** (`dict`) -- + Metadata including denoising positive indices, number of groups, and split sizes. + """ + + if num_denoising_queries <= 0: + return None, None, None, None + + num_ground_truths = [len(t["class_labels"]) for t in targets] + + max_gt_num = max(num_ground_truths) + if max_gt_num == 0: + return None, None, None, None + + num_groups_denoising_queries = num_denoising_queries // max_gt_num + num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries + # pad gt to max_num of a batch + batch_size = len(num_ground_truths) + + input_query_class = mint.full([batch_size, max_gt_num], num_classes, dtype=ms.int32) + input_query_bbox = mint.zeros([batch_size, max_gt_num, 4]) + pad_gt_mask = mint.zeros([batch_size, max_gt_num], dtype=ms.bool_) + + for i in range(batch_size): + num_gt = num_ground_truths[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]["class_labels"] + input_query_bbox[i, :num_gt] = targets[i]["boxes"] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile((1, 2 * num_groups_denoising_queries)) + input_query_bbox = input_query_bbox.tile((1, 2 * num_groups_denoising_queries, 1)) + pad_gt_mask = pad_gt_mask.tile((1, 2 * num_groups_denoising_queries)) + # positive and negative mask + negative_gt_mask = mint.zeros([batch_size, max_gt_num * 2, 1]) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + denoise_positive_idx = mint.nonzero(positive_gt_mask)[:, 1] + denoise_positive_idx = mint.split( + denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths] + ) + # total denoising queries + num_denoising_queries = mindspore_int(max_gt_num * 2 * num_groups_denoising_queries) + + if label_noise_ratio > 0: + mask = mint.rand_like(input_query_class, dtype=ms.float32) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = mint.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = mint.where(mask & pad_gt_mask, new_label, input_query_class) + + if box_noise_scale > 0: + known_bbox = center_to_corners_format(input_query_bbox) + diff = mint.tile(input_query_bbox[..., 2:] * 0.5, (1, 1, 2)) * box_noise_scale + rand_sign = mint.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = mint.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = corners_to_center_format(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + input_query_class = class_embed(input_query_class) + + target_size = num_denoising_queries + num_queries + attn_mask = mint.full([target_size, target_size], False, dtype=ms.bool_) + # match query cannot see the reconstruction + attn_mask[num_denoising_queries:, :num_denoising_queries] = True + + # reconstructions cannot see each other + for i in range(num_groups_denoising_queries): + idx_block_start = max_gt_num * 2 * i + idx_block_end = max_gt_num * 2 * (i + 1) + attn_mask[idx_block_start:idx_block_end, :idx_block_start] = True + attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = True + + denoising_meta_values = { + "dn_positive_idx": denoise_positive_idx, + "dn_num_group": num_groups_denoising_queries, + "dn_num_split": [num_denoising_queries, num_queries], + } + + return input_query_class, input_query_bbox, attn_mask, denoising_meta_values + + +def _get_clones(partial_module, N): + return nn.CellList([partial_module() for i in range(N)]) + + +class DFinePreTrainedModel(PreTrainedModel): + config: DFineConfig + base_model_prefix = "d_fine" + main_input_name = "pixel_values" + _no_split_modules = [r"DFineHybridEncoder", r"DFineDecoderLayer"] + + def _init_weights(self, module): + """Initialize the weights""" + # initialize linear layer bias value according to a given probability value. + pass + + +class DFineIntegral(nn.Cell): + """ + A static layer that calculates integral results from a distribution. + + This layer computes the target location using the formula: `sum{Pr(n) * W(n)}`, + where Pr(n) is the softmax probability vector representing the discrete + distribution, and W(n) is the non-uniform Weighting Function. + + Args: + max_num_bins (int): Max number of the discrete bins. Default is 32. + It can be adjusted based on the dataset or task requirements. + """ + + def __init__(self, config: DFineConfig): + super().__init__() + self.max_num_bins = config.max_num_bins + + def construct(self, pred_corners: ms.Tensor, project: ms.Tensor) -> ms.Tensor: + batch_size, num_queries, _ = pred_corners.shape + pred_corners = F.softmax(pred_corners.reshape(-1, self.max_num_bins + 1), dim=1) + pred_corners = F.linear(pred_corners, project).reshape(-1, 4) + pred_corners = pred_corners.reshape(batch_size, num_queries, -1) + return pred_corners + + +@dataclass +class DFineDecoderOutput(ModelOutput): + r""" + intermediate_hidden_states (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`): + Stacked intermediate hidden states (output of each layer of the decoder). + intermediate_logits (`ms.Tensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`): + Stacked intermediate logits (logits of each layer of the decoder). + intermediate_reference_points (`ms.Tensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`): + Stacked intermediate reference points (reference points of each layer of the decoder). + intermediate_predicted_corners (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked intermediate predicted corners (predicted corners of each layer of the decoder). + initial_reference_points (`ms.Tensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`): + Stacked initial reference points (initial reference points of each layer of the decoder). + cross_attentions (`tuple(ms.Tensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): # noqa: E501 + Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: Optional[ms.Tensor] = None + intermediate_hidden_states: Optional[ms.Tensor] = None + intermediate_logits: Optional[ms.Tensor] = None + intermediate_reference_points: Optional[ms.Tensor] = None + intermediate_predicted_corners: Optional[ms.Tensor] = None + initial_reference_points: Optional[ms.Tensor] = None + hidden_states: Optional[tuple[ms.Tensor]] = None + attentions: Optional[tuple[ms.Tensor]] = None + cross_attentions: Optional[tuple[ms.Tensor]] = None + + +def weighting_function(max_num_bins: int, up: ms.Tensor, reg_scale: int) -> ms.Tensor: + """ + Generates the non-uniform Weighting Function W(n) for bounding box regression. + + Args: + max_num_bins (int): Max number of the discrete bins. + up (Tensor): Controls upper bounds of the sequence, + where maximum offset is ±up * H / W. + reg_scale (float): Controls the curvature of the Weighting Function. + Larger values result in flatter weights near the central axis W(max_num_bins/2)=0 + and steeper weights at both ends. + Returns: + Tensor: Sequence of Weighting Function. + """ + upper_bound1 = abs(up[0]) * abs(reg_scale) + upper_bound2 = abs(up[0]) * abs(reg_scale) * 2 + step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2)) + left_values = [-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)] + right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)] + values = [-upper_bound2] + left_values + [mint.zeros_like(up[0][None])] + right_values + [upper_bound2] + values = mint.cat(values, 0) + return values + + +def distance2bbox(points, distance: ms.Tensor, reg_scale: float) -> ms.Tensor: + """ + Decodes edge-distances into bounding box coordinates. + + Args: + points (`ms.Tensor`): + (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height] + distance (`ms.Tensor`): + (batch_size, num_boxes, 4) or (num_boxes, 4), representing distances from the point to the left, top, right, and bottom boundaries. + reg_scale (`float`): + Controls the curvature of the Weighting Function. + Returns: + `ms.Tensor`: Bounding boxes in (batch_size, num_boxes, 4) or (num_boxes, 4) format, representing [x_center, y_center, width, height] + """ + reg_scale = abs(reg_scale) + top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale) + top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale) + bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale) + bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale) + + bboxes = mint.stack([top_left_x, top_left_y, bottom_right_x, bottom_right_y], -1) + + return corners_to_center_format(bboxes) + + +class DFineDecoder(DFinePreTrainedModel): + """ + D-FINE Decoder implementing Fine-grained Distribution Refinement (FDR). + + This decoder refines object detection predictions through iterative updates across multiple layers, + utilizing attention mechanisms, location quality estimators, and distribution refinement techniques + to improve bounding box accuracy and robustness. + """ + + def __init__(self, config: DFineConfig): + super().__init__(config) + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + + self.dropout = config.dropout + self.layers = nn.CellList( + [DFineDecoderLayer(config) for _ in range(config.decoder_layers)] + + [DFineDecoderLayer(config) for _ in range(config.decoder_layers - self.eval_idx - 1)] + ) + self.query_pos_head = DFineMLPPredictionHead(config, 4, 2 * config.d_model, config.d_model, num_layers=2) + + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + self.reg_scale = ms.Parameter(ms.tensor([config.reg_scale]), requires_grad=False, name="reg_scale") + self.max_num_bins = config.max_num_bins + self.d_model = config.d_model + self.layer_scale = config.layer_scale + self.pre_bbox_head = DFineMLP(config.hidden_size, config.hidden_size, 4, 3) + self.integral = DFineIntegral(config) + self.num_head = config.decoder_attention_heads + self.up = ms.Parameter(ms.tensor([config.up]), requires_grad=False, name="up") + self.lqe_layers = nn.CellList([DFineLQE(config) for _ in range(config.decoder_layers)]) + + # Initialize weights and apply final processing + self.post_init() + + def construct( + self, + encoder_hidden_states: ms.Tensor, + reference_points: ms.Tensor, + inputs_embeds: ms.Tensor, + spatial_shapes, + level_start_index=None, + spatial_shapes_list=None, + output_hidden_states=None, + encoder_attention_mask=None, + memory_mask=None, + output_attentions=None, + return_dict=None, + ) -> DFineDecoderOutput: + r""" + Args: + inputs_embeds (`ms.Tensor` of shape `(batch_size, num_queries, hidden_size)`): + The query embeddings that are passed into the decoder. + encoder_hidden_states (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected + in `[0, 1]`: + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + position_embeddings (`ms.Tensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings that are added to the queries and keys in each self-attention layer. + reference_points (`ms.Tensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*): + Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area. + spatial_shapes (`ms.Tensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of the feature maps. + level_start_index (`ms.Tensor` of shape `(num_feature_levels)`, *optional*): + Indexes for the start of each feature level. In range `[0, sequence_length]`. + valid_ratios (`ms.Tensor` of shape `(batch_size, num_feature_levels, 2)`, *optional*): + Ratio of valid area in each feature level. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + intermediate = () + intermediate_reference_points = () + intermediate_logits = () + intermediate_predicted_corners = () + initial_reference_points = () + + output_detach = pred_corners_undetach = 0 + + project = weighting_function(self.max_num_bins, self.up, self.reg_scale) + ref_points_detach = F.sigmoid(reference_points) + + for i, decoder_layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = decoder_layer( + hidden_states=hidden_states, + position_embeddings=query_pos_embed, + reference_points=ref_points_input, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = output[0] + + if i == 0: + # Initial bounding box predictions with inverse sigmoid refinement + new_reference_points = F.sigmoid(self.pre_bbox_head(output[0]) + inverse_sigmoid(ref_points_detach)) + ref_points_initial = ops.stop_gradient(new_reference_points) + + # Refine bounding box corners using FDR, integrating previous layer's corrections + if self.bbox_embed is not None: + pred_corners = self.bbox_embed[i](hidden_states + output_detach) + pred_corners_undetach + inter_ref_bbox = distance2bbox(ref_points_initial, self.integral(pred_corners, project), self.reg_scale) + pred_corners_undetach = pred_corners + ref_points_detach = ops.stop_gradient(inter_ref_bbox) + + output_detach = ops.stop_gradient(hidden_states) + + intermediate += (hidden_states,) + + if self.class_embed is not None and (self.training or i == self.eval_idx): + scores = self.class_embed[i](hidden_states) + # Add initial logits and reference points with pre-bbox head + if i == 0: + intermediate_logits += (scores,) + intermediate_reference_points += (new_reference_points,) + # Lqe does not affect the performance here. + scores = self.lqe_layers[i](scores, pred_corners) + intermediate_logits += (scores,) + intermediate_reference_points += (inter_ref_bbox,) + initial_reference_points += (ref_points_initial,) + intermediate_predicted_corners += (pred_corners,) + + if output_attentions: + all_self_attns += (output[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (output[2],) + + # Keep batch_size as first dimension + intermediate = mint.stack(intermediate) + if self.class_embed is not None and self.bbox_embed is not None: + intermediate_logits = mint.stack(intermediate_logits, dim=1) + intermediate_predicted_corners = mint.stack(intermediate_predicted_corners, dim=1) + initial_reference_points = mint.stack(initial_reference_points, dim=1) + intermediate_reference_points = mint.stack(intermediate_reference_points, dim=1) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + intermediate, + intermediate_logits, + intermediate_reference_points, + intermediate_predicted_corners, + initial_reference_points, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + + return DFineDecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=intermediate, + intermediate_logits=intermediate_logits, + intermediate_reference_points=intermediate_reference_points, + intermediate_predicted_corners=intermediate_predicted_corners, + initial_reference_points=initial_reference_points, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class DFineModel(DFinePreTrainedModel): + def __init__(self, config: DFineConfig): + super().__init__(config) + + # Create backbone + self.backbone = DFineConvEncoder(config) + intermediate_channel_sizes = self.backbone.intermediate_channel_sizes + num_backbone_outs = len(config.decoder_in_channels) + encoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = intermediate_channel_sizes[_] + encoder_input_proj_list.append( + nn.SequentialCell( + mint.nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False), + mint.nn.BatchNorm2d(config.encoder_hidden_dim), + ) + ) + self.encoder_input_proj = nn.CellList(encoder_input_proj_list) + self.encoder = DFineHybridEncoder(config=config) + + # denoising part + if config.num_denoising > 0: + self.denoising_class_embed = mint.nn.Embedding( + config.num_labels + 1, config.d_model, padding_idx=config.num_labels + ) + + # decoder embedding + if config.learn_initial_query: + self.weight_embedding = mint.nn.Embedding(config.num_queries, config.d_model) + + # encoder head + self.enc_output = nn.SequentialCell( + mint.nn.Linear(config.d_model, config.d_model), + mint.nn.LayerNorm(config.d_model, eps=config.layer_norm_eps), + ) + self.enc_score_head = mint.nn.Linear(config.d_model, config.num_labels) + self.enc_bbox_head = DFineMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3) + + # init encoder output anchors and valid_mask + if config.anchor_image_size: + self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype) + num_backbone_outs = len(config.decoder_in_channels) + decoder_input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = config.decoder_in_channels[_] + decoder_input_proj_list.append( + nn.SequentialCell( + mint.nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False), + mint.nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + for _ in range(config.num_feature_levels - num_backbone_outs): + decoder_input_proj_list.append( + nn.SequentialCell( + mint.nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False), + mint.nn.BatchNorm2d(config.d_model, config.batch_norm_eps), + ) + ) + in_channels = config.d_model + self.decoder = DFineDecoder(config) + decoder_input_proj = [] + in_channels = config.decoder_in_channels[-1] + for _ in range(num_backbone_outs): + if config.hidden_size == config.decoder_in_channels[-1]: + decoder_input_proj.append(mint.nn.Identity()) + else: + conv = mint.nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False) + batchnorm = mint.nn.BatchNorm2d(config.d_model, config.batch_norm_eps) + decoder_input_proj.append(nn.SequentialCell(conv, batchnorm)) + for _ in range(config.num_feature_levels - num_backbone_outs): + if config.hidden_size == config.decoder_in_channels[-1]: + decoder_input_proj.append(mint.nn.Identity()) + else: + conv = mint.nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False) + batchnorm = mint.nn.BatchNorm2d(config.d_model, config.batch_norm_eps) + decoder_input_proj.append(nn.SequentialCell(conv, batchnorm)) + self.decoder_input_proj = nn.CellList(decoder_input_proj) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def freeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(False) + + def unfreeze_backbone(self): + for param in self.backbone.parameters(): + param.requires_grad_(True) + + def generate_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=ms.float32): + if spatial_shapes is None: + spatial_shapes = [ + [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)] + for s in self.config.feat_strides + ] + anchors = [] + for level, (height, width) in enumerate(spatial_shapes): + grid_y, grid_x = mint.meshgrid( + mint.arange(end=height).to(dtype), mint.arange(end=width).to(dtype), indexing="ij" + ) + grid_xy = mint.stack([grid_x, grid_y], -1) + grid_xy = grid_xy.unsqueeze(0) + 0.5 + grid_xy[..., 0] /= width + grid_xy[..., 1] /= height + wh = mint.ones_like(grid_xy) * grid_size * (2.0**level) + anchors.append(mint.concat([grid_xy, wh], -1).reshape(-1, height * width, 4)) + # define the valid range for anchor coordinates + eps = 1e-2 + anchors = mint.concat(anchors, 1) + valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) + anchors = mint.log(anchors / (1 - anchors)) + anchors = mint.where(valid_mask, anchors, ms.tensor(dtype_to_max(dtype), dtype=dtype)) + + return anchors, valid_mask + + def construct( + self, + pixel_values: ms.Tensor, + pixel_mask: Optional[ms.Tensor] = None, + encoder_outputs: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + decoder_inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[list[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple[ms.Tensor], DFineModelOutput]: + r""" + inputs_embeds (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you + can choose to directly pass a flattened representation of an image. + decoder_inputs_embeds (`ms.Tensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an + embedded representation. + labels (`list[Dict]` of len `(batch_size,)`, *optional*): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the + following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch + respectively). The class labels themselves should be a `ms.Tensor` of len `(number of bounding boxes + in the image,)` and the boxes a `ms.Tensor` of shape `(number of bounding boxes in the image, 4)`. + + Examples: + + ```python + >>> from mindone.transformers import AutoImageProcessor, DFineModel + >>> from PIL import Image + >>> import requests + >>> import mindspore as ms + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/DFine_r50vd") + >>> model = DFineModel.from_pretrained("PekingU/DFine_r50vd") + + >>> inputs = image_processor(images=image, return_tensors="ms") + + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 300, 256] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + + if pixel_mask is None: + pixel_mask = mint.ones(((batch_size, height, width))) + + features = self.backbone(pixel_values, pixel_mask) + + proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)] + + if encoder_outputs is None: + encoder_outputs = self.encoder( + proj_feats, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if output_hidden_states else None, + attentions=encoder_outputs[2] + if len(encoder_outputs) > 2 + else encoder_outputs[1] + if output_attentions + else None, + ) + + # Equivalent to def _get_encoder_input + # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/DFine_pytorch/src/zoo/DFine/DFine_decoder.py#L412 + sources = [] + for level, source in enumerate(encoder_outputs[0]): + sources.append(self.decoder_input_proj[level](source)) + + # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage + if self.config.num_feature_levels > len(sources): + _len_sources = len(sources) + sources.append(self.decoder_input_proj[_len_sources](encoder_outputs[0])[-1]) + for i in range(_len_sources + 1, self.config.num_feature_levels): + sources.append(self.decoder_input_proj[i](encoder_outputs[0][-1])) + + # Prepare encoder inputs (by flattening) + source_flatten = [] + spatial_shapes_list = [] + spatial_shapes = mint.empty((len(sources), 2), dtype=ms.int64) + for level, source in enumerate(sources): + height, width = source.shape[-2:] + spatial_shapes[level, 0] = height + spatial_shapes[level, 1] = width + spatial_shapes_list.append((height, width)) + source = source.flatten(2).swapaxes(1, 2) + source_flatten.append(source) + source_flatten = mint.cat(source_flatten, 1) + level_start_index = mint.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + + # prepare denoising training + if self.training and self.config.num_denoising > 0 and labels is not None: + ( + denoising_class, + denoising_bbox_unact, + attention_mask, + denoising_meta_values, + ) = get_contrastive_denoising_training_group( + targets=labels, + num_classes=self.config.num_labels, + num_queries=self.config.num_queries, + class_embed=self.denoising_class_embed, + num_denoising_queries=self.config.num_denoising, + label_noise_ratio=self.config.label_noise_ratio, + box_noise_scale=self.config.box_noise_scale, + ) + else: + denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None + + batch_size = len(source_flatten) + dtype = source_flatten.dtype + + # prepare input for decoder + if self.training or self.config.anchor_image_size is None: + # Pass spatial_shapes as tuple to make it hashable and make sure + # lru_cache is working for generate_anchors() + spatial_shapes_tuple = tuple(spatial_shapes_list) + anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, dtype=dtype) + else: + anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = anchors.to(dtype), valid_mask.to(dtype) + + # use the valid_mask to selectively retain values in the feature map where the mask is `True` + memory = valid_mask.to(source_flatten.dtype) * source_flatten + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = mint.topk(enc_outputs_class.max(-1)[0], self.config.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_logits.gather( + dim=1, index=topk_ind.unsqueeze(-1).tile((1, 1, enc_outputs_coord_logits.shape[-1])) + ) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = mint.concat([denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather( + dim=1, index=topk_ind.unsqueeze(-1).tile((1, 1, enc_outputs_class.shape[-1])) + ) + + # extract region features + if self.config.learn_initial_query: + target = self.weight_embedding.tile((batch_size, 1, 1)) + else: + target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).tile((1, 1, output_memory.shape[-1]))) + target = ops.stop_gradient(target) + + if denoising_class is not None: + target = mint.concat([denoising_class, target], 1) + + init_reference_points = ops.stop_gradient(reference_points_unact) + + # decoder + decoder_outputs = self.decoder( + inputs_embeds=target, + encoder_hidden_states=source_flatten, + encoder_attention_mask=attention_mask, + reference_points=init_reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + enc_outputs = tuple( + value + for value in [enc_topk_logits, enc_topk_bboxes, enc_outputs_class, enc_outputs_coord_logits] + if value is not None + ) + dn_outputs = tuple(value if value is not None else None for value in [denoising_meta_values]) + tuple_outputs = decoder_outputs + encoder_outputs + (init_reference_points,) + enc_outputs + dn_outputs + + return tuple_outputs + + return DFineModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + intermediate_logits=decoder_outputs.intermediate_logits, + intermediate_reference_points=decoder_outputs.intermediate_reference_points, + intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners, + initial_reference_points=decoder_outputs.initial_reference_points, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + init_reference_points=init_reference_points, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + enc_outputs_class=enc_outputs_class, + enc_outputs_coord_logits=enc_outputs_coord_logits, + denoising_meta_values=denoising_meta_values, + ) + + +class DFineForObjectDetection(DFinePreTrainedModel): + # When using clones, all layers > 0 will be clones, but layer 0 *is* required + _tied_weights_keys = ["bbox_embed", "class_embed"] + # We can't initialize the model on meta device as some weights are modified during the initialization + _no_split_modules = None + + def __init__(self, config: DFineConfig): + super().__init__(config) + + # D-FINE encoder-decoder model + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx + self.model = DFineModel(config) + scaled_dim = round(config.layer_scale * config.hidden_size) + num_pred = config.decoder_layers + self.class_embed = nn.CellList([mint.nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) + self.bbox_embed = nn.CellList( + [ + DFineMLP(config.hidden_size, config.hidden_size, 4 * (config.max_num_bins + 1), 3) + for _ in range(self.eval_idx + 1) + ] + + [ + DFineMLP(scaled_dim, scaled_dim, 4 * (config.max_num_bins + 1), 3) + for _ in range(config.decoder_layers - self.eval_idx - 1) + ] + ) + + # here self.model.decoder.bbox_embed is null, but not self.bbox_embed + self.model.decoder.class_embed = self.class_embed + self.model.decoder.bbox_embed = self.bbox_embed + + # Initialize weights and apply final processing + self.post_init() + + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)] + + def construct( + self, + pixel_values: ms.Tensor, + pixel_mask: Optional[ms.Tensor] = None, + encoder_outputs: Optional[ms.Tensor] = None, + inputs_embeds: Optional[ms.Tensor] = None, + decoder_inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[list[dict]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[tuple[ms.Tensor], DFineObjectDetectionOutput]: + r""" + Example: + + ```python + >>> import mindspore as ms + >>> from mindone.transformers.image_utils import load_image + >>> from mindone.transformers import AutoImageProcessor, DFineForObjectDetection + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = load_image(url) + + >>> image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-xlarge-coco") + >>> model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-xlarge-coco") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="ms") + + >>> # forward pass + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> list(logits.shape) + [1, 300, 80] + + >>> boxes = outputs.pred_boxes + >>> list(boxes.shape) + [1, 300, 4] + + >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax) + >>> target_sizes = ms.tensor([image.size[::-1]]) + >>> results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes) + >>> result = results[0] # first image in batch + + >>> for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): + ... box = [round(i, 2) for i in box.tolist()] + ... print( + ... f"Detected {model.config.id2label[label.item()]} with confidence " + ... f"{round(score.item(), 3)} at location {box}" + ... ) + Detected cat with confidence 0.958 at location [344.49, 23.4, 639.84, 374.27] + Detected cat with confidence 0.956 at location [11.71, 53.52, 316.64, 472.33] + Detected remote with confidence 0.947 at location [40.46, 73.7, 175.62, 117.57] + Detected sofa with confidence 0.918 at location [0.59, 1.88, 640.25, 474.74] + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + pixel_values, + pixel_mask=pixel_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + denoising_meta_values = outputs.denoising_meta_values if return_dict else outputs[-1] if self.training else None + + outputs_class = outputs.intermediate_logits if return_dict else outputs[2] + outputs_coord = outputs.intermediate_reference_points if return_dict else outputs[3] + predicted_corners = outputs.intermediate_predicted_corners if return_dict else outputs[4] + initial_reference_points = outputs.initial_reference_points if return_dict else outputs[5] + + logits = outputs_class[:, -1] + pred_boxes = outputs_coord[:, -1] + + loss, loss_dict, auxiliary_outputs, enc_topk_logits, enc_topk_bboxes = None, None, None, None, None + if labels is not None: + enc_topk_logits = outputs.enc_topk_logits if return_dict else outputs[-5] + enc_topk_bboxes = outputs.enc_topk_bboxes if return_dict else outputs[-4] + loss, loss_dict, auxiliary_outputs = self.loss_function( + logits, + labels, + pred_boxes, + self.config, + outputs_class, + outputs_coord, + enc_topk_logits=enc_topk_logits, + enc_topk_bboxes=enc_topk_bboxes, + denoising_meta_values=denoising_meta_values, + predicted_corners=predicted_corners, + initial_reference_points=initial_reference_points, + **kwargs, + ) + + if not return_dict: + if auxiliary_outputs is not None: + output = (logits, pred_boxes) + (auxiliary_outputs,) + outputs + else: + output = (logits, pred_boxes) + outputs + return ((loss, loss_dict) + output) if loss is not None else output + + return DFineObjectDetectionOutput( + loss=loss, + loss_dict=loss_dict, + logits=logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + last_hidden_state=outputs.last_hidden_state, + intermediate_hidden_states=outputs.intermediate_hidden_states, + intermediate_logits=outputs.intermediate_logits, + intermediate_reference_points=outputs.intermediate_reference_points, + intermediate_predicted_corners=outputs.intermediate_predicted_corners, + initial_reference_points=outputs.initial_reference_points, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + init_reference_points=outputs.init_reference_points, + enc_topk_logits=outputs.enc_topk_logits, + enc_topk_bboxes=outputs.enc_topk_bboxes, + enc_outputs_class=outputs.enc_outputs_class, + enc_outputs_coord_logits=outputs.enc_outputs_coord_logits, + denoising_meta_values=outputs.denoising_meta_values, + ) + + +# taken from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class DFineMLPPredictionHead(nn.Cell): + """ + Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates, + height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + Origin from https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/DFine_paddle/ppdet/modeling/transformers/utils.py#L453 + + """ + + def __init__(self, config, input_dim, d_model, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [d_model] * (num_layers - 1) + self.layers = nn.CellList([mint.nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])]) + + def construct(self, x): + for i, layer in enumerate(self.layers): + x = mint.nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class DFineMLP(nn.Cell): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, act: str = "relu"): + super().__init__() + self.num_layers = num_layers + hidden_dims = [hidden_dim] * (num_layers - 1) + input_dims = [input_dim] + hidden_dims + output_dims = hidden_dims + [output_dim] + self.layers = nn.CellList([mint.nn.Linear(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, output_dims)]) + self.act = ACT2CLS[act]() + + def construct(self, stat_features: ms.Tensor) -> ms.Tensor: + for i, layer in enumerate(self.layers): + stat_features = self.act(layer(stat_features)) if i < self.num_layers - 1 else layer(stat_features) + return stat_features + + +class DFineLQE(nn.Cell): + def __init__(self, config: DFineConfig): + super().__init__() + self.top_prob_values = config.top_prob_values + self.max_num_bins = config.max_num_bins + self.reg_conf = DFineMLP(4 * (self.top_prob_values + 1), config.lqe_hidden_dim, 1, config.lqe_layers) + + def construct(self, scores: ms.Tensor, pred_corners: ms.Tensor) -> ms.Tensor: + batch_size, length, _ = pred_corners.shape + prob = F.softmax(pred_corners.reshape(batch_size, length, 4, self.max_num_bins + 1), dim=-1) + prob_topk, _ = prob.topk(self.top_prob_values, dim=-1) + stat = mint.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1) + quality_score = self.reg_conf(stat.reshape(batch_size, length, -1)) + scores = scores + quality_score + return scores + + +class DFineConvNormLayer(nn.Cell): + def __init__( + self, + config: DFineConfig, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + groups: int = 1, + padding: Optional[int] = None, + activation: Optional[str] = None, + ): + super().__init__() + self.conv = mint.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + groups=groups, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = mint.nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = mint.nn.Identity() if activation is None else ACT2CLS[activation]() + + def construct(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class DFineRepVggBlock(nn.Cell): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: DFineConfig, in_channels: int, out_channels: int): + super().__init__() + + activation = config.activation_function + hidden_channels = in_channels + self.conv1 = DFineConvNormLayer(config, hidden_channels, out_channels, 3, 1, padding=1) + self.conv2 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, padding=0) + self.activation = mint.nn.Identity() if activation is None else ACT2CLS[activation]() + + def construct(self, x): + y = self.conv1(x) + self.conv2(x) + return self.activation(y) + + +class DFineCSPRepLayer(nn.Cell): + """ + Cross Stage Partial (CSP) network layer with RepVGG blocks. + """ + + def __init__( + self, config: DFineConfig, in_channels: int, out_channels: int, num_blocks: int, expansion: float = 1.0 + ): + super().__init__() + in_channels = in_channels + out_channels = out_channels + activation = config.activation_function + + hidden_channels = int(out_channels * expansion) + self.conv1 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.conv2 = DFineConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation) + self.bottlenecks = nn.CellList( + [DFineRepVggBlock(config, hidden_channels, hidden_channels) for _ in range(num_blocks)] + ) + if hidden_channels != out_channels: + self.conv3 = DFineConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation) + else: + self.conv3 = mint.nn.Identity() + + def construct(self, hidden_state: ms.Tensor) -> ms.Tensor: + hidden_state_1 = self.conv1(hidden_state) + for bottleneck in self.bottlenecks: + hidden_state_1 = bottleneck(hidden_state_1) + hidden_state_2 = self.conv2(hidden_state) + hidden_state_3 = self.conv3(hidden_state_1 + hidden_state_2) + return hidden_state_3 + + +class DFineRepNCSPELAN4(nn.Cell): + def __init__(self, config: DFineConfig, act: str = "silu", numb_blocks: int = 3): + super().__init__() + conv1_dim = config.encoder_hidden_dim * 2 + conv2_dim = config.encoder_hidden_dim + conv3_dim = config.encoder_hidden_dim * 2 + conv4_dim = round(config.hidden_expansion * config.encoder_hidden_dim // 2) + self.conv_dim = conv3_dim // 2 + self.conv1 = DFineConvNormLayer(config, conv1_dim, conv3_dim, 1, 1, activation=act) + self.csp_rep1 = DFineCSPRepLayer(config, conv3_dim // 2, conv4_dim, num_blocks=numb_blocks) + self.conv2 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act) + self.csp_rep2 = DFineCSPRepLayer(config, conv4_dim, conv4_dim, num_blocks=numb_blocks) + self.conv3 = DFineConvNormLayer(config, conv4_dim, conv4_dim, 3, 1, activation=act) + self.conv4 = DFineConvNormLayer(config, conv3_dim + (2 * conv4_dim), conv2_dim, 1, 1, activation=act) + + def construct(self, input_features: ms.Tensor) -> ms.Tensor: + # Split initial features into two branches after first convolution + split_features = list(self.conv1(input_features).split((self.conv_dim, self.conv_dim), 1)) + + # Process branches sequentially + branch1 = self.csp_rep1(split_features[-1]) + branch1 = self.conv2(branch1) + branch2 = self.csp_rep2(branch1) + branch2 = self.conv3(branch2) + + split_features.extend([branch1, branch2]) + merged_features = mint.cat(split_features, 1) + merged_features = self.conv4(merged_features) + return merged_features + + +class DFineSCDown(nn.Cell): + def __init__(self, config: DFineConfig, kernel_size: int, stride: int): + super().__init__() + self.conv1 = DFineConvNormLayer(config, config.encoder_hidden_dim, config.encoder_hidden_dim, 1, 1) + self.conv2 = DFineConvNormLayer( + config, + config.encoder_hidden_dim, + config.encoder_hidden_dim, + kernel_size, + stride, + config.encoder_hidden_dim, + ) + + def construct(self, input_features: ms.Tensor) -> ms.Tensor: + input_features = self.conv1(input_features) + input_features = self.conv2(input_features) + return input_features + + +class DFineEncoder(nn.Cell): + def __init__(self, config: DFineConfig): + super().__init__() + + self.layers = nn.CellList([DFineEncoderLayer(config) for _ in range(config.encoder_layers)]) + + def construct(self, src, src_mask=None, pos_embed=None, output_attentions: bool = False) -> ms.Tensor: + hidden_states = src + for layer in self.layers: + hidden_states = layer( + hidden_states, + attention_mask=src_mask, + position_embeddings=pos_embed, + output_attentions=output_attentions, + ) + return hidden_states + + +class DFineHybridEncoder(nn.Cell): + """ + Decoder consisting of a projection layer, a set of `DFineEncoder`, a top-down Feature Pyramid Network + (FPN) and a bottom-up Path Aggregation Network (PAN). More details on the paper: https://huggingface.co/papers/2304.08069 + + Args: + config: DFineConfig + """ + + def __init__(self, config: DFineConfig): + nn.Cell.__init__(self) + self.config = config + self.in_channels = config.encoder_in_channels + self.num_fpn_stages = len(self.in_channels) - 1 + self.feat_strides = config.feat_strides + self.encoder_hidden_dim = config.encoder_hidden_dim + self.encode_proj_layers = config.encode_proj_layers + self.positional_encoding_temperature = config.positional_encoding_temperature + self.eval_size = config.eval_size + self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels] + self.out_strides = self.feat_strides + + # encoder transformer + self.encoder = nn.CellList([DFineEncoder(config) for _ in range(len(self.encode_proj_layers))]) + # top-down fpn + self.lateral_convs = [] + self.fpn_blocks = [] + for _ in range(len(self.in_channels) - 1, 0, -1): + lateral_layer = DFineConvNormLayer(config, self.encoder_hidden_dim, self.encoder_hidden_dim, 1, 1) + self.lateral_convs.append(lateral_layer) + num_blocks = round(3 * config.depth_mult) + fpn_layer = DFineRepNCSPELAN4(config, numb_blocks=num_blocks) + self.fpn_blocks.append(fpn_layer) + self.lateral_convs = nn.CellList(self.lateral_convs) + self.fpn_blocks = nn.CellList(self.fpn_blocks) + + # bottom-up pan + self.downsample_convs = [] + self.pan_blocks = [] + for _ in range(len(self.in_channels) - 1): + self.downsample_convs.append(DFineSCDown(config, 3, 2)) + num_blocks = round(3 * config.depth_mult) + self.pan_blocks.append(DFineRepNCSPELAN4(config, numb_blocks=num_blocks)) + self.downsample_convs = nn.CellList(self.downsample_convs) + self.pan_blocks = nn.CellList(self.pan_blocks) + + @staticmethod + def build_2d_sincos_position_embedding(width, height, embed_dim=256, temperature=10000.0, dtype=ms.float32): + grid_w = mint.arange(mindspore_int(width)).to(dtype) + grid_h = mint.arange(mindspore_int(height)).to(dtype) + grid_w, grid_h = mint.meshgrid(grid_w, grid_h, indexing="ij") + if embed_dim % 4 != 0: + raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") + pos_dim = embed_dim // 4 + omega = mint.arange(pos_dim).to(dtype) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return mint.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] + + def construct( + self, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + spatial_shapes=None, + level_start_index=None, + valid_ratios=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + [What are attention masks?](../glossary#attention-mask) + position_embeddings (`ms.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + spatial_shapes (`ms.Tensor` of shape `(num_feature_levels, 2)`): + Spatial shapes of each feature map. + level_start_index (`ms.Tensor` of shape `(num_feature_levels)`): + Starting index of each feature map. + valid_ratios (`ms.Tensor` of shape `(batch_size, num_feature_levels, 2)`): + Ratio of valid area in each feature level. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = inputs_embeds + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # encoder + if self.config.encoder_layers > 0: + for i, enc_ind in enumerate(self.encode_proj_layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + height, width = hidden_states[enc_ind].shape[2:] + # flatten [batch, channel, height, width] to [batch, height*width, channel] + src_flatten = hidden_states[enc_ind].flatten(2).permute(0, 2, 1) + if self.training or self.eval_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + width, + height, + self.encoder_hidden_dim, + self.positional_encoding_temperature, + dtype=src_flatten.dtype, + ) + else: + pos_embed = None + + layer_outputs = self.encoder[i]( + src_flatten, + pos_embed=pos_embed, + output_attentions=output_attentions, + ) + hidden_states[enc_ind] = ( + layer_outputs[0].permute(0, 2, 1).reshape(-1, self.encoder_hidden_dim, height, width).contiguous() + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states[enc_ind],) + + # top-down FPN + fpn_feature_maps = [hidden_states[-1]] + for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)): + backbone_feature_map = hidden_states[self.num_fpn_stages - idx - 1] + top_fpn_feature_map = fpn_feature_maps[-1] + # apply lateral block + top_fpn_feature_map = lateral_conv(top_fpn_feature_map) + fpn_feature_maps[-1] = top_fpn_feature_map + # apply fpn block + top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest") + fused_feature_map = mint.concat([top_fpn_feature_map, backbone_feature_map], dim=1) + new_fpn_feature_map = fpn_block(fused_feature_map) + fpn_feature_maps.append(new_fpn_feature_map) + + fpn_feature_maps = fpn_feature_maps[::-1] + + # bottom-up PAN + pan_feature_maps = [fpn_feature_maps[0]] + for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)): + top_pan_feature_map = pan_feature_maps[-1] + fpn_feature_map = fpn_feature_maps[idx + 1] + downsampled_feature_map = downsample_conv(top_pan_feature_map) + fused_feature_map = mint.concat([downsampled_feature_map, fpn_feature_map], dim=1) + new_pan_feature_map = pan_block(fused_feature_map) + pan_feature_maps.append(new_pan_feature_map) + + if not return_dict: + return tuple(v for v in [pan_feature_maps, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=pan_feature_maps, hidden_states=encoder_states, attentions=all_attentions + ) + + +__all__ = ["DFineModel", "DFinePreTrainedModel", "DFineForObjectDetection"] diff --git a/mindone/transformers/models/efficientloftr/__init__.py b/mindone/transformers/models/efficientloftr/__init__.py new file mode 100644 index 0000000000..1461cb5f81 --- /dev/null +++ b/mindone/transformers/models/efficientloftr/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_efficientloftr import * diff --git a/mindone/transformers/models/efficientloftr/modeling_efficientloftr.py b/mindone/transformers/models/efficientloftr/modeling_efficientloftr.py new file mode 100644 index 0000000000..c791e52133 --- /dev/null +++ b/mindone/transformers/models/efficientloftr/modeling_efficientloftr.py @@ -0,0 +1,1272 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Callable, Optional, Union + +from transformers.models.efficientloftr.configuration_efficientloftr import EfficientLoFTRConfig + +import mindspore as ms +from mindspore import mint, nn + +from ...activations import ACT2CLS, ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BackboneOutput +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TransformersKwargs, can_return_tuple, mindspore_int +from ...utils.generic import check_model_inputs + + +@dataclass +class KeypointMatchingOutput(ModelOutput): + r""" + matches (`ms.Tensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`ms.Tensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`ms.Tensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + hidden_states (`tuple[ms.Tensor, ...]`, *optional*): + Tuple of `ms.Tensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)`, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`) + attentions (`tuple[ms.Tensor, ...]`, *optional*): + Tuple of `ms.Tensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) + """ + + matches: Optional[ms.Tensor] = None + matching_scores: Optional[ms.Tensor] = None + keypoints: Optional[ms.Tensor] = None + hidden_states: Optional[tuple[ms.Tensor]] = None + attentions: Optional[tuple[ms.Tensor]] = None + + +class EfficientLoFTRRotaryEmbedding(nn.Cell): + def __init__(self, config: EfficientLoFTRConfig): + super().__init__() + self.config = config + self.rope_type = config.rope_scaling["rope_type"] + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, _ = self.rope_init_fn(self.config) + inv_freq_expanded = inv_freq[None, None, None, :].float().broadcast_to((1, 1, 1, -1)) + + embed_height, embed_width = config.embedding_size + i_indices = mint.ones((embed_height, embed_width)).cumsum(0).float().unsqueeze(-1) + j_indices = mint.ones((embed_height, embed_width)).cumsum(1).float().unsqueeze(-1) + + emb = mint.zeros((1, embed_height, embed_width, self.config.hidden_size // 2)) + emb[:, :, :, 0::2] = i_indices * inv_freq_expanded + emb[:, :, :, 1::2] = j_indices * inv_freq_expanded + + self.inv_freq = emb + + def construct( + self, x: ms.Tensor, position_ids: Optional[tuple[ms.Tensor, ms.Tensor]] = None + ) -> tuple[ms.Tensor, ms.Tensor]: + emb = self.inv_freq.float() + sin = emb.sin() + cos = emb.cos() + + sin = sin.repeat_interleave(2, dim=-1) + cos = cos.repeat_interleave(2, dim=-1) + + sin = sin.to(dtype=x.dtype) + cos = cos.to(dtype=x.dtype) + + return cos, sin + + +# Copied from transformers.models.rt_detr_v2.modeling_rt_detr_v2.RTDetrV2ConvNormLayer with RTDetrV2->EfficientLoFTR +class EfficientLoFTRConvNormLayer(nn.Cell): + def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None): + super().__init__() + self.conv = mint.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding=(kernel_size - 1) // 2 if padding is None else padding, + bias=False, + ) + self.norm = mint.nn.BatchNorm2d(out_channels, config.batch_norm_eps) + self.activation = mint.nn.Identity() if activation is None else ACT2CLS[activation]() + + def construct(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.norm(hidden_state) + hidden_state = self.activation(hidden_state) + return hidden_state + + +class EfficientLoFTRRepVGGBlock(GradientCheckpointingLayer): + """ + RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again". + """ + + def __init__(self, config: EfficientLoFTRConfig, stage_idx: int, block_idx: int): + super().__init__() + in_channels = config.stage_block_in_channels[stage_idx][block_idx] + out_channels = config.stage_block_out_channels[stage_idx][block_idx] + stride = config.stage_block_stride[stage_idx][block_idx] + activation = config.activation_function + self.conv1 = EfficientLoFTRConvNormLayer( + config, in_channels, out_channels, kernel_size=3, stride=stride, padding=1 + ) + self.conv2 = EfficientLoFTRConvNormLayer( + config, in_channels, out_channels, kernel_size=1, stride=stride, padding=0 + ) + self.identity = mint.nn.BatchNorm2d(in_channels) if in_channels == out_channels and stride == 1 else None + self.activation = mint.nn.Identity() if activation is None else ACT2FN[activation] + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + if self.identity is not None: + identity_out = self.identity(hidden_states) + else: + identity_out = 0 + hidden_states = self.conv1(hidden_states) + self.conv2(hidden_states) + identity_out + hidden_states = self.activation(hidden_states) + return hidden_states + + +class EfficientLoFTRRepVGGStage(nn.Cell): + def __init__(self, config: EfficientLoFTRConfig, stage_idx: int): + super().__init__() + self.blocks = [] + for block_idx in range(config.stage_num_blocks[stage_idx]): + self.blocks.append( + EfficientLoFTRRepVGGBlock( + config, + stage_idx, + block_idx, + ) + ) + self.blocks = nn.CellList(self.blocks) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + for block in self.blocks: + hidden_states = block(hidden_states) + return hidden_states + + +class EfficientLoFTRepVGG(nn.Cell): + def __init__(self, config: EfficientLoFTRConfig): + super().__init__() + + self.stages = [] + + for stage_idx in range(len(config.stage_stride)): + stage = EfficientLoFTRRepVGGStage(config, stage_idx) + self.stages.append(stage) + self.stages = nn.CellList(self.stages) + + def construct(self, hidden_states: ms.Tensor) -> list[ms.Tensor]: + outputs = [] + for stage in self.stages: + hidden_states = stage(hidden_states) + outputs.append(hidden_states) + + # Exclude first stage in outputs + outputs = outputs[1:] + return outputs + + +class EfficientLoFTRAggregationLayer(nn.Cell): + def __init__(self, config: EfficientLoFTRConfig): + super().__init__() + + hidden_size = config.hidden_size + + self.q_aggregation = mint.nn.Conv2d( + hidden_size, + hidden_size, + kernel_size=config.q_aggregation_kernel_size, + padding=0, + stride=config.q_aggregation_stride, + bias=False, + groups=hidden_size, + ) + self.kv_aggregation = mint.nn.MaxPool2d( + kernel_size=config.kv_aggregation_kernel_size, stride=config.kv_aggregation_stride + ) + self.norm = mint.nn.LayerNorm(hidden_size) + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: + query_states = hidden_states + is_cross_attention = encoder_hidden_states is not None + kv_states = encoder_hidden_states if is_cross_attention else hidden_states + + query_states = self.q_aggregation(query_states) + if kv_states.dtype == ms.bfloat16: + kv_states = self.kv_aggregation(kv_states.float()).to(ms.bfloat16) + else: + kv_states = self.kv_aggregation(kv_states) + query_states = query_states.permute(0, 2, 3, 1) + kv_states = kv_states.permute(0, 2, 3, 1) + hidden_states = self.norm(query_states) + encoder_hidden_states = self.norm(kv_states) + return hidden_states, encoder_hidden_states + + +# Copied from transformers.models.cohere.modeling_cohere.rotate_half +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = mint.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + + +# Copied from transformers.models.cohere.modeling_cohere.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`ms.Tensor`): The query tensor. + k (`ms.Tensor`): The key tensor. + cos (`ms.Tensor`): The cosine part of the rotary embedding. + sin (`ms.Tensor`): The sine part of the rotary embedding. + position_ids (`ms.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(ms.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + dtype = q.dtype + q = q.float() + k = k.float() + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) + + +# Copied from transformers.models.cohere.modeling_cohere.repeat_kv +def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor: + """ + This is the equivalent of mint.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim)) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.eager_attention_forward +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = mint.matmul(query, key_states.swapaxes(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype) + attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value_states) + attn_output = attn_output.swapaxes(1, 2).contiguous() + + return attn_output, attn_weights + + +class EfficientLoFTRAttention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->EfficientLoFTR + def __init__(self, config: EfficientLoFTRConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = mint.nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = mint.nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = mint.nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[ms.Tensor, Optional[ms.Tensor]]: + batch_size, seq_len, dim = hidden_states.shape + input_shape = hidden_states.shape[:-1] + + query_states = self.q_proj(hidden_states).view(batch_size, seq_len, -1, dim) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + + key_states = self.k_proj(current_states).view(batch_size, seq_len, -1, dim) + value_states = self.v_proj(current_states).view(batch_size, seq_len, -1, self.head_dim).swapaxes(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2) + + query_states = query_states.view(batch_size, seq_len, -1, self.head_dim).swapaxes(1, 2) + key_states = key_states.view(batch_size, seq_len, -1, self.head_dim).swapaxes(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class EfficientLoFTRMLP(nn.Cell): + def __init__(self, config: EfficientLoFTRConfig): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + self.fc1 = mint.nn.Linear(hidden_size * 2, intermediate_size, bias=False) + self.activation = ACT2FN[config.mlp_activation_function] + self.fc2 = mint.nn.Linear(intermediate_size, hidden_size, bias=False) + self.layer_norm = mint.nn.LayerNorm(hidden_size) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class EfficientLoFTRAggregatedAttention(nn.Cell): + def __init__(self, config: EfficientLoFTRConfig, layer_idx: int): + super().__init__() + + self.q_aggregation_kernel_size = config.q_aggregation_kernel_size + self.aggregation = EfficientLoFTRAggregationLayer(config) + self.attention = EfficientLoFTRAttention(config, layer_idx) + self.mlp = EfficientLoFTRMLP(config) + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> ms.Tensor: + batch_size, embed_dim, _, _ = hidden_states.shape + + # Aggregate features + aggregated_hidden_states, aggregated_encoder_hidden_states = self.aggregation( + hidden_states, encoder_hidden_states + ) + _, aggregated_h, aggregated_w, _ = aggregated_hidden_states.shape + + # Multi-head attention + aggregated_hidden_states = aggregated_hidden_states.reshape(batch_size, -1, embed_dim) + aggregated_encoder_hidden_states = aggregated_encoder_hidden_states.reshape(batch_size, -1, embed_dim) + attn_output, _ = self.attention( + aggregated_hidden_states, + aggregated_encoder_hidden_states, + position_embeddings=position_embeddings, + **kwargs, + ) + + # Upsample features + # (batch_size, seq_len, embed_dim) -> (batch_size, embed_dim, h, w) with seq_len = h * w + attn_output = attn_output.permute(0, 2, 1) + attn_output = attn_output.reshape(batch_size, embed_dim, aggregated_h, aggregated_w) + attn_output = mint.nn.functional.interpolate( + attn_output, + scale_factor=self.q_aggregation_kernel_size, + mode="bilinear", + align_corners=False, + recompute_scale_factor=True, + ) + intermediate_states = mint.cat([hidden_states, attn_output], dim=1) + intermediate_states = intermediate_states.permute(0, 2, 3, 1) + output_states = self.mlp(intermediate_states) + output_states = output_states.permute(0, 3, 1, 2) + + hidden_states = hidden_states + output_states + + return hidden_states + + +class EfficientLoFTRLocalFeatureTransformerLayer(GradientCheckpointingLayer): + def __init__(self, config: EfficientLoFTRConfig, layer_idx: int): + super().__init__() + + self.self_attention = EfficientLoFTRAggregatedAttention(config, layer_idx) + self.cross_attention = EfficientLoFTRAggregatedAttention(config, layer_idx) + + def construct( + self, + hidden_states: ms.Tensor, + position_embeddings: tuple[ms.Tensor, ms.Tensor], + **kwargs: Unpack[TransformersKwargs], + ) -> ms.Tensor: + batch_size, _, embed_dim, height, width = hidden_states.shape + + hidden_states = hidden_states.reshape(-1, embed_dim, height, width) + hidden_states = self.self_attention(hidden_states, position_embeddings=position_embeddings, **kwargs) + + encoder_hidden_states = hidden_states.reshape(-1, 2, embed_dim, height, width) + encoder_hidden_states = encoder_hidden_states.flip((1,)) + encoder_hidden_states = encoder_hidden_states.reshape(-1, embed_dim, height, width) + + hidden_states = self.cross_attention(hidden_states, encoder_hidden_states, **kwargs) + hidden_states = hidden_states.reshape(batch_size, -1, embed_dim, height, width) + + return hidden_states + + +class EfficientLoFTRLocalFeatureTransformer(nn.Cell): + def __init__(self, config: EfficientLoFTRConfig): + super().__init__() + self.layers = nn.CellList( + [ + EfficientLoFTRLocalFeatureTransformerLayer(config, layer_idx=i) + for i in range(config.num_attention_layers) + ] + ) + + def construct( + self, + hidden_states: ms.Tensor, + position_embeddings: tuple[ms.Tensor, ms.Tensor], + **kwargs: Unpack[TransformersKwargs], + ) -> ms.Tensor: + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings=position_embeddings, **kwargs) + return hidden_states + + +class EfficientLoFTROutConvBlock(nn.Cell): + def __init__(self, config: EfficientLoFTRConfig, hidden_size: int, intermediate_size: int): + super().__init__() + + self.out_conv1 = mint.nn.Conv2d(hidden_size, intermediate_size, kernel_size=1, stride=1, padding=0, bias=False) + self.out_conv2 = mint.nn.Conv2d( + intermediate_size, intermediate_size, kernel_size=3, stride=1, padding=1, bias=False + ) + self.batch_norm = mint.nn.BatchNorm2d(intermediate_size) + self.activation = ACT2CLS[config.mlp_activation_function]() + self.out_conv3 = mint.nn.Conv2d(intermediate_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=False) + + def construct(self, hidden_states: ms.Tensor, residual_states: ms.Tensor) -> ms.Tensor: + residual_states = self.out_conv1(residual_states) + residual_states = residual_states + hidden_states + residual_states = self.out_conv2(residual_states) + residual_states = self.batch_norm(residual_states) + residual_states = self.activation(residual_states) + residual_states = self.out_conv3(residual_states) + residual_states = mint.nn.functional.interpolate( + residual_states, scale_factor=2.0, mode="bilinear", align_corners=False + ) + return residual_states + + +class EfficientLoFTRFineFusionLayer(nn.Cell): + def __init__(self, config: EfficientLoFTRConfig): + super().__init__() + + self.fine_kernel_size = config.fine_kernel_size + + fine_fusion_dims = config.fine_fusion_dims + self.out_conv = mint.nn.Conv2d( + fine_fusion_dims[0], fine_fusion_dims[0], kernel_size=1, stride=1, padding=0, bias=False + ) + self.out_conv_layers = [] + for i in range(1, len(fine_fusion_dims)): + out_conv = EfficientLoFTROutConvBlock(config, fine_fusion_dims[i], fine_fusion_dims[i - 1]) + self.out_conv_layers.append(out_conv) + self.out_conv_layers = nn.CellList(self.out_conv_layers) + + def forward_pyramid( + self, + hidden_states: ms.Tensor, + residual_states: list[ms.Tensor], + ) -> ms.Tensor: + hidden_states = self.out_conv(hidden_states) + hidden_states = mint.nn.functional.interpolate( + hidden_states, scale_factor=2.0, mode="bilinear", align_corners=False + ) + for i, layer in enumerate(self.out_conv_layers): + hidden_states = layer(hidden_states, residual_states[i]) + + return hidden_states + + def construct( + self, + coarse_features: ms.Tensor, + residual_features: list[ms.Tensor], + ) -> tuple[ms.Tensor, ms.Tensor]: + """ + For each image pair, compute the fine features of pixels. + In both images, compute a patch of fine features center cropped around each coarse pixel. + In the first image, the feature patch is kernel_size large and long. + In the second image, it is (kernel_size + 2) large and long. + """ + batch_size, _, embed_dim, coarse_height, coarse_width = coarse_features.shape + + coarse_features = coarse_features.reshape(-1, embed_dim, coarse_height, coarse_width) + residual_features = list(reversed(residual_features)) + + # 1. Fine feature extraction + fine_features = self.forward_pyramid(coarse_features, residual_features) + _, fine_embed_dim, fine_height, fine_width = fine_features.shape + + fine_features = fine_features.reshape(batch_size, 2, fine_embed_dim, fine_height, fine_width) + fine_features_0 = fine_features[:, 0] + fine_features_1 = fine_features[:, 1] + + # 2. Unfold all local windows in crops + stride = int(fine_height // coarse_height) + fine_features_0 = mint.nn.functional.unfold( + fine_features_0, kernel_size=self.fine_kernel_size, stride=stride, padding=0 + ) + _, _, seq_len = fine_features_0.shape + fine_features_0 = fine_features_0.reshape(batch_size, -1, self.fine_kernel_size**2, seq_len) + fine_features_0 = fine_features_0.permute(0, 3, 2, 1) + + fine_features_1 = mint.nn.functional.unfold( + fine_features_1, kernel_size=self.fine_kernel_size + 2, stride=stride, padding=1 + ) + fine_features_1 = fine_features_1.reshape(batch_size, -1, (self.fine_kernel_size + 2) ** 2, seq_len) + fine_features_1 = fine_features_1.permute(0, 3, 2, 1) + + return fine_features_0, fine_features_1 + + +class EfficientLoFTRPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EfficientLoFTRConfig + base_model_prefix = "efficientloftr" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _supports_flash_attn = True + _supports_sdpa = True + _can_record_outputs = { + "hidden_states": EfficientLoFTRRepVGGBlock, + "attentions": EfficientLoFTRAttention, + } + + def _init_weights(self, module: nn.Cell) -> None: + """Initialize the weights""" + pass + + # Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR + def extract_one_channel_pixel_values(self, pixel_values: ms.Tensor) -> ms.Tensor: + """ + Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same, + extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for EfficientLoFTR. This is + a workaround for the issue discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + pixel_values: ms.Tensor of shape (batch_size, 3, height, width) + + Returns: + pixel_values: ms.Tensor of shape (batch_size, 1, height, width) + + """ + return pixel_values[:, 0, :, :][:, None, :, :] + + +class EfficientLoFTRModel(EfficientLoFTRPreTrainedModel): + def __init__(self, config: EfficientLoFTRConfig): + super().__init__(config) + + self.config = config + self.backbone = EfficientLoFTRepVGG(config) + self.local_feature_transformer = EfficientLoFTRLocalFeatureTransformer(config) + self.rotary_emb = EfficientLoFTRRotaryEmbedding(config=config) + + self.post_init() + + @check_model_inputs + def construct( + self, + pixel_values: ms.Tensor, + labels: Optional[ms.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor + >>> from mindone.transformers import AutoModel + >>> import mindspore as ms + >>> from PIL import Image + >>> import requests + + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true" # noqa: E501 + >>> image1 = Image.open(requests.get(url, stream=True).raw) + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true" # noqa: E501 + >>> image2 = Image.open(requests.get(url, stream=True).raw) + >>> images = [image1, image2] + + >>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr") + >>> model = AutoModel.from_pretrained("zju-community/efficient_loftr") + + >>> inputs = processor(images, return_tensors="np") + >>> inputs = {k: ms.tensor(v) for k, v in inputs.items()} + >>> outputs = model(**inputs) + ```""" + if labels is not None: + raise ValueError("EfficientLoFTR is not trainable, no labels should be provided.") + + if pixel_values.ndim != 5 or pixel_values.shape[1] != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + pixel_values = self.extract_one_channel_pixel_values(pixel_values) + + # 1. Local Feature CNN + features = self.backbone(pixel_values) + # Last stage outputs are coarse outputs + coarse_features = features[-1] + # Rest is residual features used in EfficientLoFTRFineFusionLayer + residual_features = features[:-1] + coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:] + + # 2. Coarse-level LoFTR module + cos, sin = self.rotary_emb(coarse_features) + cos = cos.broadcast_to((batch_size * 2, -1, -1, -1)).reshape(batch_size * 2, -1, coarse_embed_dim) + sin = sin.broadcast_to((batch_size * 2, -1, -1, -1)).reshape(batch_size * 2, -1, coarse_embed_dim) + position_embeddings = (cos, sin) + + coarse_features = coarse_features.reshape(batch_size, 2, coarse_embed_dim, coarse_height, coarse_width) + coarse_features = self.local_feature_transformer( + coarse_features, position_embeddings=position_embeddings, **kwargs + ) + + features = (coarse_features,) + tuple(residual_features) + + return BackboneOutput(feature_maps=features) + + +def mask_border(tensor: ms.Tensor, border_margin: int, value: Union[bool, float, int]) -> ms.Tensor: + """ + Mask a tensor border with a given value + + Args: + tensor (`ms.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): + The tensor to mask + border_margin (`int`) : + The size of the border + value (`Union[bool, int, float]`): + The value to place in the tensor's borders + + Returns: + tensor (`ms.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): + The masked tensor + """ + if border_margin <= 0: + return tensor + + tensor[:, :border_margin, :border_margin, :border_margin, :border_margin] = value + tensor[:, -border_margin:, -border_margin:, -border_margin:, -border_margin:] = value + return tensor + + +def create_meshgrid( + height: Union[int, ms.Tensor], + width: Union[int, ms.Tensor], + normalized_coordinates: bool = False, + dtype: Optional[ms.Type] = None, +) -> ms.Tensor: + """ + Copied from kornia library : kornia/kornia/utils/grid.py:26 + + Generate a coordinate grid for an image. + + When the flag ``normalized_coordinates`` is set to True, the grid is + normalized to be in the range :math:`[-1,1]` to be consistent with the mindspore + function :py:func:`mint.nn.functional.grid_sample`. + + Args: + height (`int`): + The image height (rows). + width (`int`): + The image width (cols). + normalized_coordinates (`bool`): + Whether to normalize coordinates in the range :math:`[-1,1]` in order to be consistent with the + MindSpore function :py:func:`mint.nn.functional.grid_sample`. + dtype (`ms.Type`): + The data type of the generated grid. + + Return: + grid (`ms.Tensor` of shape `(1, height, width, 2)`): + The grid tensor. + + Example: + >>> create_meshgrid(2, 2) + tensor([[[[-1., -1.], + [ 1., -1.]], + + [[-1., 1.], + [ 1., 1.]]]]) + + >>> create_meshgrid(2, 2, normalized_coordinates=False) + tensor([[[[0., 0.], + [1., 0.]], + + [[0., 1.], + [1., 1.]]]]) + + """ + xs = mint.linspace(0, width - 1, width, dtype=dtype) + ys = mint.linspace(0, height - 1, height, dtype=dtype) + if normalized_coordinates: + xs = (xs / (width - 1) - 0.5) * 2 + ys = (ys / (height - 1) - 0.5) * 2 + grid = mint.stack(mint.meshgrid(ys, xs, indexing="ij"), dim=-1) + grid = grid.permute(1, 0, 2).unsqueeze(0) + return grid + + +def spatial_expectation2d(input: ms.Tensor, normalized_coordinates: bool = True) -> ms.Tensor: + r""" + Copied from kornia library : kornia/geometry/subpix/dsnt.py:76 + Compute the expectation of coordinate values using spatial probabilities. + + The input heatmap is assumed to represent a valid spatial probability distribution, + which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`. + + Args: + input (`ms.Tensor` of shape `(batch_size, embed_dim, height, width)`): + The input tensor representing dense spatial probabilities. + normalized_coordinates (`bool`): + Whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return + the coordinates in the range of the input shape. + + Returns: + output (`ms.Tensor` of shape `(batch_size, embed_dim, 2)`) + Expected value of the 2D coordinates. Output order of the coordinates is (x, y). + + Examples: + >>> heatmaps = ms.tensor([[[ + ... [0., 0., 0.], + ... [0., 0., 0.], + ... [0., 1., 0.]]]]) + >>> spatial_expectation2d(heatmaps, False) + tensor([[[1., 2.]]]) + + """ + batch_size, embed_dim, height, width = input.shape + + # Create coordinates grid. + grid = create_meshgrid(height, width, normalized_coordinates) + grid = grid.to(input.dtype) + + pos_x = grid[..., 0].reshape(-1) + pos_y = grid[..., 1].reshape(-1) + + input_flat = input.view(batch_size, embed_dim, -1) + + # Compute the expectation of the coordinates. + expected_y = mint.sum(pos_y * input_flat, -1, keepdim=True) + expected_x = mint.sum(pos_x * input_flat, -1, keepdim=True) + + output = mint.cat([expected_x, expected_y], -1) + + return output.view(batch_size, embed_dim, 2) + + +class EfficientLoFTRForKeypointMatching(EfficientLoFTRPreTrainedModel): + """EfficientLoFTR dense image matcher + + Given two images, we determine the correspondences by: + 1. Extracting coarse and fine features through a backbone + 2. Transforming coarse features through self and cross attention + 3. Matching coarse features to obtain coarse coordinates of matches + 4. Obtaining full resolution fine features by fusing transformed and backbone coarse features + 5. Refining the coarse matches using fine feature patches centered at each coarse match in a two-stage refinement + + Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou. + Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed + In CVPR, 2024. https://arxiv.org/abs/2403.04765 + """ + + def __init__(self, config: EfficientLoFTRConfig): + super().__init__(config) + + self.config = config + self.efficientloftr = EfficientLoFTRModel(config) + self.refinement_layer = EfficientLoFTRFineFusionLayer(config) + + self.post_init() + + def _get_matches_from_scores(self, scores: ms.Tensor) -> tuple[ms.Tensor, ms.Tensor]: + """ + Based on a keypoint score matrix, compute the best keypoint matches between the first and second image. + Since each image pair can have different number of matches, the matches are concatenated together for all pair + in the batch and a batch_indices tensor is returned to specify which match belong to which element in the batch. + + Note: + This step can be done as a postprocessing step, because does not involve any model weights/params. + However, we keep it in the modeling code for consistency with other keypoint matching models. + + Args: + scores (`ms.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`): + Scores of keypoints + + Returns: + matched_indices (`ms.Tensor` of shape `(2, num_matches)`): + Indices representing which pixel in the first image matches which pixel in the second image + matching_scores (`ms.Tensor` of shape `(num_matches,)`): + Scores of each match + """ + batch_size, height0, width0, height1, width1 = scores.shape + + scores = scores.view(batch_size, height0 * width0, height1 * width1) + + # For each keypoint, get the best match + max_0 = scores.max(2, keepdim=True).values + max_1 = scores.max(1, keepdim=True).values + + # 1. Thresholding + mask = scores > self.config.coarse_matching_threshold + + # 2. Border removal + mask = mask.reshape(batch_size, height0, width0, height1, width1) + mask = mask_border(mask, self.config.coarse_matching_border_removal, False) + mask = mask.reshape(batch_size, height0 * width0, height1 * width1) + + # 3. Mutual nearest neighbors + mask = mask * (scores == max_0) * (scores == max_1) + + # 4. Fine coarse matches + masked_scores = scores * mask + matching_scores_0, max_indices_0 = masked_scores.max(1) + matching_scores_1, max_indices_1 = masked_scores.max(2) + + matching_indices = mint.cat([max_indices_0, max_indices_1]).reshape(batch_size, 2, -1) + matching_scores = mint.stack([matching_scores_0, matching_scores_1], dim=1) + + # For the keypoints not meeting the threshold score, set the indices to -1 which corresponds to no matches found + matching_indices = mint.where(matching_scores > 0, matching_indices, -1) + + return matching_indices, matching_scores + + def _coarse_matching( + self, coarse_features: ms.Tensor, coarse_scale: float + ) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor]: + """ + For each image pair, compute the matching confidence between each coarse element (by default (image_height / 8) + * (image_width / 8 elements)) from the first image to the second image. + + Note: + This step can be done as a postprocessing step, because does not involve any model weights/params. + However, we keep it in the modeling code for consistency with other keypoint matching models. + + Args: + coarse_features (`ms.Tensor` of shape `(batch_size, 2, hidden_size, coarse_height, coarse_width)`): + Coarse features + coarse_scale (`float`): Scale between the image size and the coarse size + + Returns: + keypoints (`ms.Tensor` of shape `(batch_size, 2, num_matches, 2)`): + Keypoints coordinates. + matching_scores (`ms.Tensor` of shape `(batch_size, 2, num_matches)`): + The confidence matching score of each keypoint. + matched_indices (`ms.Tensor` of shape `(batch_size, 2, num_matches)`): + Indices which indicates which keypoint in an image matched with which keypoint in the other image. For + both image in the pair. + """ + batch_size, _, embed_dim, height, width = coarse_features.shape + + # (batch_size, 2, embed_dim, height, width) -> (batch_size, 2, height * width, embed_dim) + coarse_features = coarse_features.permute(0, 1, 3, 4, 2) + coarse_features = coarse_features.reshape(batch_size, 2, -1, embed_dim) + + coarse_features = coarse_features / coarse_features.shape[-1] ** 0.5 + coarse_features_0 = coarse_features[:, 0] + coarse_features_1 = coarse_features[:, 1] + + similarity = coarse_features_0 @ coarse_features_1.swapaxes(-1, -2) + similarity = similarity / self.config.coarse_matching_temperature + + if self.config.coarse_matching_skip_softmax: + confidence = similarity + else: + confidence = mint.nn.functional.softmax(similarity, 1) * mint.nn.functional.softmax(similarity, 2) + + confidence = confidence.view(batch_size, height, width, height, width) + matched_indices, matching_scores = self._get_matches_from_scores(confidence) + + keypoints = mint.stack([matched_indices % width, matched_indices // width], dim=-1) * coarse_scale + + return keypoints, matching_scores, matched_indices + + def _get_first_stage_fine_matching( + self, + fine_confidence: ms.Tensor, + coarse_matched_keypoints: ms.Tensor, + fine_window_size: int, + fine_scale: float, + ) -> tuple[ms.Tensor, ms.Tensor]: + """ + For each coarse pixel, retrieve the highest fine confidence score and index. + The index represents the matching between a pixel position in the fine window in the first image and a pixel + position in the fine window of the second image. + For example, for a fine_window_size of 64 (8 * 8), the index 2474 represents the matching between the index 38 + (2474 // 64) in the fine window of the first image, and the index 42 in the second image. This means that 38 + which corresponds to the position (4, 6) (4 // 8 and 4 % 8) is matched with the position (5, 2). In this example + the coarse matched coordinate will be shifted to the matched fine coordinates in the first and second image. + + Note: + This step can be done as a postprocessing step, because does not involve any model weights/params. + However, we keep it in the modeling code for consistency with other keypoint matching models. + + Args: + fine_confidence (`ms.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`): + First stage confidence of matching fine features between the first and the second image + coarse_matched_keypoints (`ms.Tensor` of shape `(2, num_matches, 2)`): + Coarse matched keypoint between the first and the second image. + fine_window_size (`int`): + Size of the window used to refine matches + fine_scale (`float`): + Scale between the size of fine features and coarse features + + Returns: + indices (`ms.Tensor` of shape `(2, num_matches, 1)`): + Indices of the fine coordinate matched in the fine window + fine_matches (`ms.Tensor` of shape `(2, num_matches, 2)`): + Coordinates of matched keypoints after the first fine stage + """ + batch_size, num_keypoints, _, _ = fine_confidence.shape + fine_kernel_size = mindspore_int(fine_window_size**0.5) + + fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, -1) + values, indices = mint.max(fine_confidence, dim=-1) + indices = indices[..., None] + indices_0 = indices // fine_window_size + indices_1 = indices % fine_window_size + + grid = create_meshgrid( + fine_kernel_size, + fine_kernel_size, + normalized_coordinates=False, + dtype=fine_confidence.dtype, + ) + grid = grid - (fine_kernel_size // 2) + 0.5 + grid = grid.reshape(1, 1, -1, 2).broadcast_to((batch_size, num_keypoints, -1, -1)) + delta_0 = mint.gather(grid, 1, indices_0.unsqueeze(-1).broadcast_to((-1, -1, -1, 2))).squeeze(2) + delta_1 = mint.gather(grid, 1, indices_1.unsqueeze(-1).broadcast_to((-1, -1, -1, 2))).squeeze(2) + + fine_matches_0 = coarse_matched_keypoints[:, 0] + delta_0 * fine_scale + fine_matches_1 = coarse_matched_keypoints[:, 1] + delta_1 * fine_scale + + indices = mint.stack([indices_0, indices_1], dim=1) + fine_matches = mint.stack([fine_matches_0, fine_matches_1], dim=1) + + return indices, fine_matches + + def _get_second_stage_fine_matching( + self, + indices: ms.Tensor, + fine_matches: ms.Tensor, + fine_confidence: ms.Tensor, + fine_window_size: int, + fine_scale: float, + ) -> ms.Tensor: + """ + For the given position in their respective fine windows, retrieve the 3x3 fine confidences around this position. + After applying softmax to these confidences, compute the 2D spatial expected coordinates. + Shift the first stage fine matching with these expected coordinates. + + Note: + This step can be done as a postprocessing step, because does not involve any model weights/params. + However, we keep it in the modeling code for consistency with other keypoint matching models. + + Args: + indices (`ms.Tensor` of shape `(batch_size, 2, num_keypoints)`): + Indices representing the position of each keypoint in the fine window + fine_matches (`ms.Tensor` of shape `(2, num_matches, 2)`): + Coordinates of matched keypoints after the first fine stage + fine_confidence (`ms.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`): + Second stage confidence of matching fine features between the first and the second image + fine_window_size (`int`): + Size of the window used to refine matches + fine_scale (`float`): + Scale between the size of fine features and coarse features + + Returns: + fine_matches (`ms.Tensor` of shape `(2, num_matches, 2)`): + Coordinates of matched keypoints after the second fine stage + """ + batch_size, num_keypoints, _, _ = fine_confidence.shape + fine_kernel_size = mindspore_int(fine_window_size**0.5) + + indices_0 = indices[:, 0] + indices_1 = indices[:, 1] + indices_1_i = indices_1 // fine_kernel_size + indices_1_j = indices_1 % fine_kernel_size + + # matches_indices, indices_0, indices_1_i, indices_1_j of shape (num_matches, 3, 3) + batch_indices = mint.arange(batch_size).reshape(batch_size, 1, 1, 1) + matches_indices = mint.arange(num_keypoints).reshape(1, num_keypoints, 1, 1) + indices_0 = indices_0[..., None] + indices_1_i = indices_1_i[..., None] + indices_1_j = indices_1_j[..., None] + + delta = create_meshgrid(3, 3, normalized_coordinates=True).to(ms.int64) + delta = delta[None, ...] + + indices_1_i = indices_1_i + delta[..., 1] + indices_1_j = indices_1_j + delta[..., 0] + + fine_confidence = fine_confidence.reshape( + batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2 + ) + # (batch_size, seq_len, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2) -> (batch_size, seq_len, 3, 3) + fine_confidence = fine_confidence[batch_indices, matches_indices, indices_0, indices_1_i, indices_1_j] + fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, 9) + fine_confidence = mint.nn.functional.softmax( + fine_confidence / self.config.fine_matching_regress_temperature, dim=-1 + ) + + heatmap = fine_confidence.reshape(batch_size, num_keypoints, 3, 3) + fine_coordinates_normalized = spatial_expectation2d(heatmap, True)[0] + + fine_matches_0 = fine_matches[:, 0] + fine_matches_1 = fine_matches[:, 1] + (fine_coordinates_normalized * (3 // 2) * fine_scale) + + fine_matches = mint.stack([fine_matches_0, fine_matches_1], dim=1) + + return fine_matches + + def _fine_matching( + self, + fine_features_0: ms.Tensor, + fine_features_1: ms.Tensor, + coarse_matched_keypoints: ms.Tensor, + fine_scale: float, + ) -> ms.Tensor: + """ + For each coarse pixel with a corresponding window of fine features, compute the matching confidence between fine + features in the first image and the second image. + + Fine features are sliced in two part : + - The first part used for the first stage are the first fine_hidden_size - config.fine_matching_slicedim (64 - 8 + = 56 by default) features. + - The second part used for the second stage are the last config.fine_matching_slicedim (8 by default) features. + + Each part is used to compute a fine confidence tensor of the following shape : + (batch_size, (coarse_height * coarse_width), fine_window_size, fine_window_size) + They correspond to the score between each fine pixel in the first image and each fine pixel in the second image. + + Args: + fine_features_0 (`ms.Tensor` of shape `(num_matches, fine_kernel_size ** 2, fine_kernel_size ** 2)`): + Fine features from the first image + fine_features_1 (`ms.Tensor` of shape `(num_matches, (fine_kernel_size + 2) ** 2, (fine_kernel_size + 2) + ** 2)`): + Fine features from the second image + coarse_matched_keypoints (`ms.Tensor` of shape `(2, num_matches, 2)`): + Keypoint coordinates found in coarse matching for the first and second image + fine_scale (`int`): + Scale between the size of fine features and coarse features + + Returns: + fine_coordinates (`ms.Tensor` of shape `(2, num_matches, 2)`): + Matched keypoint between the first and the second image. All matched keypoints are concatenated in the + second dimension. + + """ + batch_size, num_keypoints, fine_window_size, fine_embed_dim = fine_features_0.shape + fine_matching_slice_dim = self.config.fine_matching_slice_dim + + fine_kernel_size = mindspore_int(fine_window_size**0.5) + + # Split fine features into first and second stage features + split_fine_features_0 = mint.split(fine_features_0, fine_embed_dim - fine_matching_slice_dim, -1) + split_fine_features_1 = mint.split(fine_features_1, fine_embed_dim - fine_matching_slice_dim, -1) + + # Retrieve first stage fine features + fine_features_0 = split_fine_features_0[0] + fine_features_1 = split_fine_features_1[0] + + # Normalize first stage fine features + fine_features_0 = fine_features_0 / fine_features_0.shape[-1] ** 0.5 + fine_features_1 = fine_features_1 / fine_features_1.shape[-1] ** 0.5 + + # Compute first stage confidence + fine_confidence = fine_features_0 @ fine_features_1.swapaxes(-1, -2) + fine_confidence = mint.nn.functional.softmax(fine_confidence, 1) * mint.nn.functional.softmax( + fine_confidence, 2 + ) + fine_confidence = fine_confidence.reshape( + batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2 + ) + fine_confidence = fine_confidence[..., 1:-1, 1:-1] + first_stage_fine_confidence = fine_confidence.reshape( + batch_size, num_keypoints, fine_window_size, fine_window_size + ) + + fine_indices, fine_matches = self._get_first_stage_fine_matching( + first_stage_fine_confidence, + coarse_matched_keypoints, + fine_window_size, + fine_scale, + ) + + # Retrieve second stage fine features + fine_features_0 = split_fine_features_0[1] + fine_features_1 = split_fine_features_1[1] + + # Normalize second stage fine features + fine_features_1 = fine_features_1 / fine_matching_slice_dim**0.5 + + # Compute second stage fine confidence + second_stage_fine_confidence = fine_features_0 @ fine_features_1.swapaxes(-1, -2) + + fine_coordinates = self._get_second_stage_fine_matching( + fine_indices, + fine_matches, + second_stage_fine_confidence, + fine_window_size, + fine_scale, + ) + + return fine_coordinates + + @can_return_tuple + def construct( + self, + pixel_values: ms.Tensor, + labels: Optional[ms.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> KeypointMatchingOutput: + r""" + Examples: + + ```python + >>> from transformers import AutoImageProcessor + >>> from mindone.transformers import AutoModel + >>> import mindspore as ms + >>> from PIL import Image + >>> import requests + + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true" # noqa: E501 + >>> image1 = Image.open(requests.get(url, stream=True).raw) + >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true" # noqa: E501 + >>> image2 = Image.open(requests.get(url, stream=True).raw) + >>> images = [image1, image2] + + >>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr") + >>> model = AutoModel.from_pretrained("zju-community/efficient_loftr") + + >>> inputs = processor(images, return_tensors="np") + >>> inputs = {k: ms.tensor(v) for k, v in inputs.items()} + >>> outputs = model(**inputs) + ```""" + if labels is not None: + raise ValueError("SuperGlue is not trainable, no labels should be provided.") + + # 1. Extract coarse and residual features + model_outputs: BackboneOutput = self.efficientloftr(pixel_values, **kwargs) + features = model_outputs.feature_maps + + # 2. Compute coarse-level matching + coarse_features = features[0] + coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:] + batch_size, _, channels, height, width = pixel_values.shape + coarse_scale = height / coarse_height + coarse_keypoints, coarse_matching_scores, coarse_matched_indices = self._coarse_matching( + coarse_features, coarse_scale + ) + + # 3. Fine-level refinement + residual_features = features[1:] + fine_features_0, fine_features_1 = self.refinement_layer(coarse_features, residual_features) + + # Filter fine features with coarse matches indices + _, _, num_keypoints = coarse_matching_scores.shape + batch_indices = mint.arange(batch_size)[..., None] + fine_features_0 = fine_features_0[batch_indices, coarse_matched_indices[:, 0]] + fine_features_1 = fine_features_1[batch_indices, coarse_matched_indices[:, 1]] + + # 4. Computer fine-level matching + fine_height = mindspore_int(coarse_height * coarse_scale) + fine_scale = height / fine_height + matching_keypoints = self._fine_matching(fine_features_0, fine_features_1, coarse_keypoints, fine_scale) + + matching_keypoints[:, :, :, 0] = matching_keypoints[:, :, :, 0] / width + matching_keypoints[:, :, :, 1] = matching_keypoints[:, :, :, 1] / height + + return KeypointMatchingOutput( + matches=coarse_matched_indices, + matching_scores=coarse_matching_scores, + keypoints=matching_keypoints, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) + + +__all__ = ["EfficientLoFTRPreTrainedModel", "EfficientLoFTRModel", "EfficientLoFTRForKeypointMatching"] diff --git a/mindone/transformers/models/granitemoehybrid/__init__.py b/mindone/transformers/models/granitemoehybrid/__init__.py new file mode 100644 index 0000000000..41f8f9caf4 --- /dev/null +++ b/mindone/transformers/models/granitemoehybrid/__init__.py @@ -0,0 +1,16 @@ +# coding=utf-8 +# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_granitemoehybrid import * diff --git a/mindone/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/mindone/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py new file mode 100644 index 0000000000..630e6e3896 --- /dev/null +++ b/mindone/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -0,0 +1,1622 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_granitemoehybrid.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, TypedDict, Union + +from transformers.models.granitemoehybrid.configuration_granitemoehybrid import GraniteMoeHybridConfig + +import mindspore as ms +import mindspore.mint.nn.functional as F +from mindspore import mint, nn, ops + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, DynamicLayer +from ...generation import GenerationMixin +from ...mindspore_adapter import dtype_to_min +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import can_return_tuple, logging + +selective_state_update = None + +causal_conv1d_update, causal_conv1d_fn = None, None + + +logger = logging.get_logger(__name__) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return mint.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`ms.Tensor`): The query tensor. + k (`ms.Tensor`): The key tensor. + cos (`ms.Tensor`): The cosine part of the rotary embedding. + sin (`ms.Tensor`): The sine part of the rotary embedding. + position_ids (`ms.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(ms.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor: + """ + This is the equivalent of mint.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim)) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = mint.matmul(query, key_states.swapaxes(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype) + attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value_states) + attn_output = attn_output.swapaxes(1, 2).contiguous() + + return attn_output, attn_weights + + +# copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoeHybrid +# no longer copied after attention refactors +class GraniteMoeHybridAttention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + + self.scaling = config.attention_multiplier + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = mint.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = mint.nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = mint.nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = mint.nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[ms.Tensor] = None, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, # None or rope embeddings + **kwargs, + ) -> tuple[ms.Tensor, Optional[ms.Tensor], Optional[tuple[ms.Tensor]]]: + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).swapaxes(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapaxes(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).swapaxes(1, 2) + + cos, sin = position_embeddings if position_embeddings is not None else (None, None) + if position_embeddings is not None: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class HybridMambaAttentionDynamicCache(Cache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + key_cache = None + value_cache = None + is_compileable = False + + def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=ms.float16): + super().__init__(layer_classes=DynamicLayer) + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + conv_kernel_size = config.mamba_d_conv + ssm_state_size = config.mamba_d_state + + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "mamba": + self.conv_states += [ + mint.zeros( + ( + batch_size, + (config.mamba_expand * config.hidden_size + 2 * config.mamba_n_groups * ssm_state_size), + conv_kernel_size, + ), + dtype=dtype, + ) + ] + self.ssm_states += [ + mint.zeros( + ( + batch_size, + config.mamba_n_heads, + config.mamba_d_head, + ssm_state_size, + ), + dtype=dtype, + ) + ] + else: + self.conv_states += [ms.tensor([[]] * batch_size)] + self.ssm_states += [ms.tensor([[]] * batch_size)] + self.transformer_layers.append(i) + + self.key_cache = [ms.tensor([[]] * batch_size) for _ in range(config.num_hidden_layers)] + self.value_cache = [ms.tensor([[]] * batch_size) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: ms.Tensor, + value_states: ms.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[ms.Tensor, ms.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = mint.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = mint.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: ms.Tensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx) + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx) + + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx) + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> tuple[tuple[ms.Tensor], tuple[ms.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[ms.Tensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: ms.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return mint.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.shape[-1] + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].broadcast_to((*input_tensor.shape, chunk_size)) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = mint.tril(mint.ones((chunk_size, chunk_size), dtype=ms.bool_), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = mint.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = mint.tril(mint.ones((chunk_size, chunk_size), dtype=ms.bool_), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -ms.tensor(float("inf"))) + return tensor_segsum + + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Adapted from transformers.models.mamba2.modeling_mamba2.Mamba2Mixer +class GraniteMoeHybridMambaLayer(nn.Cell): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + + The are a few differences between this and Mamba2Mixer: + - The variable use_precomputed_states is slightly different due to the HybridCache structure + - There's a few non-obvious bugs fixed with batching in the slow path that exist in main + - Some extra variables that our layer doesn't need have been removed + - We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged + """ + + def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_n_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * self.hidden_size) + self.layer_idx = layer_idx + self.use_conv_bias = config.mamba_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + self.use_bias = config.mamba_proj_bias + + self.layer_norm_epsilon = config.rms_norm_eps + + self.n_groups = config.mamba_n_groups + self.head_dim = config.mamba_d_head + self.chunk_size = config.mamba_chunk_size + + # FIXME: + self.time_step_limit = (0.0, float("inf")) + self.time_step_min = 0.001 + self.time_step_max = 0.1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + has_bias=config.mamba_conv_bias, + kernel_size=self.conv_kernel_size, + group=self.conv_dim, + pad_mode="pad", + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = mint.nn.Linear( + self.hidden_size, + projection_size, + bias=self.use_bias, + ) + # selective projection used to make dt, B and C input dependent + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = ms.Parameter(mint.ones(self.num_heads), name="dt_bias") + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = mint.arange(1, self.num_heads + 1) + self.A_log = ms.Parameter(mint.log(A), name="A_log") + self.A_log._no_weight_decay = True + self.norm = GraniteMoeHybridRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.D = ms.Parameter(mint.ones(self.num_heads), name="D") + self.D._no_weight_decay = True + + self.out_proj = mint.nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + logger.warning_once("The fast path for GraniteMoeHybrid will be used when running the model on a GPU") + + # fmt: off + def slow_forward( + self, + input_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + ): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + gate, hidden_states_B_C, dt = projected_states.split( + [self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_params.conv_states[self.layer_idx].shape[0] + == cache_params.ssm_states[self.layer_idx].shape[0] + == batch_size + and cache_position is not None + and cache_position[0] > 0 + ) + + # 2. Convolution sequence transformation + if use_precomputed_states: + cache_params.conv_states[self.layer_idx] = cache_params.conv_states[self.layer_idx].roll(shifts=-1, dims=-1) + cache_params.conv_states[self.layer_idx][:, :, -1] = hidden_states_B_C[:, 0, :] + + conv_states = cache_params.conv_states[self.layer_idx] + + hidden_states_B_C = mint.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.swapaxes(1, 2) + conv_states = mint.nn.functional.pad( + hidden_states_B_C_transposed, (self.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.swapaxes(1, 2))[..., :seq_len].swapaxes(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = mint.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -mint.exp(self.A_log.float()) # [num_heads] + if use_precomputed_states: + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.swapaxes(1, 2).broadcast_to((batch_size, dt.shape[-1], self.head_dim)) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].broadcast_to((self.dt_bias.shape[0], self.head_dim)) + + dt = mint.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = mint.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].broadcast_to((self.num_heads, self.head_dim, self.ssm_state_size)).to( + dtype=ms.float32 + ) + # [bsz, num_heads, head_dim, state_size] + dA = (mint.exp(dt[..., None] * A)) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.broadcast_to((batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1])).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = dB * hidden_states[..., None] + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.broadcast_to((batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1])).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = mint.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].broadcast_to((self.D.shape[0], self.head_dim)) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = mint.nn.functional.softplus(dt + self.dt_bias) + dt = mint.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = mint.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = mint.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = mint.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if use_precomputed_states: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = mint.zeros_like(states[:, :1]) + states = mint.cat([previous_states, states], dim=1) + decay_chunk = mint.exp(segment_sum(mint.nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.swapaxes(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = mint.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def construct( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + seq_idx: Optional[ms.Tensor] = None, + **kwargs, + ): + if seq_idx is not None: + raise NotImplementedError( + "`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`" + ) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class GraniteMoeHybridRMSNormGated(ms.nn.Cell): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = ms.Parameter(mint.ones(hidden_size), name="weight") + self.variance_epsilon = eps + + def construct(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + + if gate is not None: + hidden_states = hidden_states * mint.nn.functional.silu(gate.to(ms.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) + + +class GraniteMoeHybridMLP(nn.Cell): + """ + MLP layer for shared experts + + Args: + config: + Configuration object with model hyperparameters. + """ + + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.shared_intermediate_size + self.activation = ACT2FN[config.hidden_act] + self.input_linear = mint.nn.Linear(self.input_size, self.hidden_size * 2, bias=False) + self.output_linear = mint.nn.Linear(self.hidden_size, self.input_size, bias=False) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + hidden_states = self.input_linear(hidden_states) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + hidden_states = self.output_linear(hidden_states) + return hidden_states + + +class GraniteFlashAttentionKwargs(TypedDict, total=False): + """ + Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage. + Use cases include padding-free training. + + Attributes: + cu_seq_lens_q (`ms.Tensor`) + Gets cumulative sequence length for query state. + cu_seq_lens_k (`ms.Tensor`) + Gets cumulative sequence length for key state. + max_length_q (`int`): + Maximum sequence length for query state. + max_length_k (`int`): + Maximum sequence length for key state. + seq_idx (`ms.Tensor): + Index of each packed sequence. + """ + + cu_seq_lens_q: ms.Tensor + cu_seq_lens_k: ms.Tensor + max_length_q: int + max_length_k: int + seq_idx: ms.Tensor + + +class GraniteMoeHybridRMSNorm(nn.Cell): + def __init__(self, hidden_size, eps=1e-6): + """ + GraniteMoeHybridRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = ms.Parameter(mint.ones(hidden_size), name="weight") + self.variance_epsilon = eps + + def construct(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GraniteMoeHybridParallelExperts(nn.Cell): + def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: + """ + Initialize the GraniteMoeHybridParallelExperts module. + The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with + many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and + [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the + [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py) + used in vllm. + + Args: + num_experts (int): + Number of experts. + input_size (int): + Size of the input. + output_size (int): + Size of the output. + """ + super().__init__() + self.weight = ms.Parameter(mint.empty(num_experts, output_size, input_size), name="weight") + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + + def construct(self, inputs, expert_size): + """ + Forward pass of the GraniteMoeHybridParallelExperts module. + + Args: + inputs (Tensor): + Input tensor. + expert_size: + Expert size information. + + Returns: + Tensor: Output tensor. + """ + input_list = inputs.split(expert_size, dim=0) + output_list = [] + for i in range(self.num_experts): + output_list.append(F.linear(input_list[i], self.weight[i])) + results = mint.cat(output_list, dim=0) + return results + + +class GraniteMoeHybridTopKGating(nn.Cell): + def __init__(self, input_size: int, num_experts: int, top_k: int): + """ + Initialize the top-k gating mechanism. + Args: + input_size (`int`): + Size of the input. + num_experts (`int`): + Number of experts. + top_k (`int`): + Number of top experts to select. + """ + super().__init__() + + self.num_experts = num_experts + self.input_size = input_size + self.top_k = top_k + + self.layer = mint.nn.Linear(input_size, num_experts, bias=False) + + def construct(self, hidden_states): + # compute the top_k routing decision + logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts] + top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k] + top_k_gates = mint.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k] + + # compute number of input given to each expert + zeros = mint.zeros( + [top_k_gates.shape[0], self.num_experts], dtype=top_k_gates.dtype + ) # [num_tokens, num_experts] + gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts] + expert_size = gates.long().sum(0) # [num_experts,] + # (and `DataDependentOutputException`) + expert_size = expert_size.tolist() + + # sort and group input tokens according to expert assignment + top_k_experts = top_k_indices.flatten() # [num_tokens * top_k] + _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k] + batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k] + + # gather the gate values for grouped input tokens + top_k_gates = top_k_gates.flatten() # [num_tokens * top_k] + batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k] + + return index_sorted_experts, batch_index, batch_gates, expert_size, logits + + +class GraniteMoeHybridMoE(nn.Cell): + """ + A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. + + Args: + config: + Configuration object with model hyperparameters. + """ + + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.intermediate_size + self.activation = ACT2FN[config.hidden_act] + self.input_linear = GraniteMoeHybridParallelExperts( + config.num_local_experts, self.input_size, self.hidden_size * 2 + ) + self.output_linear = GraniteMoeHybridParallelExperts( + config.num_local_experts, self.hidden_size, self.input_size + ) + + self.router = GraniteMoeHybridTopKGating( + input_size=self.input_size, + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + ) + + def construct(self, layer_input): + """ + Forward pass of the mixture of experts layer. + + Args: + layer_input (Tensor): + Input tensor. + + Returns: + Tensor: + Output tensor. + Tensor: + Router logits. + """ + bsz, length, emb_size = layer_input.shape + layer_input = layer_input.reshape(-1, emb_size) + _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input) + + expert_inputs = layer_input[batch_index] + hidden_states = self.input_linear(expert_inputs, expert_size) + chunked_hidden_states = hidden_states.chunk(2, dim=-1) + hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1] + expert_outputs = self.output_linear(hidden_states, expert_size) + + expert_outputs = expert_outputs * batch_gates[:, None] + + zeros = mint.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype) + layer_output = zeros.index_add(0, batch_index, expert_outputs) + layer_output = layer_output.view(bsz, length, self.input_size) + return layer_output, router_logits + + +class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + # Either attention or mamba will be initialized, depending on the layer type. + self.self_attn = None + if config.num_local_experts > 0: + self.block_sparse_moe = GraniteMoeHybridMoE(config) + self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + self.shared_mlp = GraniteMoeHybridMLP(config) + self.mamba = None + + if config.layers_block_type[layer_idx] == "mamba": + self.mamba = GraniteMoeHybridMambaLayer(config, layer_idx) + else: + self.self_attn = GraniteMoeHybridAttention(config, layer_idx) + self.layer_type = config.layers_block_type[layer_idx] + + # Accept 0 experts: skip MoE if num_local_experts == 0 + self.has_experts = getattr(config, "num_local_experts", 0) > 0 + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[ms.Tensor] = None, + output_router_logits: Optional[bool] = False, + position_embeddings: Optional[tuple[ms.Tensor, ms.Tensor]] = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], + ) -> tuple[ms.Tensor, Optional[tuple[ms.Tensor, ms.Tensor]]]: + """ + Args: + hidden_states (`ms.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`ms.Tensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + past_key_value (`Tuple(ms.Tensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`ms.Tensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + position_embeddings (`tuple[ms.Tensor, ms.Tensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs.Can be used to provide `GraniteFlashAttentionKwargs` for + padding-free training and/or improve torch.compile performance. + """ + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.mamba is not None: + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_position=cache_position, + cache_params=past_key_value, + attention_mask=attention_mask, + **kwargs, + ) + # No attention weights for state space layers + self_attn_weights = None + else: + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states * self.residual_multiplier + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + if self.has_experts: + moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + else: + hidden_states = self.shared_mlp(hidden_states) + router_logits = None + + hidden_states = residual + hidden_states * self.residual_multiplier + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +class GraniteMoeHybridPreTrainedModel(PreTrainedModel): + config: GraniteMoeHybridConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GraniteMoeHybridDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = False + _is_stateful = True + + def _init_weights(self, module): + pass + + +class GraniteMoeHybridRotaryEmbedding(nn.Cell): + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config) + self.inv_freq = inv_freq + self.original_inv_freq = self.inv_freq + + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def construct(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().broadcast_to((position_ids.shape[0], -1, 1)) + position_ids_expanded = position_ids[:, None, :].float() + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).swapaxes(1, 2) + emb = mint.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel): + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = mint.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.CellList( + [GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.embedding_multiplier = config.embedding_multiplier + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + self.position_embedding_type = config.position_embedding_type + self.rotary_emb = GraniteMoeHybridRotaryEmbedding(config) if self.position_embedding_type == "rope" else None + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + def construct( + self, + input_ids: ms.Tensor = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Union[Cache, list[ms.Tensor]]] = None, + inputs_embeds: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[ms.Tensor] = None, + **kwargs: Unpack[GraniteFlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + # overwritten because `HybridMambaAttentionDynamicCache` is needed + if use_cache and past_key_values is None: + logger.warning_once( + "GraniteMoeHybrid requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. " + "Because one was not provided, no cache will be returned." + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = mint.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + # embed positions + hidden_states = inputs_embeds + + position_embeddings = None + # create position embeddings to be shared across the decoder layers + if self.rotary_emb is not None: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + output_router_logits=output_router_logits, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + if layer_outputs[-1] is not None: + # append router logits only of expert layers. Regular MLP layers return `None` as the router logits + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + def _update_causal_mask( + self, + attention_mask: ms.Tensor, + input_tensor: ms.Tensor, + cache_position: ms.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, ms.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if self.config._attn_implementation == "sdpa" and attention_mask is not None and not output_attentions: + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = dtype_to_min(dtype) + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: ms.Tensor, + sequence_length: int, + target_length: int, + dtype: ms.Type, + cache_position: ms.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`ms.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`ms.Type`): + The dtype to use for the 4D attention mask. + cache_position (`ms.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`ms.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = dtype_to_min(dtype) + causal_mask = ops.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype) + if sequence_length != 1: + causal_mask = mint.triu(causal_mask, diagonal=1) + causal_mask *= mint.arange(target_length) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, 1, -1, -1)) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and mint.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +def load_balancing_loss_func( + gate_logits: Union[ms.Tensor, tuple[ms.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[ms.Tensor] = None, +) -> Union[ms.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in MindSpore. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`ms.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + concatenated_gate_logits = mint.cat([layer_gate for layer_gate in gate_logits], dim=0) + + routing_weights = mint.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = mint.topk(routing_weights, top_k, dim=-1) + + expert_mask = mint.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = mint.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = mint.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .broadcast_to((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = mint.sum(expert_mask.float() * expert_attention_mask, dim=0) / mint.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .broadcast_to((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = mint.sum(routing_weights * router_per_expert_attention_mask, dim=0) / mint.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = mint.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: GraniteMoeHybridConfig): + super().__init__(config) + self.model = GraniteMoeHybridModel(config) + self.vocab_size = config.vocab_size + self.lm_head = mint.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def construct( + self, + input_ids: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + position_ids: Optional[ms.Tensor] = None, + past_key_values: Optional[Union[Cache, list[ms.Tensor]]] = None, + inputs_embeds: Optional[ms.Tensor] = None, + labels: Optional[ms.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[ms.Tensor] = None, + logits_to_keep: Union[int, ms.Tensor] = 0, + **kwargs, + ) -> Union[tuple, MoeCausalLMOutputWithPast]: + r""" + labels (`ms.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer + >>> from mindone.transformers import GraniteMoeHybridForCausalLM + >>> import mindspore as ms + + >>> model = GraniteMoeHybridForCausalLM.from_pretrained("ibm/PowerMoE-3b") + >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="np") + >>> inputs = {k: ms.tensor(v) for k, v in inputs.items()} + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + # Only compute necessary logits + hidden_states = outputs[0] + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = logits / self.config.logits_scaling + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Flatten the tokens + loss = self.loss_function( + logits, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + elif use_cache: + past_key_values = HybridMambaAttentionDynamicCache(self.config, input_ids.shape[0], self.dtype) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "cache_position": cache_position, + } + ) + return model_inputs + + +__all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"] diff --git a/mindone/transformers/models/hgnet_v2/__init__.py b/mindone/transformers/models/hgnet_v2/__init__.py new file mode 100644 index 0000000000..57a9b2b52c --- /dev/null +++ b/mindone/transformers/models/hgnet_v2/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_hgnet_v2 import * diff --git a/mindone/transformers/models/hgnet_v2/modeling_hgnet_v2.py b/mindone/transformers/models/hgnet_v2/modeling_hgnet_v2.py new file mode 100644 index 0000000000..fcf5701aa2 --- /dev/null +++ b/mindone/transformers/models/hgnet_v2/modeling_hgnet_v2.py @@ -0,0 +1,480 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/hgnet_v2/modular_hgnet_v2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_hgnet_v2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Baidu Inc and The HuggingFace Inc. team. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from transformers import HGNetV2Config + +import mindspore +from mindspore import Parameter, Tensor, mint, nn + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput, BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention +from ...modeling_utils import PreTrainedModel +from ...utils.backbone_utils import BackboneMixin + + +class HGNetV2PreTrainedModel(PreTrainedModel): + config: HGNetV2Config + base_model_prefix = "hgnetv2" + main_input_name = "pixel_values" + _no_split_modules = ["HGNetV2BasicLayer"] + + +class HGNetV2LearnableAffineBlock(nn.Cell): + def __init__(self, scale_value: float = 1.0, bias_value: float = 0.0): + super().__init__() + self.scale = Parameter(Tensor([scale_value]), requires_grad=True) + self.bias = Parameter(Tensor([bias_value]), requires_grad=True) + + def construct(self, hidden_state: Tensor) -> Tensor: + hidden_state = self.scale * hidden_state + self.bias + return hidden_state + + +class HGNetV2ConvLayer(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + activation: str = "relu", + use_learnable_affine_block: bool = False, + ): + super().__init__() + self.convolution = mint.nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + groups=groups, + padding=(kernel_size - 1) // 2, + bias=False, + ) + self.normalization = mint.nn.BatchNorm2d(out_channels) + self.activation = ACT2FN[activation] if activation is not None else mint.nn.Identity() + if activation and use_learnable_affine_block: + self.lab = HGNetV2LearnableAffineBlock() + else: + self.lab = mint.nn.Identity() + + def construct(self, input: Tensor) -> Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.lab(hidden_state) + return hidden_state + + +class HGNetV2ConvLayerLight(nn.Cell): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, use_learnable_affine_block: bool = False): + super().__init__() + self.conv1 = HGNetV2ConvLayer( + in_channels, + out_channels, + kernel_size=1, + activation=None, + use_learnable_affine_block=use_learnable_affine_block, + ) + self.conv2 = HGNetV2ConvLayer( + out_channels, + out_channels, + kernel_size=kernel_size, + groups=out_channels, + use_learnable_affine_block=use_learnable_affine_block, + ) + + def construct(self, hidden_state: Tensor) -> Tensor: + hidden_state = self.conv1(hidden_state) + hidden_state = self.conv2(hidden_state) + return hidden_state + + +class HGNetV2Embeddings(nn.Cell): + def __init__(self, config: HGNetV2Config): + super().__init__() + + self.stem1 = HGNetV2ConvLayer( + config.stem_channels[0], + config.stem_channels[1], + kernel_size=3, + stride=2, + activation=config.hidden_act, + use_learnable_affine_block=config.use_learnable_affine_block, + ) + self.stem2a = HGNetV2ConvLayer( + config.stem_channels[1], + config.stem_channels[1] // 2, + kernel_size=2, + stride=1, + activation=config.hidden_act, + use_learnable_affine_block=config.use_learnable_affine_block, + ) + self.stem2b = HGNetV2ConvLayer( + config.stem_channels[1] // 2, + config.stem_channels[1], + kernel_size=2, + stride=1, + activation=config.hidden_act, + use_learnable_affine_block=config.use_learnable_affine_block, + ) + self.stem3 = HGNetV2ConvLayer( + config.stem_channels[1] * 2, + config.stem_channels[1], + kernel_size=3, + stride=2, + activation=config.hidden_act, + use_learnable_affine_block=config.use_learnable_affine_block, + ) + self.stem4 = HGNetV2ConvLayer( + config.stem_channels[1], + config.stem_channels[2], + kernel_size=1, + stride=1, + activation=config.hidden_act, + use_learnable_affine_block=config.use_learnable_affine_block, + ) + + self.pool = mint.nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True) + self.num_channels = config.num_channels + + def construct(self, pixel_values: Tensor) -> Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.stem1(pixel_values) + embedding = mint.nn.functional.pad(embedding, (0, 1, 0, 1)) + emb_stem_2a = self.stem2a(embedding) + emb_stem_2a = mint.nn.functional.pad(emb_stem_2a, (0, 1, 0, 1)) + emb_stem_2a = self.stem2b(emb_stem_2a) + pooled_emb = self.pool(embedding.float()).to(embedding.dtype) + embedding = mint.cat([pooled_emb, emb_stem_2a], dim=1) + embedding = self.stem3(embedding) + embedding = self.stem4(embedding) + return embedding + + +class HGNetV2BasicLayer(nn.Cell): + def __init__( + self, + in_channels: int, + middle_channels: int, + out_channels: int, + layer_num: int, + kernel_size: int = 3, + residual: bool = False, + light_block: bool = False, + drop_path: float = 0.0, + use_learnable_affine_block: bool = False, + ): + super().__init__() + self.residual = residual + + self.layers = nn.CellList() + for i in range(layer_num): + temp_in_channels = in_channels if i == 0 else middle_channels + if light_block: + block = HGNetV2ConvLayerLight( + in_channels=temp_in_channels, + out_channels=middle_channels, + kernel_size=kernel_size, + use_learnable_affine_block=use_learnable_affine_block, + ) + else: + block = HGNetV2ConvLayer( + in_channels=temp_in_channels, + out_channels=middle_channels, + kernel_size=kernel_size, + use_learnable_affine_block=use_learnable_affine_block, + stride=1, + ) + self.layers.append(block) + + # feature aggregation + total_channels = in_channels + layer_num * middle_channels + aggregation_squeeze_conv = HGNetV2ConvLayer( + total_channels, + out_channels // 2, + kernel_size=1, + stride=1, + use_learnable_affine_block=use_learnable_affine_block, + ) + aggregation_excitation_conv = HGNetV2ConvLayer( + out_channels // 2, + out_channels, + kernel_size=1, + stride=1, + use_learnable_affine_block=use_learnable_affine_block, + ) + self.aggregation = nn.SequentialCell( + aggregation_squeeze_conv, + aggregation_excitation_conv, + ) + self.drop_path = mint.nn.Dropout(drop_path) if drop_path else mint.nn.Identity() + + def construct(self, hidden_state: Tensor) -> Tensor: + identity = hidden_state + output = [hidden_state] + for layer in self.layers: + hidden_state = layer(hidden_state) + output.append(hidden_state) + hidden_state = mint.cat(output, dim=1) + hidden_state = self.aggregation(hidden_state) + if self.residual: + hidden_state = self.drop_path(hidden_state) + identity + return hidden_state + + +class HGNetV2Stage(nn.Cell): + def __init__(self, config: HGNetV2Config, stage_index: int, drop_path: float = 0.0): + super().__init__() + in_channels = config.stage_in_channels[stage_index] + mid_channels = config.stage_mid_channels[stage_index] + out_channels = config.stage_out_channels[stage_index] + num_blocks = config.stage_num_blocks[stage_index] + num_layers = config.stage_numb_of_layers[stage_index] + downsample = config.stage_downsample[stage_index] + light_block = config.stage_light_block[stage_index] + kernel_size = config.stage_kernel_size[stage_index] + use_learnable_affine_block = config.use_learnable_affine_block + + if downsample: + self.downsample = HGNetV2ConvLayer( + in_channels, in_channels, kernel_size=3, stride=2, groups=in_channels, activation=None + ) + else: + self.downsample = mint.nn.Identity() + + blocks_list = [] + for i in range(num_blocks): + blocks_list.append( + HGNetV2BasicLayer( + in_channels if i == 0 else out_channels, + mid_channels, + out_channels, + num_layers, + residual=False if i == 0 else True, + kernel_size=kernel_size, + light_block=light_block, + drop_path=drop_path, + use_learnable_affine_block=use_learnable_affine_block, + ) + ) + self.blocks = nn.CellList(blocks_list) + + def construct(self, hidden_state: Tensor) -> Tensor: + hidden_state = self.downsample(hidden_state) + for block in self.blocks: + hidden_state = block(hidden_state) + return hidden_state + + +class HGNetV2Encoder(nn.Cell): + def __init__(self, config: HGNetV2Config): + super().__init__() + self.stages = nn.CellList([]) + for stage_index in range(len(config.stage_in_channels)): + resnet_stage = HGNetV2Stage(config, stage_index) + self.stages.append(resnet_stage) + + def construct( + self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True + ) -> BaseModelOutputWithNoAttention: + hidden_states = () if output_hidden_states else None + + for stage in self.stages: + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + hidden_state = stage(hidden_state) + + if output_hidden_states: + hidden_states = hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + ) + + +class HGNetV2Backbone(HGNetV2PreTrainedModel, BackboneMixin): + def __init__(self, config: HGNetV2Config): + super().__init__(config) + super()._init_backbone(config) + self.depths = config.depths + self.num_features = [config.embedding_size] + config.hidden_sizes + self.embedder = HGNetV2Embeddings(config) + self.encoder = HGNetV2Encoder(config) + + # initialize weights and apply final processing + self.post_init() + + def construct( + self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None + ) -> BackboneOutput: + r""" + Examples: + + ```python + >>> from transformers import HGNetV2Config + >>> from mindone.transformers import HGNetV2Backbone + >>> import mindspore as ms + + >>> config = HGNetV2Config() + >>> model = HGNetV2Backbone(config) + + >>> pixel_values = ms.mint.randn(1, 3, 224, 224) + + >>> outputs = model(pixel_values) + + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 2048, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + embedding_output = self.embedder(pixel_values) + + outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True) + + hidden_states = outputs.hidden_states + + feature_maps = () + for idx, stage in enumerate(self.stage_names): + if stage in self.out_features: + feature_maps += (hidden_states[idx],) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=None, + ) + + +class HGNetV2ForImageClassification(HGNetV2PreTrainedModel): + def __init__(self, config: HGNetV2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.embedder = HGNetV2Embeddings(config) + self.encoder = HGNetV2Encoder(config) + self.avg_pool = mint.nn.AdaptiveAvgPool2d((1, 1)) + self.flatten = nn.Flatten() + self.fc = ( + mint.nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else mint.nn.Identity() + ) + + # classification head + self.classifier = nn.CellList([self.avg_pool, self.flatten]) + + # initialize weights and apply final processing + self.post_init() + + def construct( + self, + pixel_values: Optional[Tensor] = None, + labels: Optional[Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> ImageClassifierOutputWithNoAttention: + r""" + labels (`ms.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Examples: + ```python + >>> import mindspore as ms + >>> import requests + >>> from mindone.transformers import HGNetV2ForImageClassification, AutoImageProcessor + >>> from PIL import Image + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> model = HGNetV2ForImageClassification.from_pretrained("ustc-community/hgnet-v2") + >>> processor = AutoImageProcessor.from_pretrained("ustc-community/hgnet-v2") + + >>> inputs = processor(images=image, return_tensors="np") + >>> inputs = {k: ms.tensor(v) for k, v in inputs.items()} + >>> outputs = model(**inputs) + >>> outputs.logits.shape + (1, 2) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + embedding_output = self.embedder(pixel_values) + outputs = self.encoder(embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict) + last_hidden_state = outputs[0] + for layer in self.classifier: + last_hidden_state = layer(last_hidden_state) + logits = self.fc(last_hidden_state) + loss = None + + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == mindspore.long or labels.dtype == mindspore.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + if self.config.problem_type == "regression": + loss_fct = mint.nn.MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = mint.nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = mint.nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return (loss,) + output if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +__all__ = ["HGNetV2Backbone", "HGNetV2PreTrainedModel", "HGNetV2ForImageClassification"] diff --git a/mindone/transformers/models/rt_detr/__init__.py b/mindone/transformers/models/rt_detr/__init__.py index c5dc41ce66..4ab2d201e4 100644 --- a/mindone/transformers/models/rt_detr/__init__.py +++ b/mindone/transformers/models/rt_detr/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .image_processing_rt_detr import * from .modeling_rt_detr import RTDetrForObjectDetection, RTDetrModel, RTDetrPreTrainedModel from .modeling_rt_detr_resnet import RTDetrResNetBackbone, RTDetrResNetPreTrainedModel diff --git a/mindone/transformers/models/rt_detr/image_processing_rt_detr.py b/mindone/transformers/models/rt_detr/image_processing_rt_detr.py new file mode 100644 index 0000000000..37794427ee --- /dev/null +++ b/mindone/transformers/models/rt_detr/image_processing_rt_detr.py @@ -0,0 +1,1086 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for RT-DETR.""" + +import pathlib +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union + +import numpy as np + +from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import BaseImageProcessor, get_size_dict +from ...image_transforms import ( + PaddingMode, + center_to_corners_format, + corners_to_center_format, + pad, + rescale, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_annotations, + validate_preprocess_arguments, +) +from ...utils import ( + filter_out_non_signature_kwargs, + is_mindspore_available, + is_mindspore_tensor, + logging, + requires_backends, +) +from ...utils.generic import TensorType + +if is_mindspore_available(): + import mindspore as ms + from mindspore import mint + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,) + + +# Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = int(raw_size * height / width) + else: + oh = int(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = int(raw_size * width / height) + else: + ow = int(size * width / height) + + return (oh, ow) + + +# Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size +def get_resize_output_image_size( + input_image: np.ndarray, + size: Union[int, tuple[int, int], list[int]], + max_size: Optional[int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. If the desired output size + is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output + image size is computed by keeping the aspect ratio of the input image size. + + Args: + input_image (`np.ndarray`): + The image to resize. + size (`int` or `tuple[int, int]` or `list[int]`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + if isinstance(size, (list, tuple)): + return size + + return get_size_with_aspect_ratio(image_size, size, max_size) + + +# Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width +def get_image_size_for_max_height_width( + input_image: np.ndarray, + max_height: int, + max_width: int, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + Args: + input_image (`np.ndarray`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + +# Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn +def get_numpy_to_framework_fn(arr) -> Callable: + """ + Returns a function that converts a numpy array to the framework of the input array. + + Args: + arr (`np.ndarray`): The array to convert. + """ + if isinstance(arr, np.ndarray): + return np.array + + if is_mindspore_available() and is_mindspore_tensor(arr): + import mindspore as ms + + return ms.tensor + + raise ValueError(f"Cannot convert arrays of type {type(arr)}") + + +# Copied from transformers.models.detr.image_processing_detr.safe_squeeze +def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: + """ + Squeezes an array, but only if the axis specified has dim 1. + """ + if axis is None: + return arr.squeeze() + + try: + return arr.squeeze(axis=axis) + except ValueError: + return arr + + +# Copied from transformers.models.detr.image_processing_detr.normalize_annotation +def normalize_annotation(annotation: dict, image_size: tuple[int, int]) -> dict: + image_height, image_width = image_size + norm_annotation = {} + for key, value in annotation.items(): + if key == "boxes": + boxes = value + boxes = corners_to_center_format(boxes) + boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32) + norm_annotation[key] = boxes + else: + norm_annotation[key] = value + return norm_annotation + + +# Copied from transformers.models.detr.image_processing_detr.max_across_indices +def max_across_indices(values: Iterable[Any]) -> list[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +# Copied from transformers.models.detr.image_processing_detr.get_max_height_width +def get_max_height_width( + images: list[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> list[int]: + """ + Get the maximum height and width across all images in a batch. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images[0]) + + if input_data_format == ChannelDimension.FIRST: + _, max_height, max_width = max_across_indices([img.shape for img in images]) + elif input_data_format == ChannelDimension.LAST: + max_height, max_width, _ = max_across_indices([img.shape for img in images]) + else: + raise ValueError(f"Invalid channel dimension format: {input_data_format}") + return (max_height, max_width) + + +# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask +def make_pixel_mask( + image: np.ndarray, output_size: tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> np.ndarray: + """ + Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + + Args: + image (`np.ndarray`): + Image to make the pixel mask for. + output_size (`tuple[int, int]`): + Output size of the mask. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + mask = np.zeros(output_size, dtype=np.int64) + mask[:input_height, :input_width] = 1 + return mask + + +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by RTDETR. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + + image_id = target["image_id"] + image_id = np.asarray([image_id], dtype=np.int64) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0] + + classes = [obj["category_id"] for obj in annotations] + classes = np.asarray(classes, dtype=np.int64) + + # for conversion to coco api + area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32) + iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=np.int64) + + boxes = [obj["bbox"] for obj in annotations] + # guard against no boxes via resizing + boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = {} + new_target["image_id"] = image_id + new_target["class_labels"] = classes[keep] + new_target["boxes"] = boxes[keep] + new_target["area"] = area[keep] + new_target["iscrowd"] = iscrowd[keep] + new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64) + + if annotations and "keypoints" in annotations[0]: + keypoints = [obj["keypoints"] for obj in annotations] + # Converting the filtered keypoints list to a numpy array + keypoints = np.asarray(keypoints, dtype=np.float32) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + return new_target + + +# Copied from transformers.models.detr.image_processing_detr.resize_annotation +def resize_annotation( + annotation: dict[str, Any], + orig_size: tuple[int, int], + target_size: tuple[int, int], + threshold: float = 0.5, + resample: PILImageResampling = PILImageResampling.NEAREST, +): + """ + Resizes an annotation to a target size. + + Args: + annotation (`dict[str, Any]`): + The annotation dictionary. + orig_size (`tuple[int, int]`): + The original size of the input image. + target_size (`tuple[int, int]`): + The target size of the image, as returned by the preprocessing `resize` step. + threshold (`float`, *optional*, defaults to 0.5): + The threshold used to binarize the segmentation masks. + resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`): + The resampling filter to use when resizing the masks. + """ + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size)) + ratio_height, ratio_width = ratios + + new_annotation = {} + new_annotation["size"] = target_size + + for key, value in annotation.items(): + if key == "boxes": + boxes = value + scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32) + new_annotation["boxes"] = scaled_boxes + elif key == "area": + area = value + scaled_area = area * (ratio_width * ratio_height) + new_annotation["area"] = scaled_area + elif key == "masks": + masks = value[:, None] + masks = np.array([resize(mask, target_size, resample=resample) for mask in masks]) + masks = masks.astype(np.float32) + masks = masks[:, 0] > threshold + new_annotation["masks"] = masks + elif key == "size": + new_annotation["size"] = target_size + else: + new_annotation[key] = value + + return new_annotation + + +class RTDetrImageProcessor(BaseImageProcessor): + r""" + Constructs a RT-DETR image processor. + + Args: + format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be + overridden by the `do_resize` parameter in the `preprocess` method. + size (`dict[str, int]` *optional*, defaults to `{"height": 640, "width": 640}`): + Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter + in the `preprocess` method. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether to normalize the image. + image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean values to use when normalizing the image. Can be a single value or a list of values, one for each + channel. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one + for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `False`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + + model_input_names = ["pixel_values", "pixel_mask"] + + def __init__( + self, + format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = False, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_convert_annotations: bool = True, + do_pad: bool = False, + pad_size: Optional[dict[str, int]] = None, + **kwargs, + ) -> None: + size = size if size is not None else {"height": 640, "width": 640} + size = get_size_dict(size, default_to_square=False) + + if do_convert_annotations is None: + do_convert_annotations = do_normalize + + super().__init__(**kwargs) + self.format = format + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_convert_annotations = do_convert_annotations + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_pad = do_pad + self.pad_size = pad_size + + def prepare_annotation( + self, + image: np.ndarray, + target: dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> dict: + """ + Prepare an annotation for feeding into RTDETR model. + """ + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`np.ndarray`): + Image to resize. + size (`dict[str, int]`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None + size = get_size_dict(size, max_size=max_size, default_to_square=False) + if "shortest_edge" in size and "longest_edge" in size: + new_size = get_resize_output_image_size( + image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format + ) + elif "max_height" in size and "max_width" in size: + new_size = get_image_size_for_max_height_width( + image, size["max_height"], size["max_width"], input_data_format=input_data_format + ) + elif "height" in size and "width" in size: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + image = resize( + image, + size=new_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return image + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation + def resize_annotation( + self, + annotation, + orig_size, + size, + resample: PILImageResampling = PILImageResampling.NEAREST, + ) -> dict: + """ + Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched + to this number. + """ + return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale + def rescale( + self, + image: np.ndarray, + rescale_factor: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Rescale the image by the given factor. image = image * rescale_factor. + + Args: + image (`np.ndarray`): + Image to rescale. + rescale_factor (`float`): + The value to use for rescaling. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. If unset, is inferred from the input image. Can be + one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation + def normalize_annotation(self, annotation: dict, image_size: tuple[int, int]) -> dict: + """ + Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to + `[center_x, center_y, width, height]` format and from absolute to relative pixel values. + """ + return normalize_annotation(annotation, image_size=image_size) + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image + def _update_annotation_for_padded_image( + self, + annotation: dict, + input_image_size: tuple[int, int], + output_image_size: tuple[int, int], + padding, + update_bboxes, + ) -> dict: + """ + Update the annotation for a padded image. + """ + new_annotation = {} + new_annotation["size"] = output_image_size + + for key, value in annotation.items(): + if key == "masks": + masks = value + masks = pad( + masks, + padding, + mode=PaddingMode.CONSTANT, + constant_values=0, + input_data_format=ChannelDimension.FIRST, + ) + masks = safe_squeeze(masks, 1) + new_annotation["masks"] = masks + elif key == "boxes" and update_bboxes: + boxes = value + boxes *= np.asarray( + [ + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + ] + ) + new_annotation["boxes"] = boxes + elif key == "size": + new_annotation["size"] = output_image_size + else: + new_annotation[key] = value + return new_annotation + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image + def _pad_image( + self, + image: np.ndarray, + output_size: tuple[int, int], + annotation: Optional[dict[str, Any]] = None, + constant_values: Union[float, Iterable[float]] = 0, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + ) -> np.ndarray: + """ + Pad an image with zeros to the given size. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = output_size + + pad_bottom = output_height - input_height + pad_right = output_width - input_width + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, + padding, + mode=PaddingMode.CONSTANT, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + if annotation is not None: + annotation = self._update_annotation_for_padded_image( + annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes + ) + return padded_image, annotation + + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad + def pad( + self, + images: list[np.ndarray], + annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None, + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + pad_size: Optional[dict[str, int]] = None, + ) -> BatchFeature: + """ + Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width + in the batch and optionally returns their corresponding pixel mask. + + Args: + images (list[`np.ndarray`]): + Images to pad. + annotations (`AnnotationType` or `list[AnnotationType]`, *optional*): + Annotations to transform according to the padding that is applied to the images. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether to return a pixel mask. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.MINDSPORE` or `'ms'`: Return a batch of type `ms.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + update_bboxes (`bool`, *optional*, defaults to `True`): + Whether to update the bounding boxes in the annotations to match the padded images. If the + bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)` + format, the bounding boxes will not be updated. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + pad_size = pad_size if pad_size is not None else self.pad_size + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images, input_data_format=input_data_format) + + annotation_list = annotations if annotations is not None else [None] * len(images) + padded_images = [] + padded_annotations = [] + for image, annotation in zip(images, annotation_list): + padded_image, padded_annotation = self._pad_image( + image, + padded_size, + annotation, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=update_bboxes, + ) + padded_images.append(padded_image) + padded_annotations.append(padded_annotation) + + data = {"pixel_values": padded_images} + + if return_pixel_mask: + masks = [ + make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format) + for image in images + ] + data["pixel_mask"] = masks + + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations + ] + + return encoded_inputs + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, list[AnnotationType]]] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample=None, # PILImageResampling + do_rescale: Optional[bool] = None, + rescale_factor: Optional[Union[int, float]] = None, + do_normalize: Optional[bool] = None, + do_convert_annotations: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_pad: Optional[bool] = None, + format: Optional[Union[str, AnnotationFormat]] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[dict[str, int]] = None, + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging + from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + annotations (`AnnotationType` or `list[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`list[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`list[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks): + Whether to return segmentation masks. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + do_resize (`bool`, *optional*, defaults to self.do_resize): + Whether to resize the image. + size (`dict[str, int]`, *optional*, defaults to self.size): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to self.resample): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to self.do_rescale): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to self.rescale_factor): + Rescale factor to use when rescaling the image. + do_normalize (`bool`, *optional*, defaults to self.do_normalize): + Whether to normalize the image. + do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations): + Whether to convert the annotations to the format expected by the model. Converts the bounding + boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)` + and in relative coordinates. + image_mean (`float` or `list[float]`, *optional*, defaults to self.image_mean): + Mean to use when normalizing the image. + image_std (`float` or `list[float]`, *optional*, defaults to self.image_std): + Standard deviation to use when normalizing the image. + do_pad (`bool`, *optional*, defaults to self.do_pad): + Whether to pad the image. If `True`, padding will be applied to the bottom and right of + the image with zeros. If `pad_size` is provided, the image will be padded to the specified + dimensions. Otherwise, the image will be padded to the maximum height and width of the batch. + format (`str` or `AnnotationFormat`, *optional*, defaults to self.format): + Format of the annotations. + return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors): + Type of tensors to return. If `None`, will return the list of images. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + pad_size (`dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + do_resize = self.do_resize if do_resize is None else do_resize + size = self.size if size is None else size + size = get_size_dict(size=size, default_to_square=True) + resample = self.resample if resample is None else resample + do_rescale = self.do_rescale if do_rescale is None else do_rescale + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = self.do_normalize if do_normalize is None else do_normalize + image_mean = self.image_mean if image_mean is None else image_mean + image_std = self.image_std if image_std is None else image_std + do_convert_annotations = ( + self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations + ) + do_pad = self.do_pad if do_pad is None else do_pad + pad_size = self.pad_size if pad_size is None else pad_size + format = self.format if format is None else format + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "ms.Tensor, tf.Tensor or jax.ndarray." + ) + + # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError(f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match.") + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + images = make_list_of_images(images) + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "ms.Tensor, tf.Tensor or jax.ndarray." + ) + + # All transformations expect numpy arrays + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + prepared_images = [] + prepared_annotations = [] + for image, target in zip(images, annotations): + target = self.prepare_annotation( + image, + target, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=input_data_format, + ) + prepared_images.append(image) + prepared_annotations.append(target) + images = prepared_images + annotations = prepared_annotations + del prepared_images, prepared_annotations + + # transformations + if do_resize: + if annotations is not None: + resized_images, resized_annotations = [], [] + for image, target in zip(images, annotations): + orig_size = get_image_size(image, input_data_format) + resized_image = self.resize( + image, size=size, resample=resample, input_data_format=input_data_format + ) + resized_annotation = self.resize_annotation( + target, orig_size, get_image_size(resized_image, input_data_format) + ) + resized_images.append(resized_image) + resized_annotations.append(resized_annotation) + images = resized_images + annotations = resized_annotations + del resized_images, resized_annotations + else: + images = [ + self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images] + + if do_normalize: + images = [ + self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images + ] + + if do_convert_annotations and annotations is not None: + annotations = [ + self.normalize_annotation(annotation, get_image_size(image, input_data_format)) + for annotation, image in zip(annotations, images) + ] + + if do_pad: + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + encoded_inputs = self.pad( + images, + annotations=annotations, + return_pixel_mask=True, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=do_convert_annotations, + return_tensors=return_tensors, + pad_size=pad_size, + ) + else: + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + + return encoded_inputs + + def post_process_object_detection( + self, + outputs, + threshold: float = 0.5, + target_sizes: Union[TensorType, list[tuple]] = None, + use_focal_loss: bool = True, + ): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports MindSpore. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + Score threshold to keep object detection predictions. + target_sizes (`ms.Tensor` or `list[tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + use_focal_loss (`bool` defaults to `True`): + Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied + to compute the scores of each detection, otherwise, a softmax function is used. + + Returns: + `list[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + requires_backends(self, ["mindspore"]) + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + # convert from relative cxcywh to absolute xyxy + boxes = center_to_corners_format(out_bbox) + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if isinstance(target_sizes, list): + img_h, img_w = ms.tensor(target_sizes).unbind(1) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = mint.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + num_top_queries = out_logits.shape[1] + num_classes = out_logits.shape[2] + + if use_focal_loss: + scores = mint.nn.functional.sigmoid(out_logits) + scores, index = mint.topk(scores.flatten(1), num_top_queries, dim=-1) + labels = index % num_classes + index = index // num_classes + boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).tile((1, 1, boxes.shape[-1]))) + else: + scores = mint.nn.functional.softmax(out_logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > num_top_queries: + scores, index = mint.topk(scores, num_top_queries, dim=-1) + labels = mint.gather(labels, dim=1, index=index) + boxes = mint.gather(boxes, dim=1, index=index.unsqueeze(-1).tile((1, 1, boxes.shape[-1]))) + + results = [] + for score, label, box in zip(scores, labels, boxes): + results.append( + { + "scores": score[score > threshold], + "labels": label[score > threshold], + "boxes": box[score > threshold], + } + ) + + return results + + +__all__ = ["RTDetrImageProcessor"] diff --git a/tests/transformers_tests/models/d_fine/__init__.py b/tests/transformers_tests/models/d_fine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/d_fine/test_modeling_d_fine.py b/tests/transformers_tests/models/d_fine/test_modeling_d_fine.py new file mode 100644 index 0000000000..9b2df8fe7b --- /dev/null +++ b/tests/transformers_tests/models/d_fine/test_modeling_d_fine.py @@ -0,0 +1,323 @@ +# coding = utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the MindSpore D-FINE model.""" + +import inspect +import math + +import numpy as np +import pytest +import torch +from transformers import DFineConfig, HGNetV2Config + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import floats_numpy + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] + + +class DFineModelTester: + def __init__( + self, + batch_size=3, + is_training=True, + use_labels=True, + n_targets=3, + num_labels=10, + initializer_range=0.02, + layer_norm_eps=1e-5, + batch_norm_eps=1e-5, + # backbone + backbone_config=None, + # encoder HybridEncoder + encoder_hidden_dim=32, + encoder_in_channels=[128, 256, 512], + feat_strides=[8, 16, 32], + encoder_layers=1, + encoder_ffn_dim=64, + encoder_attention_heads=2, + dropout=0.0, + activation_dropout=0.0, + encode_proj_layers=[2], + positional_encoding_temperature=10000, + encoder_activation_function="gelu", + activation_function="silu", + eval_size=None, + normalize_before=False, + # decoder DFineTransformer + d_model=32, + num_queries=30, + decoder_in_channels=[32, 32, 32], + decoder_ffn_dim=64, + num_feature_levels=3, + decoder_n_points=[3, 6, 3], + decoder_n_levels=3, + decoder_layers=2, + decoder_attention_heads=2, + decoder_activation_function="relu", + attention_dropout=0.0, + num_denoising=0, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learn_initial_query=False, + anchor_image_size=None, + image_size=64, + disable_custom_kernels=True, + with_box_refine=True, + decoder_offset_scale=0.5, + eval_idx=-1, + layer_scale=1, + reg_max=32, + reg_scale=4.0, + depth_mult=0.34, + hidden_expansion=0.5, + ): + self.batch_size = batch_size + self.num_channels = 3 + self.is_training = is_training + self.use_labels = use_labels + self.n_targets = n_targets + self.num_labels = num_labels + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.batch_norm_eps = batch_norm_eps + self.backbone_config = backbone_config + self.encoder_hidden_dim = encoder_hidden_dim + self.encoder_in_channels = encoder_in_channels + self.feat_strides = feat_strides + self.encoder_layers = encoder_layers + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.dropout = dropout + self.activation_dropout = activation_dropout + self.encode_proj_layers = encode_proj_layers + self.positional_encoding_temperature = positional_encoding_temperature + self.encoder_activation_function = encoder_activation_function + self.activation_function = activation_function + self.eval_size = eval_size + self.normalize_before = normalize_before + self.d_model = d_model + self.num_queries = num_queries + self.decoder_in_channels = decoder_in_channels + self.decoder_ffn_dim = decoder_ffn_dim + self.num_feature_levels = num_feature_levels + self.decoder_n_points = decoder_n_points + self.decoder_n_levels = decoder_n_levels + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.decoder_activation_function = decoder_activation_function + self.attention_dropout = attention_dropout + self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale + self.reg_max = reg_max + self.reg_scale = reg_scale + self.depth_mult = depth_mult + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + self.learn_initial_query = learn_initial_query + self.anchor_image_size = anchor_image_size + self.image_size = image_size + self.disable_custom_kernels = disable_custom_kernels + self.with_box_refine = with_box_refine + self.hidden_expansion = hidden_expansion + + self.encoder_seq_length = math.ceil(self.image_size / 32) * math.ceil(self.image_size / 32) + + def prepare_config_and_inputs(self): + pixel_values = floats_numpy([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + pixel_mask = np.ones([self.batch_size, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + # labels is a list of Dict (each Dict being the labels for a given example in the batch) + labels = [] + for i in range(self.batch_size): + target = {} + target["class_labels"] = np.random.randint(low=0, high=self.num_labels, size=(self.n_targets,)) + target["boxes"] = np.random.rand(self.n_targets, 4) + labels.append(target) + + config = self.get_config() + config.num_labels = self.num_labels + return config, pixel_values, pixel_mask, labels + + def get_config(self): + hidden_sizes = [64, 128, 256, 512] + backbone_config = HGNetV2Config( + stage_in_channels=[16, 64, 128, 256], + stage_mid_channels=[16, 32, 64, 128], + stage_out_channels=[64, 128, 256, 512], + stage_num_blocks=[1, 1, 2, 1], + stage_downsample=[False, True, True, True], + stage_light_block=[False, False, True, True], + stage_kernel_size=[3, 3, 5, 5], + stage_numb_of_layers=[3, 3, 3, 3], + embeddings_size=10, + hidden_sizes=hidden_sizes, + depths=[1, 1, 2, 1], + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + stem_channels=[3, 16, 16], + use_lab=True, + ) + return DFineConfig.from_backbone_configs( + backbone_config=backbone_config, + encoder_hidden_dim=self.encoder_hidden_dim, + encoder_in_channels=self.encoder_in_channels, + feat_strides=self.feat_strides, + encoder_layers=self.encoder_layers, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + dropout=self.dropout, + activation_dropout=self.activation_dropout, + encode_proj_layers=self.encode_proj_layers, + positional_encoding_temperature=self.positional_encoding_temperature, + encoder_activation_function=self.encoder_activation_function, + activation_function=self.activation_function, + eval_size=self.eval_size, + normalize_before=self.normalize_before, + d_model=self.d_model, + num_queries=self.num_queries, + decoder_in_channels=self.decoder_in_channels, + decoder_ffn_dim=self.decoder_ffn_dim, + num_feature_levels=self.num_feature_levels, + decoder_n_points=self.decoder_n_points, + decoder_n_levels=self.decoder_n_levels, + decoder_layers=self.decoder_layers, + decoder_attention_heads=self.decoder_attention_heads, + decoder_activation_function=self.decoder_activation_function, + decoder_offset_scale=self.decoder_offset_scale, + eval_idx=self.eval_idx, + layer_scale=self.layer_scale, + reg_max=self.reg_max, + reg_scale=self.reg_scale, + depth_mult=self.depth_mult, + attention_dropout=self.attention_dropout, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, + learn_initial_query=self.learn_initial_query, + anchor_image_size=self.anchor_image_size, + image_size=self.image_size, + disable_custom_kernels=self.disable_custom_kernels, + with_box_refine=self.with_box_refine, + ) + + +model_tester = DFineModelTester() +config, pixel_values, pixel_mask, labels = model_tester.prepare_config_and_inputs() +DFINE_CASES = [ + [ + "DFineModel", + "transformers.DFineModel", + "mindone.transformers.DFineModel", + (config,), + {}, + (pixel_values,), + { + "pixel_mask": pixel_mask, + "labels": labels, + }, + { + "last_hidden_state": 0, + }, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in DFINE_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + # print("ms:", ms_outputs) + # print("pt:", pt_outputs) + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + # print("===map", pt_key, ms_idx) + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + ) diff --git a/tests/transformers_tests/models/efficientloftr/__init__.py b/tests/transformers_tests/models/efficientloftr/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/efficientloftr/test_modeling_efficientloftr.py b/tests/transformers_tests/models/efficientloftr/test_modeling_efficientloftr.py new file mode 100644 index 0000000000..eb7b7faa36 --- /dev/null +++ b/tests/transformers_tests/models/efficientloftr/test_modeling_efficientloftr.py @@ -0,0 +1,185 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + +import numpy as np +import pytest +import torch +from transformers import EfficientLoFTRConfig + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import floats_numpy + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] + + +class EfficientLoFTRModelTester: + def __init__( + self, + batch_size=2, + image_width=80, + image_height=60, + stage_num_blocks: list[int] = [1, 1, 1], + out_features: list[int] = [32, 32, 64], + stage_stride: list[int] = [2, 1, 2], + q_aggregation_kernel_size: int = 1, + kv_aggregation_kernel_size: int = 1, + q_aggregation_stride: int = 1, + kv_aggregation_stride: int = 1, + num_attention_layers: int = 2, + num_attention_heads: int = 8, + hidden_size: int = 64, + coarse_matching_threshold: float = 0.0, + fine_kernel_size: int = 2, + coarse_matching_border_removal: int = 0, + ): + self.batch_size = batch_size + self.image_width = image_width + self.image_height = image_height + + self.stage_num_blocks = stage_num_blocks + self.out_features = out_features + self.stage_stride = stage_stride + self.q_aggregation_kernel_size = q_aggregation_kernel_size + self.kv_aggregation_kernel_size = kv_aggregation_kernel_size + self.q_aggregation_stride = q_aggregation_stride + self.kv_aggregation_stride = kv_aggregation_stride + self.num_attention_layers = num_attention_layers + self.num_attention_heads = num_attention_heads + self.hidden_size = hidden_size + self.coarse_matching_threshold = coarse_matching_threshold + self.coarse_matching_border_removal = coarse_matching_border_removal + self.fine_kernel_size = fine_kernel_size + + def prepare_config_and_inputs(self): + # EfficientLoFTR expects a grayscale image as input + pixel_values = floats_numpy([self.batch_size, 2, 3, self.image_height, self.image_width]) + config = self.get_config() + return config, pixel_values + + def get_config(self): + return EfficientLoFTRConfig( + stage_num_blocks=self.stage_num_blocks, + out_features=self.out_features, + stage_stride=self.stage_stride, + q_aggregation_kernel_size=self.q_aggregation_kernel_size, + kv_aggregation_kernel_size=self.kv_aggregation_kernel_size, + q_aggregation_stride=self.q_aggregation_stride, + kv_aggregation_stride=self.kv_aggregation_stride, + num_attention_layers=self.num_attention_layers, + num_attention_heads=self.num_attention_heads, + hidden_size=self.hidden_size, + coarse_matching_threshold=self.coarse_matching_threshold, + coarse_matching_border_removal=self.coarse_matching_border_removal, + fine_kernel_size=self.fine_kernel_size, + ) + + +model_tester = EfficientLoFTRModelTester() +config, pixel_values = model_tester.prepare_config_and_inputs() +EFFICIENTLOFTR_CASES = [ + [ + "EfficientLoFTRModel", + "transformers.EfficientLoFTRModel", + "mindone.transformers.EfficientLoFTRModel", + (config,), + {}, + (pixel_values,), + {}, + {}, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in EFFICIENTLOFTR_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs)["feature_maps"][-1] + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs)["feature_maps"][-1] + # print("ms:", ms_outputs) + # print("pt:", pt_outputs) + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + # print("===map", pt_key, ms_idx) + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + ) diff --git a/tests/transformers_tests/models/hgnet_v2/__init__.py b/tests/transformers_tests/models/hgnet_v2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/hgnet_v2/test_modeing_hgnet_v2.py b/tests/transformers_tests/models/hgnet_v2/test_modeing_hgnet_v2.py new file mode 100644 index 0000000000..a38858391b --- /dev/null +++ b/tests/transformers_tests/models/hgnet_v2/test_modeing_hgnet_v2.py @@ -0,0 +1,223 @@ +"""Adapted from https://github.com/huggingface/transformers/tree/main/tests/models/hgnet_v2/test_modeling_hgnet_v2.py.""" + +# This module contains test cases that are defined in the `.test_cases.py` file, structured as lists or tuples like +# [name, pt_module, ms_module, init_args, init_kwargs, inputs_args, inputs_kwargs, outputs_map]. +# +# Each defined case corresponds to a pair consisting of PyTorch and MindSpore modules, including their respective +# initialization parameters and inputs for the forward. The testing framework adopted here is designed to generically +# parse these parameters to assess and compare the precision of forward outcomes between the two frameworks. +# +# In cases where models have unique initialization procedures or require testing with specialized output formats, +# it is necessary to develop distinct, dedicated test cases. +import inspect + +import numpy as np +import pytest +import torch +from transformers import HGNetV2Config + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import floats_numpy, ids_numpy + +# ms.nn.MaxPool2d does not support bf16 inputs +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3} +MODES = [1] + + +class HGNetV2ModelTester: + def __init__( + self, + batch_size=3, + image_size=32, + num_channels=3, + embeddings_size=10, + hidden_sizes=[64, 128, 256, 512], + stage_in_channels=[16, 64, 128, 256], + stage_mid_channels=[16, 32, 64, 128], + stage_out_channels=[64, 128, 256, 512], + stage_num_blocks=[1, 1, 2, 1], + stage_downsample=[False, True, True, True], + stage_light_block=[False, False, True, True], + stage_kernel_size=[3, 3, 5, 5], + stage_numb_of_layers=[3, 3, 3, 3], + stem_channels=[3, 16, 16], + depths=[1, 1, 2, 1], + is_training=True, + use_labels=True, + hidden_act="relu", + num_labels=3, + scope=None, + out_features=["stage2", "stage3", "stage4"], + out_indices=[2, 3, 4], + ): + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.embeddings_size = embeddings_size + self.hidden_sizes = hidden_sizes + self.stage_in_channels = stage_in_channels + self.stage_mid_channels = stage_mid_channels + self.stage_out_channels = stage_out_channels + self.stage_num_blocks = stage_num_blocks + self.stage_downsample = stage_downsample + self.stage_light_block = stage_light_block + self.stage_kernel_size = stage_kernel_size + self.stage_numb_of_layers = stage_numb_of_layers + self.stem_channels = stem_channels + self.depths = depths + self.is_training = is_training + self.use_labels = use_labels + self.hidden_act = hidden_act + self.num_labels = num_labels + self.scope = scope + self.num_stages = len(hidden_sizes) + self.out_features = out_features + self.out_indices = out_indices + + def prepare_config_and_inputs(self): + pixel_values = floats_numpy([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_numpy([self.batch_size], self.num_labels) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return HGNetV2Config( + num_channels=self.num_channels, + embeddings_size=self.embeddings_size, + hidden_sizes=self.hidden_sizes, + stage_in_channels=self.stage_in_channels, + stage_mid_channels=self.stage_mid_channels, + stage_out_channels=self.stage_out_channels, + stage_num_blocks=self.stage_num_blocks, + stage_downsample=self.stage_downsample, + stage_light_block=self.stage_light_block, + stage_kernel_size=self.stage_kernel_size, + stage_numb_of_layers=self.stage_numb_of_layers, + stem_channels=self.stem_channels, + depths=self.depths, + hidden_act=self.hidden_act, + num_labels=self.num_labels, + out_features=self.out_features, + out_indices=self.out_indices, + ) + + +model_tester = HGNetV2ModelTester() +config, pixel_values, labels = model_tester.prepare_config_and_inputs() + + +TEST_CASES = [ + [ + "HGNetV2Backbone", + "transformers.HGNetV2Backbone", + "mindone.transformers.HGNetV2Backbone", + (config,), + {}, + (pixel_values, None), + {}, + { + "feature_maps": "feature_maps", + }, + ], + [ + "HGNetV2ForImageClassification", + "transformers.HGNetV2ForImageClassification", + "mindone.transformers.HGNetV2ForImageClassification", + (config,), + {}, + (pixel_values, labels), + {}, + { + "loss": "loss", + "logits": "logits", + }, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in TEST_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + # print("ms:", ms_outputs) + # print("pt:", pt_outputs) + + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + # print("===map", pt_key, ms_idx) + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + )