diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 76da0f9b..a1c6f1f4 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -136,6 +136,10 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: def loss_defs(self) -> list[LossDef]: pass + @property + def metric_defs(self) -> list[LossDef]: + return [] + def add_preprocessor(self, preprocessor: Preprocessor): # TODO: Generalize preprocessors. raise NotImplementedError() diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 1d4b04c1..0bd7aefc 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -3,6 +3,7 @@ import logging import time import typing +from typing import Callable import torch import torch.cuda @@ -94,6 +95,7 @@ def __init__( self._tied_parameters = self._multi_stage.tied_parameters self._num_stages = len(self._stages) self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs} + self._metric_defs = {metric_def.name: metric_def for metric_def in self._multi_stage.base_model.metric_defs} def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: assert not self._is_setup @@ -264,27 +266,53 @@ def run_step( self._record_event(context, EventType.batch_end, None) self._handle_events(context) - if metrics is not None: - metrics["loss_scale"] = self._optimizer.grad_scale - if self._multi_stage.config.multi_stage.debug_activation_memory: log_pipeline_parallel_main_rank( lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str) ) + # All metrics comming out of forward pass are reduced by default. + if metrics is not None: + metrics = self._reduce_metrics(context) + metrics["loss_scale"] = self._optimizer.grad_scale - return self._reduce_losses(context), update_successful, metrics + return ( + self._reduce_losses(context), + update_successful, + metrics, + ) def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]: + return self._reduce_metric_or_loss( + context, + lambda name: self._loss_defs[name].count, + "losses", + ) + + def _reduce_metrics(self, context: BatchContext) -> dict[str, float | int]: + return self._reduce_metric_or_loss( + context, lambda name: self._metric_defs[name].count, "metrics", lambda x: x in self._metric_defs + ) + + def _reduce_metric_or_loss( + self, + context: BatchContext, + check_count: Callable[[str], int], + reduce_attr: str = "losses", + check_reduce: Callable[[str], bool] = lambda _: True, + ) -> dict[str, float | int]: reduced_losses = {} num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs - for name, losses in context.losses.items(): + for name, losses in context.__getattribute__(reduce_attr).items(): + if not check_reduce(name): + reduced_losses[name] = losses + continue if losses or self._distributed.pipeline_group: if losses: - reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_defs[name].count + reduced_loss = torch.stack(losses).sum() / num_inputs / check_count(name) if self._distributed.data_group: all_reduce(reduced_loss, group=self._distributed.data_group) else: - reduced_loss = torch.zeros([1], dtype=self._loss_defs[name].dtype, device=self._distributed.device) + reduced_loss = torch.zeros([1], dtype=check_count(name).dtype, device=self._distributed.device) if self._distributed.pipeline_group: all_reduce(reduced_loss, group=self._distributed.pipeline_group) else: diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 4806e37e..fcaac2d9 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -93,6 +93,11 @@ class TransformerLossNames: router_z_loss = "router_z_loss" +class TransformerRoutingMetrics: + normalized_average_entropy = "normalized_average_entropy" + mutual_info = "mutual_info" + + class RotaryEmbeddingType(str, enum.Enum): none = "none" default = "default" @@ -656,6 +661,11 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + calculate_moe_metrics: bool = Field( + default=True, + desc="If 'True', will calculate the MoE metrics (entropy and MI) at each logging step.", + hint=FieldHint.logging, + ) def _validate(self) -> None: if self.init_method_std is None: diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f..fd4cd825 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -17,6 +17,7 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, + TransformerRoutingMetrics, ) from fast_llm.layers.transformer.mlp import MLPBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage @@ -26,6 +27,41 @@ logger = logging.getLogger(__name__) +@torch.compile +def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor: + """ + Calculates routing entropy for each token, then averages over all tokens. + If low, means a lot of mass is put on a single expert in all tokens, which can indicate collapse or specialization. + """ + n_experts = probs.size(-1) + entropy_values = calculate_entropy(probs) + average_entropy = entropy_values.mean() # Average over batch and tokens + return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device)) + + +@torch.compile +def calculate_entropy(probs: torch.Tensor) -> torch.Tensor: + probs = torch.clamp(probs, min=1e-9) # Avoid log(0) + return -torch.sum(probs * torch.log(probs), dim=-1) + + +@torch.compile +def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor: + """ + Calculates the difference between the entropy of the average routing and + the average routing entropy, we average across all tokens of all examples in the batch. + If low, means that routing is not informative. + """ + n_experts = probs.size(-1) + average_routing = torch.mean(probs.view(-1, n_experts), dim=0) # Average over tokens + entropy_avg_routing = calculate_entropy(average_routing) / torch.log( + torch.tensor(n_experts, dtype=probs.dtype) + ) # H[E[X]] + entropy_routing = calculate_normalized_average_entropy(probs) # E[H[X]] + + return entropy_avg_routing - entropy_routing + + class MixtureOfExpertMLP(MLPBase): """ MoeLayer following implementation from @@ -48,6 +84,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory + self._calculate_moe_metrics = config.calculate_moe_metrics self._num_experts = config.num_experts self._experts_per_token = config.num_experts_per_token @@ -103,7 +140,9 @@ def forward( # Routing if self._routing_type == RoutingType.topk: - scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses) + scores, top_experts = self._topk_routing( + logits, kwargs.get(TransformerKwargs.grad_output), losses, metrics + ) if self._num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) elif self._routing_type == RoutingType.sinkhorn: @@ -169,11 +208,26 @@ def _topk_routing( logits: torch.Tensor, grad_scale: float | None = None, losses: dict | None = None, + metrics: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + + # Store these metrics + if metrics is not None and self._calculate_moe_metrics: + # Calculate and log entropy and mutual information + entropy = calculate_normalized_average_entropy(probs) + mutual_info = calculate_mutual_information(probs) + if TransformerRoutingMetrics.normalized_average_entropy not in metrics: + metrics[TransformerRoutingMetrics.normalized_average_entropy] = [] + if TransformerRoutingMetrics.mutual_info not in metrics: + metrics[TransformerRoutingMetrics.mutual_info] = [] + + metrics[TransformerRoutingMetrics.normalized_average_entropy].append(entropy.detach()) + metrics[TransformerRoutingMetrics.mutual_info].append(mutual_info.detach()) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 7e5c5d33..84288e3e 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -20,6 +20,7 @@ TransformerDimNames, TransformerKwargs, TransformerLossNames, + TransformerRoutingMetrics, ) from fast_llm.layers.transformer.preprocessing import ( BackupAttentionPreprocessor, @@ -348,6 +349,7 @@ def loss_defs(self) -> list[LossDef]: count=self._config.transformer.num_layers, ) ) + if self._config.logit_z_loss: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) @@ -361,6 +363,29 @@ def loss_defs(self) -> list[LossDef]: ) return loss_defs + @property + def metric_defs(self) -> list[LossDef]: + metric_defs = [] + if ( + self._config.transformer.num_experts > 1 + and self._config.transformer.expert_routing_type == RoutingType.topk + ): + metric_defs.append( + LossDef( + name=TransformerRoutingMetrics.normalized_average_entropy, + formatted_name="Normalized Entropy", + count=self._config.transformer.num_layers, + ) + ) + metric_defs.append( + LossDef( + name=TransformerRoutingMetrics.mutual_info, + formatted_name="Mutual Information", + count=self._config.transformer.num_layers, + ) + ) + return metric_defs + def add_preprocessor(self, preprocessor: Preprocessor): assert not self._is_setup self._preprocessors.append(preprocessor) diff --git a/tests/test_moe_metrics.py b/tests/test_moe_metrics.py new file mode 100644 index 00000000..a9571775 --- /dev/null +++ b/tests/test_moe_metrics.py @@ -0,0 +1,160 @@ +import torch + +from fast_llm.layers.transformer.mixture_of_experts import ( + calculate_mutual_information, + calculate_normalized_average_entropy, +) + + +def test_diversity_entropy(): + """ + collapse routing would have low entropy and low mutual information + """ + + collapased_probs = torch.tensor( + [ + # Batch 1 + [ + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + ], + # Batch 2 + [ + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.99, 0.01, 0.0, 0.0], + ], + ] + ) + norm_entropy = calculate_normalized_average_entropy(collapased_probs) + mutual_info = calculate_mutual_information(collapased_probs) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {mutual_info}" + + # diverse but no collapse + # should give low entropy and high mutual information + diverse_probs = torch.tensor( + [ + # Batch 1 + [ + [0.99, 0.01, 0.0, 0.0], + [0.01, 0.99, 0.0, 0.0], + [0.01, 0.01, 0.99, 0.0], + ], + # Batch 2 + [ + [0.01, 0.01, 0.99, 0.0], + [0.99, 0.01, 0.0, 0.0], + [0.01, 0.01, 0.01, 0.99], + ], + ] + ) + norm_entropy = calculate_normalized_average_entropy(diverse_probs) + mutual_info = calculate_mutual_information(diverse_probs) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-1), f"Expected 0.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.9), atol=1e-1), f"Expected 1.0, got {mutual_info}" + + +def test_calculate_normalized_average_entropy(): + # AI generated test case + # Create a batch of routing probabilities + batch_size = 2 + seq_len = 3 + n_experts = 4 + + # Test 1: Uniform distribution (should give normalized entropy of 1.0) + uniform_probs = torch.ones(batch_size, seq_len, n_experts) / n_experts + norm_entropy = calculate_normalized_average_entropy(uniform_probs) + assert torch.isclose(norm_entropy, torch.tensor(1.0), atol=1e-5), f"Expected 1.0, got {norm_entropy}" + + # Test 2: One-hot distribution (should give normalized entropy of 0.0) + one_hot = torch.zeros(batch_size, seq_len, n_experts) + for b in range(batch_size): + for s in range(seq_len): + one_hot[b, s, b % n_experts] = 1.0 + norm_entropy = calculate_normalized_average_entropy(one_hot) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected 0.0, got {norm_entropy}" + + # Test 3: Mixed distribution + mixed_probs = torch.tensor( + [ + # Batch 1 + [ + [0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 + [0.25, 0.25, 0.25, 0.25], # Token 3: uniform + ], + # Batch 2 + [ + [0.4, 0.4, 0.1, 0.1], # Token 1: split between experts 0 and 1 + [0.1, 0.1, 0.4, 0.4], # Token 2: split between experts 2 and 3 + [0.1, 0.1, 0.1, 0.7], # Token 3: mostly expert 3 + ], + ] + ) + norm_entropy = calculate_normalized_average_entropy(mixed_probs) + # The expected value is between 0 and 1 + assert 0.0 < norm_entropy < 1.0, f"Expected value between 0 and 1, got {norm_entropy}" + + +def test_calculate_mutual_information(): + # AI generated test cases + # Create a batch of routing probabilities + batch_size = 2 + seq_len = 3 + n_experts = 4 + + # Test 1: All tokens route to the same expert (low mutual information) + same_expert = torch.zeros(batch_size, seq_len, n_experts) + same_expert[:, :, 0] = 1.0 # All tokens route to expert 0 + mutual_info = calculate_mutual_information(same_expert) + assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}" + + # Test 2: Each token routes to a different expert (high mutual information) + different_experts = torch.zeros(batch_size, seq_len, n_experts) + for b in range(batch_size): + for s in range(seq_len): + different_experts[b, s, s % n_experts] = 1.0 + mutual_info = calculate_mutual_information(different_experts) + # The value should be positive and closer to 1 + assert mutual_info > 0.0, f"Expected positive value, got {mutual_info}" + + # Test 3: Mixed routing pattern + mixed_probs = torch.tensor( + [ + # Batch 1 + [ + [0.7, 0.1, 0.1, 0.1], # Token 1: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 2: mostly expert 1 + [0.1, 0.1, 0.7, 0.1], # Token 3: mostly expert 2 + ], + # Batch 2 + [ + [0.1, 0.1, 0.1, 0.7], # Token 1: mostly expert 3 + [0.7, 0.1, 0.1, 0.1], # Token 2: mostly expert 0 + [0.1, 0.7, 0.1, 0.1], # Token 3: mostly expert 1 + ], + ] + ) + mutual_info = calculate_mutual_information(mixed_probs) + # The expected value is between 0 and 1 + assert 0.0 < mutual_info < 1.0, f"Expected value between 0 and 1, got {mutual_info}" + + +def test_small_seq_length_batch_size_probabilities(): + # AI generated test cases + # Test with very small batch and sequence length + tiny_probs = torch.tensor([[[0.25, 0.25, 0.25, 0.25]]]) # batch=1, seq_len=1, n_experts=4 + norm_entropy = calculate_normalized_average_entropy(tiny_probs) + mutual_info = calculate_mutual_information(tiny_probs) + assert torch.isclose(norm_entropy, torch.tensor(1.0)), f"Expected 1.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.0)), f"Expected 0.0, got {mutual_info}" + + # Test with very small probabilities + small_probs = torch.ones(2, 3, 4) * 1e-8 + small_probs[:, :, 0] = 1.0 - 3e-8 # Make sure they sum to 1 + norm_entropy = calculate_normalized_average_entropy(small_probs) + mutual_info = calculate_mutual_information(small_probs) + assert torch.isclose(norm_entropy, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {norm_entropy}" + assert torch.isclose(mutual_info, torch.tensor(0.0), atol=1e-5), f"Expected ~0.0, got {mutual_info}" diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 00000000..1f22141c --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,93 @@ +from unittest import mock + +import pytest +import torch + +from fast_llm.engine.base_model.base_model import LossDef +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.multi_stage.multi_stage import MultiStageModel +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.runner import BatchContext, ScheduleRunner +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.layers.transformer.config import TransformerRoutingMetrics + + +@pytest.fixture +def setup_runner(): + """ + Fixture to set up the test environment. + TODO: Leave it here for now, but may be moved to common.py + """ + # Mock objects needed for testing + distributed_config = DistributedConfig() + + # Mock MultiStageModel with loss_defs + multi_stage = mock.MagicMock(spec=MultiStageModel) + multi_stage.base_model.loss_defs = [LossDef(name="test_loss", formatted_name="Test Loss", count=1)] + multi_stage.base_model.metric_defs = [ + LossDef( + name=TransformerRoutingMetrics.normalized_average_entropy, formatted_name="Normalized Entropy", count=1 + ), + LossDef(name=TransformerRoutingMetrics.mutual_info, formatted_name="Mutual Information", count=1), + ] + + # Create a schedule runner + schedule_config = ScheduleConfig() + runner = ScheduleRunner(config=schedule_config, multi_stage=multi_stage, distributed_config=distributed_config) + + # Mock distributed object + distributed = mock.MagicMock(spec=Distributed) + distributed.config = distributed_config + distributed.device = torch.device("cpu") + distributed.data_group = None + distributed.pipeline_group = None + + # Setup the runner + runner._distributed = distributed + runner.is_initialized = True + + # Create a mock schedule + schedule = mock.MagicMock(spec=Schedule) + schedule.phase = PhaseType.training + schedule.batch_config.num_inputs = 3 + schedule._schedule_config = schedule_config + + # Create a batch context with metrics and losses + context = BatchContext( + iteration=1, + schedule=schedule, + ) + + return runner, context, schedule + + +def test_reduce_losses(setup_runner): + """Test that _reduce_losses correctly reduces losses""" + runner, context, _ = setup_runner + + # Add test metrics + context.metrics = { + # Metrics that should be reduced (in TransformerReducedMetrics) + TransformerRoutingMetrics.normalized_average_entropy: [ + torch.tensor(0.5), + torch.tensor(0.6), + torch.tensor(0.7), + ], + TransformerRoutingMetrics.mutual_info: [torch.tensor(0.2), torch.tensor(0.3), torch.tensor(0.4)], + # Metric that should not be reduced as its not registered in metric_defs of the base model + "non_reduced_metric": [torch.tensor(1.0), torch.tensor(1.0), torch.tensor(1.0)], + } + + # Add test losses + context.losses = {"test_loss": [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)]} + + reduced_losses = runner._reduce_losses(context) + reduced_metrics = runner._reduce_metrics(context) + + assert "test_loss" in reduced_losses + assert pytest.approx(reduced_losses["test_loss"], 0.001) == 2.0 + assert "non_reduced_metric" in reduced_metrics + assert pytest.approx(reduced_metrics["normalized_average_entropy"], 0.01) == 0.6 + assert pytest.approx(reduced_metrics["mutual_info"], 0.01) == 0.3 + assert isinstance(reduced_metrics["non_reduced_metric"], list)