Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def ms_data_collator(features, batch_info):
model = LlamaForSequenceClassification.from_pretrained(
args.model_path,
num_labels=5,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
mindspore_dtype=ms.bfloat16 if args.bf16 else (ms.float16 if args.fp16 else None),
)
model.gradient_checkpointing_enable()
Expand Down
49 changes: 25 additions & 24 deletions mindone/transformers/mindspore_adapter/train_onestep_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Literal, Optional

import mindspore as ms
from mindspore import ParallelMode, Tensor, context, nn, ops
from mindspore.amp import LossScaler
from mindspore.ops import composite as C

if TYPE_CHECKING:
Expand Down Expand Up @@ -89,15 +90,15 @@ def create_grad_reducer(trainable_parameters):


class LossWithScaleSense(nn.Cell):
def __init__(self, network: nn.Cell) -> None:
def __init__(self, network: nn.Cell, scaler: LossScaler) -> None:
super().__init__(auto_prefix=False)
self.network = network
self.scaler = scaler

def construct(self, *args, scale_sense: float = 1.0, **kwargs) -> Tensor:
def construct(self, *args, **kwargs) -> Tensor:
loss = self.network(*args, **kwargs)
if isinstance(scale_sense, ms.Tensor):
scale_sense = scale_sense.to(loss.dtype)
loss = loss * scale_sense
scale_sense = self.scaler.scale_value
loss = loss * scale_sense.to(loss.dtype)
return loss


Expand All @@ -118,7 +119,7 @@ def __init__(
optimizer: nn.Optimizer,
ema: nn.Cell = None,
drop_overflow_step: bool = True,
scaler: str = "default",
scaler: Literal["default", "static", "auto", "dynamic", "none"] = "default",
scaler_config: Dict = {},
gradient_accumulation_steps: int = 1,
clip_grad: str = "none",
Expand All @@ -139,8 +140,23 @@ def __init__(
reducer = create_grad_reducer(network.trainable_params())
is_zero = False

# scaler and reducer
assert "ms_loss_scaler" not in scaler_config
if scaler.lower() in ("default", "static"):
_scaler_config = {"scale_value": 1024}
_scaler_config.update(scaler_config)
scaler = create_loss_scaler("static", **_scaler_config)
elif scaler.lower() in ("auto", "dynamic"):
scaler = create_loss_scaler("dynamic", **scaler_config)
elif scaler.lower() == "none":
scaler = create_loss_scaler("none", **scaler_config)
else:
raise NotImplementedError
Comment on lines +145 to +154
Copy link
Collaborator

@hadipash hadipash Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please specify scaler possible values with Literal in the __init__() definition.


self.scaler = scaler

# wrap network with scale sense
network = LossWithScaleSense(network)
network = LossWithScaleSense(network, self.scaler)

# grad accumulation
assert gradient_accumulation_steps >= 1
Expand Down Expand Up @@ -177,21 +193,6 @@ def construct(self, *args, **kwargs):

self.optimizer = optimizer
self.ema = ema

# scaler and reducer
assert "ms_loss_scaler" not in scaler_config
if scaler.lower() in ("default", "static"):
_scaler_config = {"scale_value": 1024}
_scaler_config.update(scaler_config)
scaler = create_loss_scaler("static", **_scaler_config)
elif scaler.lower() in ("auto", "dynamic"):
scaler = create_loss_scaler("dynamic", **scaler_config)
elif scaler.lower() == "none":
scaler = create_loss_scaler("none", **scaler_config)
else:
raise NotImplementedError

self.scaler = scaler
self.reducer = reducer
self.is_zero = is_zero
self.all_finite = ms.amp.all_finite if not _is_cpu() else return_true
Expand Down Expand Up @@ -274,7 +275,7 @@ def do_optim(self, loss, grads):
return loss

def construct(self, *inputs):
loss, grads = self.value_and_grad(*inputs, self.scaler.scale_value)
loss, grads = self.value_and_grad(*inputs)
loss = self.scaler.unscale(loss)
loss = loss * self.accum_steps

Expand Down
2 changes: 1 addition & 1 deletion mindone/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = False # SDPA, not support yet
_supports_flex_attn = False # FlexAttention, not support yet
_supports_cache_class = True # set it True if use DynamicCache
_supports_cache_class = False # set it True if use DynamicCache
_supports_quantized_cache = False
_supports_static_cache = False # StaticCache, not used
_supports_attention_backend = True
Expand Down