Skip to content

Add test for backward step spatial parallelism#9

Open
peterdschwartz wants to merge 4 commits intomainfrom
sp-backward-step
Open

Add test for backward step spatial parallelism#9
peterdschwartz wants to merge 4 commits intomainfrom
sp-backward-step

Conversation

@peterdschwartz
Copy link

use graph-aware all_reduce inside spatial mean

Short description of why the PR is needed and how it satisfies those requirements, in sentence form.

Changes:

  • symbol (e.g. fme.core.my_function) or script and concise description of changes or added feature

  • Can group multiple related symbols on a single bullet

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Resolves # (delete if none)

use graph-aware all_reduce inside spatial mean
@peterdschwartz
Copy link
Author

@mahf708 first try

@peterdschwartz
Copy link
Author

peterdschwartz commented Mar 10, 2026

regression test is able to pass with these tolerances against NonDistributed baseline.

    # Tight tolerance is probably too strict; pick something reasonable.
    torch.testing.assert_close(
        loss,
        baseline_loss,
        rtol=5e-4,
        atol=1e-5,
        msg=f"Spatial-parallel loss deviates from baseline: \n{baseline_loss} <======> {loss}",
    )

    # 2) Gradients finite and close to baseline for each parameter.
    for name, base_g in baseline_grads.items():
        assert name in grads, f"Missing gradient for parameter {name}"
        g = grads[name]
        assert torch.isfinite(g).all(), f"Non-finite gradient detected for {name}"
        try:
            torch.testing.assert_close(
                g,
                base_g,
                rtol=1e-3,
                atol=1e-5,
                msg=f"Gradient for {name} deviates from baseline",
            )
        except AssertionError as e:
            # Extract some simple summary stats for debugging
            diff = (g - base_g).abs()
            max_diff = diff.max().item()
            mean_diff = diff.mean().item()
            raise AssertionError(
                f"Gradient for {name} deviates from baseline:\n"
                f"  max |Δ| = {max_diff:.3e}, mean |Δ| = {mean_diff:.3e}\n"
                f"  grad sample (first 5): {g.flatten()[:5].tolist()}\n"
                f"  baseline sample (first 5): {base_g.flatten()[:5].tolist()}\n"
                f"  original error: {e}"
            ) from e

currently the test needs to be run with NonDistributed to generate baseline
pytest fme/core/distributed/parallel_tests/test_backward_step.py::test_spatial_parallel_backward_step

I added a temporary helper script that can be used for spatial parallel version:
H=2 W=2 ./scripts/testing/test_spatial.sh

EDIT: Oops the baseline grads are automatically zero'd out somewhere so it trivially passes. Checking....

@peterdschwartz
Copy link
Author

Ok, I had to change the optimizer options so that it would update the gradients and save the gradients before they are zero'd out in step_weights function.j

So, having said that, there are really differences due to the Conv2d not being aware of spatial sharding as @mahf708 predicted.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new spatial-parallel backward/optimizer-step regression test and updates the spatial reduction primitive to use an autograd-aware all-reduce, with a helper script to run the test under torchrun.

Changes:

  • Switch ModelTorchDistributed.spatial_reduce_sum to torch.distributed.nn.functional.all_reduce to preserve autograd graphs through spatial reductions.
  • Add a new parallel test exercising forward/backward/step with LatLonOperations.area_weighted_mean under spatial parallelism.
  • Add a scripts/testing/test_spatial.sh helper to run the new test via torchrun.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
scripts/testing/test_spatial.sh Adds a torchrun-based entrypoint for running the spatial-parallel backward-step test.
fme/core/distributed/parallel_tests/test_backward_step.py New spatial-parallel backward/step regression test using an on-disk baseline for loss/grad checks.
fme/core/distributed/model_torch_distributed.py Uses graph-aware all-reduce for spatial sum reductions to support backprop through spatial means.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

return

if not BASELINE_FILE.exists():
assert not spatial_parallel, "Generate Baseline using NonDistributed backend first!"
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test depends on testdata/backward_step_baseline.pt existing when run with the spatial backend, but that baseline artifact is not added in this PR. As written, running under FME_DISTRIBUTED_BACKEND=model will fail with the "Generate Baseline" assertion. Please add the baseline file under fme/core/distributed/parallel_tests/testdata/ (or rework the test to be analytic/deterministic like other parallel tests).

Suggested change
assert not spatial_parallel, "Generate Baseline using NonDistributed backend first!"
if spatial_parallel:
pytest.skip(
"Baseline file missing; run test with NonDistributed backend "
"first to generate backward_step_baseline.pt."
)

Copilot uses AI. Check for mistakes.
Comment on lines +47 to +50
) -> tuple[nn.Module, torch.optim.Optimizer, LatLonOperations]:
"""
Build a DDP-wrapped TinyConvNet under ModelTorchDistributed and
a simple optimizer. Also returns LatLonOperations for computing a global loss.
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type annotation for _build_model_and_optimizer says it returns a torch.optim.Optimizer, but the function actually returns an OptimizationConfig as the second element. This mismatch can break type checking and makes the helper misleading; please update the annotation/docstring to match the actual return values (or return the optimizer if that’s what you intended).

Suggested change
) -> tuple[nn.Module, torch.optim.Optimizer, LatLonOperations]:
"""
Build a DDP-wrapped TinyConvNet under ModelTorchDistributed and
a simple optimizer. Also returns LatLonOperations for computing a global loss.
) -> tuple[nn.Module, OptimizationConfig, LatLonOperations]:
"""
Build a DDP-wrapped TinyConvNet under ModelTorchDistributed and
an optimization config. Also returns LatLonOperations for computing a global loss.

Copilot uses AI. Check for mistakes.
Comment on lines +153 to +157
# Wrap with DDP only if spatial distributed backend is active.
spatial_parallel = isinstance(dist._distributed, ModelTorchDistributed)
if spatial_parallel:
model = dist._distributed.wrap_module(model)

Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test reaches into dist._distributed (a private implementation detail) to detect the backend and to wrap the module. To avoid coupling the test to internal attributes, prefer using the public Distributed.wrap_module(model) API and derive whether spatial parallelism is active from public properties (e.g., compare world_size vs total_data_parallel_ranks) instead of dist._distributed type checks.

Copilot uses AI. Check for mistakes.
@peterdschwartz
Copy link
Author

So i did refactor the test to use SHT/iSHT instead. Currently, the gradients are off relative to baseline by exactly the number of nodes - so i'm investigating the source of that

backwards step

Add test that verifies consistency between NonDistribute and
TorchModelDistributed for loss and gradient calculation using simple
SHT/iSHT transforms
@peterdschwartz
Copy link
Author

OK!! Not so simple after all. So the gradient norms were an exact multiple of the number of ranks because of this:
pytorch/pytorch#58005 (comment)

It is considered expected and correct behavior. So i had to implement a new autograd wrapper so that it wouldn't sum twice during the backward pass:

class SpatialReplicatedSum(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, group):
        ctx.group = group
        y = x.clone()
        dist.all_reduce(y, op=dist.ReduceOp.SUM, group=group)
        return y

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        # Important: do NOT all-reduce again here.
        # The forward result is a replicated view of one logical global value.
        return grad_output, None

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 4 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants