Skip to content

Commit 9f23889

Browse files
JenniferWangfacebook-github-bot
authored andcommitted
move weight update validation functions to util (#573)
Summary: * Fix the weight update test * Extract common logic to a separate util function; see the next diff D87083010 for how to use them in verifying weights do get updated as part of infra verification when debugging a buggy run. Reviewed By: casteryh Differential Revision: D87005971
1 parent 3d69189 commit 9f23889

File tree

3 files changed

+285
-69
lines changed

3 files changed

+285
-69
lines changed

src/forge/actors/generator.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -579,16 +579,16 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
579579
await stop_proc_mesh(actor._fetcher_procs)
580580

581581
@endpoint
582-
async def _test_save_model_params(self):
583-
"""Save model parameters before weight update, used for tesing purposes only."""
582+
async def save_model_params(self):
583+
"""Save model parameters before weight update, used for testing purposes only."""
584584
logger.info("[Generator] save model parameters for testing.")
585-
await self.worker._test_save_model_params.call()
585+
await self.worker.save_model_params.call()
586586

587587
@endpoint
588-
async def _test_validate_model_params(self, validate_fn):
588+
async def validate_model_params(self, validate_fn):
589589
"""Validate updated model params using validate_fn."""
590590
logger.info("[Generator] start validating model parameters.")
591-
return await self.worker._test_validate_model_params.call(validate_fn)
591+
return await self.worker.validate_model_params.call(validate_fn)
592592

593593

594594
@dataclass
@@ -604,6 +604,9 @@ class GeneratorWorker(ForgeActor):
604604
# TODO: Remove below param
605605
_test_prev_params = {}
606606

607+
def __post_init__(self):
608+
super().__init__()
609+
607610
@endpoint
608611
async def setup(self):
609612
self.rank = current_rank().rank
@@ -720,8 +723,8 @@ async def update_weights(
720723
t.stop()
721724

722725
@endpoint
723-
async def _test_save_model_params(self):
724-
"""Save model parameters before weight update, used for tesing purposes only."""
726+
async def save_model_params(self):
727+
"""Save model parameters before weight update, used for testing purposes only."""
725728
logger.info("[GeneratorWorker] save model parameters for testing.")
726729
for name, param in self.worker.model_runner.model.named_parameters():
727730
self._test_prev_params[name] = param.detach().cpu()
@@ -731,7 +734,7 @@ async def _test_save_model_params(self):
731734
)
732735

733736
@endpoint
734-
async def _test_validate_model_params(self, validate_fn):
737+
async def validate_model_params(self, validate_fn):
735738
"""Validate updated model params using validate_fn."""
736739
logger.info("[GeneratorWorker] start validating model parameters.")
737740
return validate_fn(
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Utilities for verifying model weight updates during training."""
8+
9+
import logging
10+
from dataclasses import dataclass
11+
from typing import Any
12+
13+
import torch
14+
import torch.nn as nn
15+
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@dataclass
21+
class WeightSnapshot:
22+
"""Snapshot of model weights at a specific point in time."""
23+
24+
params: dict[str, torch.Tensor]
25+
version: int | None = None
26+
metadata: dict[str, Any] | None = None
27+
28+
@classmethod
29+
def from_model(
30+
cls, model: nn.Module, version: int | None = None, device: str = "cpu"
31+
) -> "WeightSnapshot":
32+
"""Create a snapshot of model parameters.
33+
34+
Args:
35+
model: PyTorch model to snapshot
36+
version: Optional version identifier
37+
device: Device to store snapshot tensors (default: cpu)
38+
39+
Returns:
40+
WeightSnapshot containing detached copies of all parameters
41+
"""
42+
params = {}
43+
for name, param in model.named_parameters():
44+
params[name] = param.detach().to(device).clone()
45+
46+
return cls(params=params, version=version)
47+
48+
49+
@dataclass
50+
class WeightVerificationResult:
51+
"""Result of weight verification check."""
52+
53+
weights_changed: bool
54+
num_params_checked: int
55+
num_params_changed: int
56+
num_params_unchanged: int
57+
num_params_skipped: int
58+
changed_params: list[str]
59+
unchanged_params: list[str]
60+
skipped_params: list[str]
61+
max_delta: float | None = None
62+
mean_delta: float | None = None
63+
64+
def __str__(self) -> str:
65+
status = "✅ CHANGED" if self.weights_changed else "⚠️ UNCHANGED"
66+
max_delta = f"{self.max_delta:.6e}" if self.max_delta is not None else "N/A"
67+
mean_delta = f"{self.mean_delta:.6e}" if self.mean_delta is not None else "N/A"
68+
69+
return (
70+
f"Weight Verification {status}:\n"
71+
f" Checked: {self.num_params_checked}\n"
72+
f" Changed: {self.num_params_changed}\n"
73+
f" Unchanged: {self.num_params_unchanged}\n"
74+
f" Skipped: {self.num_params_skipped}\n"
75+
f" Max delta: {max_delta}\n"
76+
f" Mean delta: {mean_delta}"
77+
)
78+
79+
80+
def verify_weights_changed(
81+
prev_snapshot: WeightSnapshot,
82+
current_model: nn.Module,
83+
atol: float = 1e-6,
84+
rtol: float = 1e-5,
85+
skip_non_float: bool = True,
86+
verbose: bool = False,
87+
) -> WeightVerificationResult:
88+
"""Verify that model weights have changed compared to a previous snapshot.
89+
90+
This is a more robust verification than simple parameter hashing, as it:
91+
- Checks each parameter individually
92+
- Uses proper floating point comparison (torch.allclose)
93+
- Provides detailed information about which parameters changed
94+
- Computes statistics about the magnitude of changes
95+
96+
Args:
97+
prev_snapshot: Previous weight snapshot to compare against
98+
current_model: Current model to check
99+
atol: Absolute tolerance for considering weights unchanged
100+
rtol: Relative tolerance for considering weights unchanged
101+
skip_non_float: Whether to skip non-floating point parameters
102+
verbose: Whether to log detailed information
103+
104+
Returns:
105+
WeightVerificationResult with detailed information about changes
106+
"""
107+
changed_params = []
108+
unchanged_params = []
109+
skipped_params = []
110+
deltas = []
111+
112+
for name, param in current_model.named_parameters():
113+
if skip_non_float and not torch.is_floating_point(param):
114+
skipped_params.append(name)
115+
if verbose:
116+
logger.info(f"Skipping non-float param: {name}")
117+
continue
118+
119+
if name not in prev_snapshot.params:
120+
logger.warning(f"Parameter {name} not found in previous snapshot")
121+
skipped_params.append(name)
122+
continue
123+
124+
prev_param = prev_snapshot.params[name]
125+
curr_param = param.detach().cpu()
126+
127+
# Check if parameters are close (i.e., unchanged)
128+
is_close = torch.allclose(prev_param, curr_param, atol=atol, rtol=rtol)
129+
130+
if is_close:
131+
unchanged_params.append(name)
132+
else:
133+
changed_params.append(name)
134+
# Compute delta for statistics
135+
delta = (curr_param - prev_param).abs().max().item()
136+
deltas.append(delta)
137+
138+
if verbose:
139+
logger.info(
140+
f"Parameter {name} changed - max delta: {delta:.6e}, "
141+
f"mean delta: {(curr_param - prev_param).abs().mean().item():.6e}"
142+
)
143+
144+
# Compute statistics
145+
max_delta = max(deltas) if deltas else 0
146+
mean_delta = sum(deltas) / len(deltas) if deltas else 0
147+
148+
result = WeightVerificationResult(
149+
weights_changed=len(changed_params) > 0,
150+
num_params_checked=len(changed_params) + len(unchanged_params),
151+
num_params_changed=len(changed_params),
152+
num_params_unchanged=len(unchanged_params),
153+
num_params_skipped=len(skipped_params),
154+
changed_params=changed_params,
155+
unchanged_params=unchanged_params,
156+
skipped_params=skipped_params,
157+
max_delta=max_delta,
158+
mean_delta=mean_delta,
159+
)
160+
161+
logger.info(str(result))
162+
163+
return result
164+
165+
166+
def verify_weights_all_zeros(
167+
current_model: nn.Module,
168+
atol: float = 1e-4,
169+
rtol: float = 1e-3,
170+
skip_non_float: bool = True,
171+
verbose: bool = False,
172+
) -> tuple[bool, list[str], list[str]]:
173+
"""Verify that all model parameters are zero.
174+
175+
Args:
176+
current_model: Model to check
177+
atol: Absolute tolerance
178+
rtol: Relative tolerance
179+
skip_non_float: Whether to skip non-floating point parameters
180+
verbose: Whether to log detailed information
181+
182+
Returns:
183+
Tuple of (all_zeros, zero_params, non_zero_params)
184+
"""
185+
zero_params = []
186+
non_zero_params = []
187+
188+
for name, param in current_model.named_parameters():
189+
if skip_non_float and not torch.is_floating_point(param):
190+
if verbose:
191+
logger.info(f"Skipping non-float param: {name}")
192+
continue
193+
194+
param_cpu = param.detach().cpu()
195+
is_zero = torch.allclose(
196+
torch.zeros_like(param_cpu), param_cpu, atol=atol, rtol=rtol
197+
)
198+
199+
if is_zero:
200+
zero_params.append(name)
201+
else:
202+
non_zero_params.append(name)
203+
if verbose:
204+
logger.info(
205+
f"Parameter {name} is not zero - "
206+
f"max: {param_cpu.abs().max().item():.6e}, "
207+
f"mean: {param_cpu.abs().mean().item():.6e}"
208+
)
209+
210+
all_zeros = len(non_zero_params) == 0
211+
212+
logger.info(
213+
f"Zero check: {'✅ PASS' if all_zeros else '⚠️ FAIL'} - "
214+
f"{len(zero_params)} zero, {len(non_zero_params)} non-zero"
215+
)
216+
217+
return all_zeros, zero_params, non_zero_params

0 commit comments

Comments
 (0)