Add test for backward step spatial parallelism#9
Add test for backward step spatial parallelism#9peterdschwartz wants to merge 4 commits intomainfrom
Conversation
use graph-aware all_reduce inside spatial mean
|
@mahf708 first try |
|
regression test is able to pass with these tolerances against NonDistributed baseline. # Tight tolerance is probably too strict; pick something reasonable.
torch.testing.assert_close(
loss,
baseline_loss,
rtol=5e-4,
atol=1e-5,
msg=f"Spatial-parallel loss deviates from baseline: \n{baseline_loss} <======> {loss}",
)
# 2) Gradients finite and close to baseline for each parameter.
for name, base_g in baseline_grads.items():
assert name in grads, f"Missing gradient for parameter {name}"
g = grads[name]
assert torch.isfinite(g).all(), f"Non-finite gradient detected for {name}"
try:
torch.testing.assert_close(
g,
base_g,
rtol=1e-3,
atol=1e-5,
msg=f"Gradient for {name} deviates from baseline",
)
except AssertionError as e:
# Extract some simple summary stats for debugging
diff = (g - base_g).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
raise AssertionError(
f"Gradient for {name} deviates from baseline:\n"
f" max |Δ| = {max_diff:.3e}, mean |Δ| = {mean_diff:.3e}\n"
f" grad sample (first 5): {g.flatten()[:5].tolist()}\n"
f" baseline sample (first 5): {base_g.flatten()[:5].tolist()}\n"
f" original error: {e}"
) from ecurrently the test needs to be run with NonDistributed to generate baseline I added a temporary helper script that can be used for spatial parallel version: EDIT: Oops the baseline grads are automatically zero'd out somewhere so it trivially passes. Checking.... |
|
Ok, I had to change the optimizer options so that it would update the gradients and save the gradients before they are zero'd out in So, having said that, there are really differences due to the |
There was a problem hiding this comment.
Pull request overview
Adds a new spatial-parallel backward/optimizer-step regression test and updates the spatial reduction primitive to use an autograd-aware all-reduce, with a helper script to run the test under torchrun.
Changes:
- Switch
ModelTorchDistributed.spatial_reduce_sumtotorch.distributed.nn.functional.all_reduceto preserve autograd graphs through spatial reductions. - Add a new parallel test exercising forward/backward/step with
LatLonOperations.area_weighted_meanunder spatial parallelism. - Add a
scripts/testing/test_spatial.shhelper to run the new test viatorchrun.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| scripts/testing/test_spatial.sh | Adds a torchrun-based entrypoint for running the spatial-parallel backward-step test. |
| fme/core/distributed/parallel_tests/test_backward_step.py | New spatial-parallel backward/step regression test using an on-disk baseline for loss/grad checks. |
| fme/core/distributed/model_torch_distributed.py | Uses graph-aware all-reduce for spatial sum reductions to support backprop through spatial means. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| return | ||
|
|
||
| if not BASELINE_FILE.exists(): | ||
| assert not spatial_parallel, "Generate Baseline using NonDistributed backend first!" |
There was a problem hiding this comment.
The test depends on testdata/backward_step_baseline.pt existing when run with the spatial backend, but that baseline artifact is not added in this PR. As written, running under FME_DISTRIBUTED_BACKEND=model will fail with the "Generate Baseline" assertion. Please add the baseline file under fme/core/distributed/parallel_tests/testdata/ (or rework the test to be analytic/deterministic like other parallel tests).
| assert not spatial_parallel, "Generate Baseline using NonDistributed backend first!" | |
| if spatial_parallel: | |
| pytest.skip( | |
| "Baseline file missing; run test with NonDistributed backend " | |
| "first to generate backward_step_baseline.pt." | |
| ) |
| ) -> tuple[nn.Module, torch.optim.Optimizer, LatLonOperations]: | ||
| """ | ||
| Build a DDP-wrapped TinyConvNet under ModelTorchDistributed and | ||
| a simple optimizer. Also returns LatLonOperations for computing a global loss. |
There was a problem hiding this comment.
The return type annotation for _build_model_and_optimizer says it returns a torch.optim.Optimizer, but the function actually returns an OptimizationConfig as the second element. This mismatch can break type checking and makes the helper misleading; please update the annotation/docstring to match the actual return values (or return the optimizer if that’s what you intended).
| ) -> tuple[nn.Module, torch.optim.Optimizer, LatLonOperations]: | |
| """ | |
| Build a DDP-wrapped TinyConvNet under ModelTorchDistributed and | |
| a simple optimizer. Also returns LatLonOperations for computing a global loss. | |
| ) -> tuple[nn.Module, OptimizationConfig, LatLonOperations]: | |
| """ | |
| Build a DDP-wrapped TinyConvNet under ModelTorchDistributed and | |
| an optimization config. Also returns LatLonOperations for computing a global loss. |
| # Wrap with DDP only if spatial distributed backend is active. | ||
| spatial_parallel = isinstance(dist._distributed, ModelTorchDistributed) | ||
| if spatial_parallel: | ||
| model = dist._distributed.wrap_module(model) | ||
|
|
There was a problem hiding this comment.
This test reaches into dist._distributed (a private implementation detail) to detect the backend and to wrap the module. To avoid coupling the test to internal attributes, prefer using the public Distributed.wrap_module(model) API and derive whether spatial parallelism is active from public properties (e.g., compare world_size vs total_data_parallel_ranks) instead of dist._distributed type checks.
|
So i did refactor the test to use SHT/iSHT instead. Currently, the gradients are off relative to baseline by exactly the number of nodes - so i'm investigating the source of that |
backwards step Add test that verifies consistency between NonDistribute and TorchModelDistributed for loss and gradient calculation using simple SHT/iSHT transforms
6987e71 to
3e3943e
Compare
|
OK!! Not so simple after all. So the gradient norms were an exact multiple of the number of ranks because of this: It is considered expected and correct behavior. So i had to implement a new autograd wrapper so that it wouldn't sum twice during the backward pass: class SpatialReplicatedSum(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, group):
ctx.group = group
y = x.clone()
dist.all_reduce(y, op=dist.ReduceOp.SUM, group=group)
return y
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Important: do NOT all-reduce again here.
# The forward result is a replicated view of one logical global value.
return grad_output, None |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 4 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
use graph-aware all_reduce inside spatial mean
Short description of why the PR is needed and how it satisfies those requirements, in sentence form.
Changes:
symbol (e.g.
fme.core.my_function) or script and concise description of changes or added featureCan group multiple related symbols on a single bullet
Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated
Resolves # (delete if none)