Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions mindone/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,51 @@ def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callabl
return mask_function


# We add a patch for `mindspore.vmap` substitution
def _vmap_patch(
mask_function: Callable,
batch_size: ms.Tensor,
head_dim: ms.Tensor,
cache_postion: ms.Tensor,
kv_range: ms.Tensor,
bh_indices: bool = True,
) -> Callable:
"""
Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
the batch and head indices as well if `bh_indices=True`.
Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).

Args:
mask_function (`Callable`):
The mask_function to vmap.
batch_size (ms.Tensor):
head_dim (ms.Tensor):
cache_postion (ms.Tensor):
kv_range (ms.Tensor):
bh_indices (`bool`, optional):
Whether to vmap over the batch and head indices as well, or only q and kv indices.

Returns:
causal_mask (ms.bool_)
"""
bs = batch_size.shape[0]
h = head_dim.shape[0]
q_len = cache_postion.shape[0]
kv_len = kv_range.shape[0]
if bh_indices:
causal_mask = mint.zeros((bs, h, q_len, kv_len), dtype=ms.bool_)
for i in range(bs):
for j in range(kv_len):
causal_mask[i, :, :, j] = mask_function(batch_size[i], head_dim, cache_postion, kv_range[j].item())
else:
causal_mask = mint.zeros((q_len, kv_len), dtype=ms.bool_)
for i in range(kv_len):
causal_mask[:, i] = mask_function(batch_size, head_dim, cache_postion, kv_range[i].item())

return causal_mask


def prepare_padding_mask(
attention_mask: Optional[ms.Tensor], kv_length: int, kv_offset: int, _slice: bool = True
) -> Optional[ms.Tensor]:
Expand Down Expand Up @@ -304,20 +349,21 @@ def sdpa_mask_recent_torch(

# Similar to `kv_arange = mint.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = mint.arange(kv_length, device=cache_position.device)
kv_arange = mint.arange(kv_length)
kv_arange += kv_offset

# Potentially add the padding 2D mask
if padding_mask is not None:
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))

batch_arange = mint.arange(batch_size, device=cache_position.device)
head_arange = mint.arange(1, device=cache_position.device)
batch_arange = mint.arange(batch_size)
head_arange = mint.arange(1)
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
# with TransformGetItemToIndex():
causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
# TODO there is a compile problem if using `mindspore vmap`, so a patch is used as substition for this operator
causal_mask = _vmap_patch(mask_function, batch_arange, head_arange, cache_position, kv_arange)

return causal_mask

Expand Down Expand Up @@ -383,7 +429,8 @@ def sdpa_mask_older_torch(
# as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow
# However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have
# `sdpa_mask_recent_torch`, as it allows more general `mask_function`
causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange)
# TODO there is a compile problem if using `mindspore vmap`, so a patch is used as substition for this operator
causal_mask = _vmap_patch(mask_function, None, None, cache_position, kv_arange, bh_indices=False)
causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, -1, -1, -1))
if padding_mask is not None:
causal_mask = causal_mask * padding_mask[:, None, None, :]
Expand Down Expand Up @@ -436,7 +483,8 @@ def _ignore_causal_mask_sdpa(

# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions
# (especially mask_function indexing a tensor, such as the padding mask function)
sdpa_mask = sdpa_mask_older_torch # TODO: use sdpa_mask_recent_torch orsdpa_mask_older_torch?
# TODO we do not set sdpa_mask based on torch version like transformers setting, we use `sdpa_mask_recent_torch` directly
sdpa_mask = sdpa_mask_recent_torch


def eager_mask(
Expand Down
25 changes: 23 additions & 2 deletions mindone/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2454,6 +2454,18 @@ def from_pretrained(
if "attn_implementation" in kwargs:
config._attn_implementation = kwargs.pop("attn_implementation")

transformers_explicit_filename = getattr(config, "transformers_weights", None)

if transformers_explicit_filename is not None:
if not transformers_explicit_filename.endswith(
".safetensors"
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
raise ValueError(
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
f"{transformers_explicit_filename}"
)

# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
Expand All @@ -2469,7 +2481,13 @@ def from_pretrained(
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
if from_tf and os.path.isfile(
if transformers_explicit_filename is not None:
# If the filename is explicitly defined, load this by default.
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, transformers_explicit_filename
)
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to determine is_sharded is duplicated on line 2580. To improve maintainability and avoid repetition, consider determining this value once, right after transformers_explicit_filename is retrieved from the config.

elif from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
# Load from a TF 1.0 checkpoint in priority if from_tf
Expand Down Expand Up @@ -2558,7 +2576,10 @@ def from_pretrained(
resolved_archive_file = download_url(pretrained_model_name_or_path)
else:
# set correct filename
if from_tf:
if transformers_explicit_filename is not None:
filename = transformers_explicit_filename
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
elif from_tf:
filename = TF2_WEIGHTS_NAME
elif from_flax:
filename = FLAX_WEIGHTS_NAME
Expand Down