Skip to content

Replace Gram matrix V projection with factored form#74

Open
stmcgovern wants to merge 3 commits intoRed-Hat-AI-Innovation-Team:mainfrom
stmcgovern:factored-v-projection
Open

Replace Gram matrix V projection with factored form#74
stmcgovern wants to merge 3 commits intoRed-Hat-AI-Innovation-Team:mainfrom
stmcgovern:factored-v-projection

Conversation

@stmcgovern
Copy link
Copy Markdown
Contributor

@stmcgovern stmcgovern commented Mar 6, 2026

Fixes #73
Replace dV -= dV @ (V_high^T @ V_high) with the factored form dV -= (dV @ V_high^T) @ V_high. Under FSDP2 this replaces an (M, M) all-reduce with a (k_high, M) all-gather — M/k_high fewer bytes (2x for square weights, 7x for down_proj) and all-gather is cheaper per byte.

Add opt-in caching of the all-gathered V_high (OSFT_CACHE_V=1, default off). V_high is frozen, so the cache is exact.

Includes bench_v_proj.py benchmark and 10 new tests.

Summary by CodeRabbit

  • New Features

    • Optional caching of the full V_high tensor to avoid repeated distributed gathers and speed V-gradient projection
    • Command-line benchmark script to compare Gram, factored, and cached V-projection performance across multi-GPU setups
  • Bug Fixes / Improvements

    • Switched V-projection to a factored update path that handles uneven shard layouts and leverages the cache when available
    • Ensures caches are cleared on reinitialization and not persisted into model state
  • Tests

    • Extensive tests covering caching behavior, uneven-shard de-interleaving, correctness vs. Gram-based projection, cache clearing, and no-leak guarantees

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 6, 2026

📝 Walkthrough

Walkthrough

Replaces the Gram-matrix V projection with a factored all-gather/de‑interleave form, adds an opt-in per-module cache for the gathered full V_high, introduces a 2-GPU benchmark comparing Gram/Factored/Cached modes, and adds tests for caching, uneven-shard de-interleave, and correctness.

Changes

Cohort / File(s) Summary
Benchmark script
benchmarks/bench_v_proj.py
New distributed 2‑GPU benchmark measuring Gram, Factored, and Cached V‑projection modes with warmup, per‑mode timing, correctness checks, and aggregate reporting.
OSFT projection & cache
src/mini_trainer/osft_utils.py
Adds OSFT_CACHE_V env flag; extends project_gradient_to_orthogonal_space(..., cache_holder=...); replaces Gram all‑reduce with factored all‑gather + de‑interleave; implements per‑module _osft_v_high_full cache, eviction in _reset_osft_metadata, and cache bookkeeping/logging.
Tests: cache & uneven shards
tests/test_osft.py
Adds TestVProjectionCache and TestUnevenShardDeinterleave suites covering cache population/shape/size, parity vs Gram projection, cache vs no‑cache equivalence, cache clearing, non‑leakage into state_dict, orthogonality across steps, and uneven‑shard de‑interleave mocking and caching behavior.

Sequence Diagram

sequenceDiagram
    participant Grad as Gradient
    participant Proj as project_gradient_to_orthogonal_space
    participant Module as ModuleCacheHolder
    participant Comm as DistributedComm
    participant SVD as LocalProjection

    Grad->>Proj: request projection for module
    Proj->>Module: check OSFT_CACHE_V & _osft_v_high_full
    alt cache hit
        Module-->>Proj: return _osft_v_high_full
    else cache miss or disabled
        Proj->>Comm: all_gather_into_tensor(local_V_high) (pad+de‑interleave if uneven)
        Comm-->>Proj: V_high_full
        Proj->>Module: store _osft_v_high_full (if enabled)
    end
    Proj->>SVD: coeff = local_dV @ V_high_full.T
    SVD->>SVD: local_dV -= coeff @ V_high_full
    SVD-->>Proj: return updated local_dV
    Proj->>Proj: log cache metrics / increment counters
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • RobotSail
  • NikhilNayak-debug

Poem

🐰 I gathered V beneath the moonlit sod,
shards untangled, cached against the clod.
No heavy Gram to slow my hop,
I hop, I prod, then hop nonstop.
Small bytes, bright leaps—hooray, we nod.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.68% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Replace Gram matrix V projection with factored form' accurately and concisely summarizes the primary change in the pull request.
Linked Issues check ✅ Passed The pull request successfully implements all coding requirements from issue #73: replaces Gram matrix computation with factored form, changes all-reduce to all-gather for V_high, adds optional caching via OSFT_CACHE_V flag, and includes comprehensive test coverage.
Out of Scope Changes check ✅ Passed All changes are directly scoped to issue #73: the osft_utils.py modifications implement the factored V projection and caching logic, test_osft.py adds required test coverage, and bench_v_proj.py provides the benchmark tool as described in the PR objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
benchmarks/bench_v_proj.py (1)

22-35: Fail fast on non-divisible benchmark shapes.

k_high // P and k_low // P silently discard remainder rows, and the slice at Line 34 only covers P * local_k_high rows. If this helper gets reused with a non-divisible shape, it will publish timings for the wrong problem size instead of stopping early.

Possible fix
 def bench_target(name, k_high, k_low, M, P, dev, n_iters=100):
     """Benchmark one OSFT target shape.  Returns dict of timings."""
+    if k_high % P != 0 or k_low % P != 0:
+        raise ValueError(
+            f"k_high={k_high} and k_low={k_low} must both be divisible by P={P}"
+        )
     local_k_high = k_high // P
     local_k_low = k_low // P
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_v_proj.py` around lines 22 - 35, The function bench_target
silently truncates k_high and k_low by integer division into
local_k_high/local_k_low and then slices V_full by rank * local_k_high; add
explicit validation at the start of bench_target to check that k_high % P == 0
and k_low % P == 0 (or otherwise that both divide evenly by P), and raise a
clear ValueError if not so that the benchmark fails fast; update the error
message to reference the offending values (k_high, k_low, P) so callers can
correct input shapes and avoid publishing wrong timings caused by the
partial-row truncation in the local_V slice.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 605-611: The current all_gather_into_tensor assumes equal per-rank
rows for local_V_high; pad local_V_high to target_rows =
math.ceil(svd_dict["rank_high"] / world_size) before calling
dist.all_gather_into_tensor so each rank provides the same shape (use
torch.nn.functional.pad or torch.zeros on the same device/dtype and
concatenate), perform the all_gather_into_tensor into V_high_full sized
(target_rows * world_size, cols), then slice V_high_full[:svd_dict["rank_high"],
:] to restore the original total rank and assign back into the downstream
variable; ensure import math is present at file top if missing.

In `@tests/test_osft.py`:
- Around line 1626-1640: In test_cache_disabled_by_default, pin the module-level
OSFT_CACHE_V constant to False so the test is hermetic: before calling
self._create_simple_osft_model(), use the pytest monkeypatch fixture to
monkeypatch.setattr(<module_that_defines_OSFT_CACHE_V>, "OSFT_CACHE_V", False)
(or equivalent setattr) so the constant is forced off for the duration of the
test; rely on monkeypatch to automatically restore the original value after the
test and keep the rest of the test logic (model.train(), project_gradients(),
and the assertions) unchanged.

---

Nitpick comments:
In `@benchmarks/bench_v_proj.py`:
- Around line 22-35: The function bench_target silently truncates k_high and
k_low by integer division into local_k_high/local_k_low and then slices V_full
by rank * local_k_high; add explicit validation at the start of bench_target to
check that k_high % P == 0 and k_low % P == 0 (or otherwise that both divide
evenly by P), and raise a clear ValueError if not so that the benchmark fails
fast; update the error message to reference the offending values (k_high, k_low,
P) so callers can correct input shapes and avoid publishing wrong timings caused
by the partial-row truncation in the local_V slice.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 20a6d6ef-db54-486a-a9a4-6bc05401b8a6

📥 Commits

Reviewing files that changed from the base of the PR and between 4d6dc87 and 0e9b3fe.

📒 Files selected for processing (3)
  • benchmarks/bench_v_proj.py
  • src/mini_trainer/osft_utils.py
  • tests/test_osft.py

Replace dV -= dV @ (V_high^T @ V_high) with the factored form
dV -= (dV @ V_high^T) @ V_high.  Under FSDP2 this replaces an
(M, M) all-reduce with a (k_high, M) all-gather — M/k_high fewer
bytes (2x for square weights, 7x for down_proj) and all-gather is
cheaper per byte.

Add opt-in caching of the all-gathered V_high (OSFT_CACHE_V=1,
default off).  V_high is frozen, so the cache is exact.

Includes bench_v_proj.py benchmark and 10 new tests.
@stmcgovern stmcgovern force-pushed the factored-v-projection branch from 0e9b3fe to 6ab859c Compare April 1, 2026 01:41
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 1, 2026

Codecov Report

❌ Patch coverage is 92.85714% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/mini_trainer/osft_utils.py 92.85% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
src/mini_trainer/osft_utils.py (1)

652-656: Optional: de-interleave can be simplified to a single slice.

Since all_gather_into_tensor concatenates data in rank order and padding only appears at the end of the last rank's contribution, the loop can be replaced with a single slice:

V_high_full = gathered[:full_k_high]

The current approach is correct, but the simpler form is equally valid and avoids the intermediate list allocation.

Simplified de-interleave
-                        parts = []
-                        for i in range(world_size):
-                            n = min(rows_per_rank, full_k_high - rows_per_rank * i)
-                            parts.append(gathered[i * rows_per_rank : i * rows_per_rank + n])
-                        V_high_full = torch.cat(parts)
+                        V_high_full = gathered[:full_k_high]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/mini_trainer/osft_utils.py` around lines 652 - 656, The de-interleaving
loop that builds parts and then concatenates into V_high_full (variables: parts,
world_size, rows_per_rank, full_k_high, gathered, V_high_full) can be
simplified: replace the loop and torch.cat with a single slice of the gathered
tensor up to full_k_high (i.e., use gathered[:full_k_high]) to avoid the
intermediate list allocation while preserving correctness.
benchmarks/bench_v_proj.py (1)

110-116: Consider passing rank as a parameter instead of using global state.

The global rank variable works but reduces clarity. Passing it explicitly would make the data flow clearer:

-def bench_target(name, k_high, k_low, M, P, dev, n_iters=100):
+def bench_target(name, k_high, k_low, M, P, rank, dev, n_iters=100):

This is minor—the current approach is acceptable for a benchmark script.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_v_proj.py` around lines 110 - 116, The current code uses a
module-level global `rank` and the function `run(rank_, world_size)` which
assigns to that global; replace this by passing `rank` explicitly: remove the
module-level `rank` global, change `run` to use a local parameter (keep the
signature `run(rank, world_size)` or rename consistently), and update any
downstream functions or places that reference the global `rank` to accept a
`rank` parameter and use the local one instead (search for references to the
global `rank` symbol and the `run` function to update callers).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@benchmarks/bench_v_proj.py`:
- Around line 110-116: The current code uses a module-level global `rank` and
the function `run(rank_, world_size)` which assigns to that global; replace this
by passing `rank` explicitly: remove the module-level `rank` global, change
`run` to use a local parameter (keep the signature `run(rank, world_size)` or
rename consistently), and update any downstream functions or places that
reference the global `rank` to accept a `rank` parameter and use the local one
instead (search for references to the global `rank` symbol and the `run`
function to update callers).

In `@src/mini_trainer/osft_utils.py`:
- Around line 652-656: The de-interleaving loop that builds parts and then
concatenates into V_high_full (variables: parts, world_size, rows_per_rank,
full_k_high, gathered, V_high_full) can be simplified: replace the loop and
torch.cat with a single slice of the gathered tensor up to full_k_high (i.e.,
use gathered[:full_k_high]) to avoid the intermediate list allocation while
preserving correctness.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 036317a0-8b97-4a7e-b179-014ab9f3470b

📥 Commits

Reviewing files that changed from the base of the PR and between 0e9b3fe and 6ab859c.

📒 Files selected for processing (3)
  • benchmarks/bench_v_proj.py
  • src/mini_trainer/osft_utils.py
  • tests/test_osft.py

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
benchmarks/bench_v_proj.py (3)

22-26: Benchmark restricts to even sharding — doesn't exercise uneven shard path.

The divisibility check ensures the benchmark only tests even sharding. While this avoids complexity, it means the uneven shard de-interleave logic in osft_utils.py (which has a bug — see other comment) isn't exercised here.

Consider adding a note in the docstring or a TODO for future coverage:

+    # NOTE: This benchmark requires even sharding (k_high % P == 0).
+    # Uneven sharding paths in osft_utils.py are not covered here.
     if k_high % P != 0 or k_low % P != 0:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_v_proj.py` around lines 22 - 26, The bench_target function
currently raises if k_high or k_low aren't divisible by P, which prevents
testing uneven sharding and the related de-interleave code; update the
bench_target docstring (function bench_target) to note this limitation and add a
TODO comment suggesting adding a non-divisible-shard test path to exercise
uneven shard logic (reference variables k_high, k_low, P and local_k_high) so
future work will remove the divisibility guard or add separate test coverage for
the uneven-shard code path.

149-159: Aggregate calculation is Llama-8B specific — consider noting this.

The calculation assumes 32 layers with 6 square targets + 1 down_proj per layer. This is accurate for Llama-8B but could mislead if someone uses this benchmark for other architectures.

The inline comment at line 149 helps, but consider making the model name more prominent in the output:

-        print("  Aggregate (32 layers x 7 targets = 224):")
+        print("  Aggregate for Llama-8B (32 layers x 7 targets = 224):")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_v_proj.py` around lines 149 - 159, The aggregate calculation
and printed summary in bench_v_proj.py currently hardcodes "32 layers x 7
targets" (variables sq, dp and computed gram_tot, fact_tot, cached_tot) which is
Llama-8B specific; update the output to explicitly state the model/assumption
and/or compute the aggregate from configurable values instead of fixed literals:
surface the model name or a derived string (e.g., model_name or computed
n_layers and targets_per_layer) in the print header, or compute layer_count and
targets_per_layer from inputs used to derive sq and dp and use those when
printing so the message accurately reflects the configuration used by gram_tot,
fact_tot, and cached_tot.

110-115: Consider making port configurable to avoid conflicts.

Hardcoded port 29500 could conflict with other distributed jobs or parallel benchmark runs.

+def run(rank, world_size, port=29500):
+    os.environ.update(MASTER_ADDR="localhost", MASTER_PORT=str(port), ...)
-def run(rank, world_size):
-    os.environ.update(MASTER_ADDR="localhost", MASTER_PORT="29500", ...)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_v_proj.py` around lines 110 - 115, The hardcoded MASTER_PORT
value in run() risks conflicts; make the port configurable by accepting a port
argument (e.g., port=None) or reading a dedicated env var (e.g.,
DIST_MASTER_PORT) and defaulting to "29500" if unset, then use that value when
setting os.environ["MASTER_PORT"] before calling
dist.init_process_group("nccl"); update the run signature and any callers to
pass the desired port and document the new parameter so parallel runs can
specify different ports.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@benchmarks/bench_v_proj.py`:
- Around line 22-26: The bench_target function currently raises if k_high or
k_low aren't divisible by P, which prevents testing uneven sharding and the
related de-interleave code; update the bench_target docstring (function
bench_target) to note this limitation and add a TODO comment suggesting adding a
non-divisible-shard test path to exercise uneven shard logic (reference
variables k_high, k_low, P and local_k_high) so future work will remove the
divisibility guard or add separate test coverage for the uneven-shard code path.
- Around line 149-159: The aggregate calculation and printed summary in
bench_v_proj.py currently hardcodes "32 layers x 7 targets" (variables sq, dp
and computed gram_tot, fact_tot, cached_tot) which is Llama-8B specific; update
the output to explicitly state the model/assumption and/or compute the aggregate
from configurable values instead of fixed literals: surface the model name or a
derived string (e.g., model_name or computed n_layers and targets_per_layer) in
the print header, or compute layer_count and targets_per_layer from inputs used
to derive sq and dp and use those when printing so the message accurately
reflects the configuration used by gram_tot, fact_tot, and cached_tot.
- Around line 110-115: The hardcoded MASTER_PORT value in run() risks conflicts;
make the port configurable by accepting a port argument (e.g., port=None) or
reading a dedicated env var (e.g., DIST_MASTER_PORT) and defaulting to "29500"
if unset, then use that value when setting os.environ["MASTER_PORT"] before
calling dist.init_process_group("nccl"); update the run signature and any
callers to pass the desired port and document the new parameter so parallel runs
can specify different ports.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c3384706-c27c-4475-a91a-1b41bc9bc3e7

📥 Commits

Reviewing files that changed from the base of the PR and between 6ab859c and 2568e80.

📒 Files selected for processing (2)
  • benchmarks/bench_v_proj.py
  • src/mini_trainer/osft_utils.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Replace Gram matrix V projection with factored form

1 participant