Skip to content

[feat] Track entropy and MI of routing distribution for topk MoE #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]:
def loss_defs(self) -> list[LossDef]:
pass

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loss/metric split is way more complicated than needed. How about having a single entry, and using a is_metric flag in LossDef (or a derived class) to distinguish? Then no change is needed other than extracting metrics from the context before returning from run_step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be nice!

Maybe better to leave it for a separate pr? It would make this one larger as it would require also changing the interfaces of the models' forward functions (that expect losses and metrics) as well as making sure that metrics are only calculated when return_metrics is True.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't much change needed actually, just need to add kwargs["return_metrics"]. I would prefer doing this here so we don't grow ScheduleRunner too much.

def metric_defs(self) -> list[LossDef]:
return []

def add_preprocessor(self, preprocessor: Preprocessor):
# TODO: Generalize preprocessors.
raise NotImplementedError()
42 changes: 35 additions & 7 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import time
import typing
from typing import Callable

import torch
import torch.cuda
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
self._tied_parameters = self._multi_stage.tied_parameters
self._num_stages = len(self._stages)
self._loss_defs = {loss_def.name: loss_def for loss_def in self._multi_stage.base_model.loss_defs}
self._metric_defs = {metric_def.name: metric_def for metric_def in self._multi_stage.base_model.metric_defs}

def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None:
assert not self._is_setup
Expand Down Expand Up @@ -264,27 +266,53 @@ def run_step(
self._record_event(context, EventType.batch_end, None)
self._handle_events(context)

if metrics is not None:
metrics["loss_scale"] = self._optimizer.grad_scale

if self._multi_stage.config.multi_stage.debug_activation_memory:
log_pipeline_parallel_main_rank(
lambda: log_memory_usage(f"End of {context.phase.value} iteration {iteration}", str)
)
# All metrics comming out of forward pass are reduced by default.
if metrics is not None:
metrics = self._reduce_metrics(context)
metrics["loss_scale"] = self._optimizer.grad_scale

return self._reduce_losses(context), update_successful, metrics
return (
self._reduce_losses(context),
update_successful,
metrics,
)

def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]:
return self._reduce_metric_or_loss(
context,
lambda name: self._loss_defs[name].count,
"losses",
)

def _reduce_metrics(self, context: BatchContext) -> dict[str, float | int]:
return self._reduce_metric_or_loss(
context, lambda name: self._metric_defs[name].count, "metrics", lambda x: x in self._metric_defs
)

def _reduce_metric_or_loss(
self,
context: BatchContext,
check_count: Callable[[str], int],
reduce_attr: str = "losses",
check_reduce: Callable[[str], bool] = lambda _: True,
) -> dict[str, float | int]:
reduced_losses = {}
num_inputs = self._distributed_config.data_parallel * context.schedule.batch_config.num_inputs
for name, losses in context.losses.items():
for name, losses in context.__getattribute__(reduce_attr).items():
if not check_reduce(name):
reduced_losses[name] = losses
continue
if losses or self._distributed.pipeline_group:
if losses:
reduced_loss = torch.stack(losses).sum() / num_inputs / self._loss_defs[name].count
reduced_loss = torch.stack(losses).sum() / num_inputs / check_count(name)
if self._distributed.data_group:
all_reduce(reduced_loss, group=self._distributed.data_group)
else:
reduced_loss = torch.zeros([1], dtype=self._loss_defs[name].dtype, device=self._distributed.device)
reduced_loss = torch.zeros([1], dtype=check_count(name).dtype, device=self._distributed.device)
if self._distributed.pipeline_group:
all_reduce(reduced_loss, group=self._distributed.pipeline_group)
else:
Expand Down
10 changes: 10 additions & 0 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ class TransformerLossNames:
router_z_loss = "router_z_loss"


class TransformerRoutingMetrics:
normalized_average_entropy = "normalized_average_entropy"
mutual_info = "mutual_info"


class RotaryEmbeddingType(str, enum.Enum):
none = "none"
default = "default"
Expand Down Expand Up @@ -656,6 +661,11 @@ class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig):
" Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.",
hint=FieldHint.expert,
)
calculate_moe_metrics: bool = Field(
default=True,
desc="If 'True', will calculate the MoE metrics (entropy and MI) at each logging step.",
hint=FieldHint.logging,
)

def _validate(self) -> None:
if self.init_method_std is None:
Expand Down
56 changes: 55 additions & 1 deletion fast_llm/layers/transformer/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TransformerDimNames,
TransformerKwargs,
TransformerLossNames,
TransformerRoutingMetrics,
)
from fast_llm.layers.transformer.mlp import MLPBase
from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage
Expand All @@ -26,6 +27,41 @@
logger = logging.getLogger(__name__)


@torch.compile
def calculate_normalized_average_entropy(probs: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could try @torch.compile on these for a free performance boost.

"""
Calculates routing entropy for each token, then averages over all tokens.
If low, means a lot of mass is put on a single expert in all tokens, which can indicate collapse or specialization.
"""
n_experts = probs.size(-1)
entropy_values = calculate_entropy(probs)
average_entropy = entropy_values.mean() # Average over batch and tokens
return average_entropy / torch.log(torch.tensor(n_experts, dtype=probs.dtype, device=probs.device))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

average_entropy/math.log(n_experts) (same elsewhere)



@torch.compile
def calculate_entropy(probs: torch.Tensor) -> torch.Tensor:
probs = torch.clamp(probs, min=1e-9) # Avoid log(0)
return -torch.sum(probs * torch.log(probs), dim=-1)


@torch.compile
def calculate_mutual_information(probs: torch.Tensor) -> torch.Tensor:
"""
Calculates the difference between the entropy of the average routing and
the average routing entropy, we average across all tokens of all examples in the batch.
If low, means that routing is not informative.
"""
n_experts = probs.size(-1)
average_routing = torch.mean(probs.view(-1, n_experts), dim=0) # Average over tokens
entropy_avg_routing = calculate_entropy(average_routing) / torch.log(
torch.tensor(n_experts, dtype=probs.dtype)
) # H[E[X]]
entropy_routing = calculate_normalized_average_entropy(probs) # E[H[X]]

return entropy_avg_routing - entropy_routing


class MixtureOfExpertMLP(MLPBase):
"""
MoeLayer following implementation from
Expand All @@ -48,6 +84,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s
self._config = config
self._tensor_space = tensor_space
self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory
self._calculate_moe_metrics = config.calculate_moe_metrics

self._num_experts = config.num_experts
self._experts_per_token = config.num_experts_per_token
Expand Down Expand Up @@ -103,7 +140,9 @@ def forward(

# Routing
if self._routing_type == RoutingType.topk:
scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses)
scores, top_experts = self._topk_routing(
logits, kwargs.get(TransformerKwargs.grad_output), losses, metrics
)
if self._num_shared_experts > 0:
scores, top_experts = self._add_shared_experts(top_experts, scores)
elif self._routing_type == RoutingType.sinkhorn:
Expand Down Expand Up @@ -169,11 +208,26 @@ def _topk_routing(
logits: torch.Tensor,
grad_scale: float | None = None,
losses: dict | None = None,
metrics: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32)
if losses is not None or (self.training and grad_scale is not None):
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)

# Store these metrics
if metrics is not None and self._calculate_moe_metrics:
# Calculate and log entropy and mutual information
entropy = calculate_normalized_average_entropy(probs)
mutual_info = calculate_mutual_information(probs)
if TransformerRoutingMetrics.normalized_average_entropy not in metrics:
metrics[TransformerRoutingMetrics.normalized_average_entropy] = []
if TransformerRoutingMetrics.mutual_info not in metrics:
metrics[TransformerRoutingMetrics.mutual_info] = []

metrics[TransformerRoutingMetrics.normalized_average_entropy].append(entropy.detach())
metrics[TransformerRoutingMetrics.mutual_info].append(mutual_info.detach())

mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1)
# Auxiliary loss, corresponding to the sum of probabilities for the top experts.
# In the optimal case (uniform distribution), loss = experts_per_token / num_experts.
Expand Down
25 changes: 25 additions & 0 deletions fast_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TransformerDimNames,
TransformerKwargs,
TransformerLossNames,
TransformerRoutingMetrics,
)
from fast_llm.layers.transformer.preprocessing import (
BackupAttentionPreprocessor,
Expand Down Expand Up @@ -348,6 +349,7 @@ def loss_defs(self) -> list[LossDef]:
count=self._config.transformer.num_layers,
)
)

if self._config.logit_z_loss:
LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1)

Expand All @@ -361,6 +363,29 @@ def loss_defs(self) -> list[LossDef]:
)
return loss_defs

@property
def metric_defs(self) -> list[LossDef]:
metric_defs = []
if (
self._config.transformer.num_experts > 1
and self._config.transformer.expert_routing_type == RoutingType.topk
):
metric_defs.append(
LossDef(
name=TransformerRoutingMetrics.normalized_average_entropy,
formatted_name="Normalized Entropy",
count=self._config.transformer.num_layers,
)
)
metric_defs.append(
LossDef(
name=TransformerRoutingMetrics.mutual_info,
formatted_name="Mutual Information",
count=self._config.transformer.num_layers,
)
)
return metric_defs

def add_preprocessor(self, preprocessor: Preprocessor):
assert not self._is_setup
self._preprocessors.append(preprocessor)
Expand Down
Loading