Skip to content

Commit 3f0f132

Browse files
authored
[train] moe support aux_loss (#5187)
1 parent 11a345d commit 3f0f132

File tree

4 files changed

+19
-1
lines changed

4 files changed

+19
-1
lines changed

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@
162162
- 🔥report_to: 默认值为`tensorboard`。你也可以指定`--report_to tensorboard wandb swanlab``--report_to all`
163163
- logging_first_step: 是否记录第一个step的日志,默认为True。
164164
- logging_steps: 日志打印间隔,默认为5。
165+
- router_aux_loss_coef: 用于moe模型训练时,设置 aux_loss 的权重。默认为None,使用config中值。若设置为0,则不计算 aux_loss。
165166
- logging_dir: tensorboard日志路径。默认为None,即设置为`f'{self.output_dir}/runs'`
166167
- predict_with_generate: 验证时使用生成式的方式,默认为False。
167168
- metric_for_best_model: 默认为None,即当`predict_with_generate`设置为False时,设置为'loss',否则设置为'rouge-l'(在PPO训练时,不进行默认值设置;GRPO训练设置为'reward')。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ This parameter list inherits from transformers `Seq2SeqTrainingArguments`, with
165165
- 🔥report_to: Default value is `tensorboard`. You can also specify `--report_to tensorboard wandb swanlab` or `--report_to all`.
166166
- logging_first_step: Whether to log the first step, defaults to True.
167167
- logging_steps: Interval for logging, defaults to 5.
168+
- router_aux_loss_coef: Weight for aux_loss when training MoE models. Defaults to None, meaning the value from the config is used. If set to 0, aux_loss is not computed.
168169
- logging_dir: The path for TensorBoard logs. Defaults to None, which means it is set to `f'{self.output_dir}/runs'`.
169170
- predict_with_generate: Whether to use generative method during validation, default is False.
170171
- metric_for_best_model: Default is None, which means that when predict_with_generate is set to False, it is set to 'loss'; otherwise, it is set to 'rouge-l' (during PPO training, the default value is not set; in GRPO training, it is set to 'reward').

swift/trainers/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class TrainArgumentsMixin:
3030
gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None
3131
logging_first_step: bool = True
3232
logging_steps: int = 5
33+
router_aux_loss_coef: Optional[float] = None
3334

3435
weight_decay: float = 0.1
3536
adam_beta2: float = 0.95

swift/trainers/trainers.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
# Part of the implementation is borrowed from huggingface/transformers.
3+
import inspect
34
import os
45
from contextlib import contextmanager, nullcontext
56
from functools import wraps
@@ -15,6 +16,7 @@
1516
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
1617
from transformers.utils import is_peft_available
1718

19+
from swift.plugin import MeanMetric
1820
from swift.utils import JsonlWriter, Serializer, gc_collect, get_logger, unwrap_model_for_generation
1921
from .arguments import Seq2SeqTrainingArguments, TrainingArguments
2022
from .mixin import DataLoaderMixin, SwiftMixin
@@ -335,6 +337,16 @@ def _prepare_inputs(self, inputs):
335337
if self.args.tuner_backend == 'unsloth':
336338
inputs['logits_to_keep'] = int(logits_to_keep.sum())
337339

340+
if self.model.model_info.is_moe_model:
341+
base_model = self.template.get_base_model(self.model)
342+
router_aux_loss_coef = self.args.router_aux_loss_coef
343+
if router_aux_loss_coef is None:
344+
router_aux_loss_coef = getattr(base_model.config, 'router_aux_loss_coef', None)
345+
if router_aux_loss_coef is not None:
346+
base_model.config.router_aux_loss_coef = router_aux_loss_coef
347+
if router_aux_loss_coef > 0 and 'output_router_logits' in inspect.signature(
348+
base_model.forward).parameters:
349+
inputs['output_router_logits'] = True
338350
inputs['compute_loss_func'] = compute_loss_func
339351
inputs['loss_kwargs'] = loss_kwargs
340352
return inputs
@@ -346,8 +358,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
346358

347359
if (self.label_smoother is not None or compute_loss_func is not None) and 'labels' in inputs:
348360
labels = inputs.pop('labels')
349-
350361
outputs = model(**inputs)
362+
if getattr(outputs, 'aux_loss', None) is not None:
363+
if 'aux_loss' not in self._custom_metrics:
364+
self._custom_metrics['aux_loss'] = MeanMetric(nan_value=None)
365+
self._custom_metrics['aux_loss'].update(outputs.aux_loss)
351366
# Save past state if it exists
352367
# TODO: this needs to be fixed and made cleaner later.
353368
if self.args.past_index >= 0:

0 commit comments

Comments
 (0)