Replace Gram matrix V projection with factored form#74
Replace Gram matrix V projection with factored form#74stmcgovern wants to merge 3 commits intoRed-Hat-AI-Innovation-Team:mainfrom
Conversation
📝 WalkthroughWalkthroughReplaces the Gram-matrix V projection with a factored all-gather/de‑interleave form, adds an opt-in per-module cache for the gathered full Changes
Sequence DiagramsequenceDiagram
participant Grad as Gradient
participant Proj as project_gradient_to_orthogonal_space
participant Module as ModuleCacheHolder
participant Comm as DistributedComm
participant SVD as LocalProjection
Grad->>Proj: request projection for module
Proj->>Module: check OSFT_CACHE_V & _osft_v_high_full
alt cache hit
Module-->>Proj: return _osft_v_high_full
else cache miss or disabled
Proj->>Comm: all_gather_into_tensor(local_V_high) (pad+de‑interleave if uneven)
Comm-->>Proj: V_high_full
Proj->>Module: store _osft_v_high_full (if enabled)
end
Proj->>SVD: coeff = local_dV @ V_high_full.T
SVD->>SVD: local_dV -= coeff @ V_high_full
SVD-->>Proj: return updated local_dV
Proj->>Proj: log cache metrics / increment counters
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
benchmarks/bench_v_proj.py (1)
22-35: Fail fast on non-divisible benchmark shapes.
k_high // Pandk_low // Psilently discard remainder rows, and the slice at Line 34 only coversP * local_k_highrows. If this helper gets reused with a non-divisible shape, it will publish timings for the wrong problem size instead of stopping early.Possible fix
def bench_target(name, k_high, k_low, M, P, dev, n_iters=100): """Benchmark one OSFT target shape. Returns dict of timings.""" + if k_high % P != 0 or k_low % P != 0: + raise ValueError( + f"k_high={k_high} and k_low={k_low} must both be divisible by P={P}" + ) local_k_high = k_high // P local_k_low = k_low // P🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_v_proj.py` around lines 22 - 35, The function bench_target silently truncates k_high and k_low by integer division into local_k_high/local_k_low and then slices V_full by rank * local_k_high; add explicit validation at the start of bench_target to check that k_high % P == 0 and k_low % P == 0 (or otherwise that both divide evenly by P), and raise a clear ValueError if not so that the benchmark fails fast; update the error message to reference the offending values (k_high, k_low, P) so callers can correct input shapes and avoid publishing wrong timings caused by the partial-row truncation in the local_V slice.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 605-611: The current all_gather_into_tensor assumes equal per-rank
rows for local_V_high; pad local_V_high to target_rows =
math.ceil(svd_dict["rank_high"] / world_size) before calling
dist.all_gather_into_tensor so each rank provides the same shape (use
torch.nn.functional.pad or torch.zeros on the same device/dtype and
concatenate), perform the all_gather_into_tensor into V_high_full sized
(target_rows * world_size, cols), then slice V_high_full[:svd_dict["rank_high"],
:] to restore the original total rank and assign back into the downstream
variable; ensure import math is present at file top if missing.
In `@tests/test_osft.py`:
- Around line 1626-1640: In test_cache_disabled_by_default, pin the module-level
OSFT_CACHE_V constant to False so the test is hermetic: before calling
self._create_simple_osft_model(), use the pytest monkeypatch fixture to
monkeypatch.setattr(<module_that_defines_OSFT_CACHE_V>, "OSFT_CACHE_V", False)
(or equivalent setattr) so the constant is forced off for the duration of the
test; rely on monkeypatch to automatically restore the original value after the
test and keep the rest of the test logic (model.train(), project_gradients(),
and the assertions) unchanged.
---
Nitpick comments:
In `@benchmarks/bench_v_proj.py`:
- Around line 22-35: The function bench_target silently truncates k_high and
k_low by integer division into local_k_high/local_k_low and then slices V_full
by rank * local_k_high; add explicit validation at the start of bench_target to
check that k_high % P == 0 and k_low % P == 0 (or otherwise that both divide
evenly by P), and raise a clear ValueError if not so that the benchmark fails
fast; update the error message to reference the offending values (k_high, k_low,
P) so callers can correct input shapes and avoid publishing wrong timings caused
by the partial-row truncation in the local_V slice.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 20a6d6ef-db54-486a-a9a4-6bc05401b8a6
📒 Files selected for processing (3)
benchmarks/bench_v_proj.pysrc/mini_trainer/osft_utils.pytests/test_osft.py
Replace dV -= dV @ (V_high^T @ V_high) with the factored form dV -= (dV @ V_high^T) @ V_high. Under FSDP2 this replaces an (M, M) all-reduce with a (k_high, M) all-gather — M/k_high fewer bytes (2x for square weights, 7x for down_proj) and all-gather is cheaper per byte. Add opt-in caching of the all-gathered V_high (OSFT_CACHE_V=1, default off). V_high is frozen, so the cache is exact. Includes bench_v_proj.py benchmark and 10 new tests.
0e9b3fe to
6ab859c
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
src/mini_trainer/osft_utils.py (1)
652-656: Optional: de-interleave can be simplified to a single slice.Since
all_gather_into_tensorconcatenates data in rank order and padding only appears at the end of the last rank's contribution, the loop can be replaced with a single slice:V_high_full = gathered[:full_k_high]The current approach is correct, but the simpler form is equally valid and avoids the intermediate list allocation.
Simplified de-interleave
- parts = [] - for i in range(world_size): - n = min(rows_per_rank, full_k_high - rows_per_rank * i) - parts.append(gathered[i * rows_per_rank : i * rows_per_rank + n]) - V_high_full = torch.cat(parts) + V_high_full = gathered[:full_k_high]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/mini_trainer/osft_utils.py` around lines 652 - 656, The de-interleaving loop that builds parts and then concatenates into V_high_full (variables: parts, world_size, rows_per_rank, full_k_high, gathered, V_high_full) can be simplified: replace the loop and torch.cat with a single slice of the gathered tensor up to full_k_high (i.e., use gathered[:full_k_high]) to avoid the intermediate list allocation while preserving correctness.benchmarks/bench_v_proj.py (1)
110-116: Consider passingrankas a parameter instead of using global state.The global
rankvariable works but reduces clarity. Passing it explicitly would make the data flow clearer:-def bench_target(name, k_high, k_low, M, P, dev, n_iters=100): +def bench_target(name, k_high, k_low, M, P, rank, dev, n_iters=100):This is minor—the current approach is acceptable for a benchmark script.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_v_proj.py` around lines 110 - 116, The current code uses a module-level global `rank` and the function `run(rank_, world_size)` which assigns to that global; replace this by passing `rank` explicitly: remove the module-level `rank` global, change `run` to use a local parameter (keep the signature `run(rank, world_size)` or rename consistently), and update any downstream functions or places that reference the global `rank` to accept a `rank` parameter and use the local one instead (search for references to the global `rank` symbol and the `run` function to update callers).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@benchmarks/bench_v_proj.py`:
- Around line 110-116: The current code uses a module-level global `rank` and
the function `run(rank_, world_size)` which assigns to that global; replace this
by passing `rank` explicitly: remove the module-level `rank` global, change
`run` to use a local parameter (keep the signature `run(rank, world_size)` or
rename consistently), and update any downstream functions or places that
reference the global `rank` to accept a `rank` parameter and use the local one
instead (search for references to the global `rank` symbol and the `run`
function to update callers).
In `@src/mini_trainer/osft_utils.py`:
- Around line 652-656: The de-interleaving loop that builds parts and then
concatenates into V_high_full (variables: parts, world_size, rows_per_rank,
full_k_high, gathered, V_high_full) can be simplified: replace the loop and
torch.cat with a single slice of the gathered tensor up to full_k_high (i.e.,
use gathered[:full_k_high]) to avoid the intermediate list allocation while
preserving correctness.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 036317a0-8b97-4a7e-b179-014ab9f3470b
📒 Files selected for processing (3)
benchmarks/bench_v_proj.pysrc/mini_trainer/osft_utils.pytests/test_osft.py
There was a problem hiding this comment.
🧹 Nitpick comments (3)
benchmarks/bench_v_proj.py (3)
22-26: Benchmark restricts to even sharding — doesn't exercise uneven shard path.The divisibility check ensures the benchmark only tests even sharding. While this avoids complexity, it means the uneven shard de-interleave logic in
osft_utils.py(which has a bug — see other comment) isn't exercised here.Consider adding a note in the docstring or a TODO for future coverage:
+ # NOTE: This benchmark requires even sharding (k_high % P == 0). + # Uneven sharding paths in osft_utils.py are not covered here. if k_high % P != 0 or k_low % P != 0:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_v_proj.py` around lines 22 - 26, The bench_target function currently raises if k_high or k_low aren't divisible by P, which prevents testing uneven sharding and the related de-interleave code; update the bench_target docstring (function bench_target) to note this limitation and add a TODO comment suggesting adding a non-divisible-shard test path to exercise uneven shard logic (reference variables k_high, k_low, P and local_k_high) so future work will remove the divisibility guard or add separate test coverage for the uneven-shard code path.
149-159: Aggregate calculation is Llama-8B specific — consider noting this.The calculation assumes 32 layers with 6 square targets + 1 down_proj per layer. This is accurate for Llama-8B but could mislead if someone uses this benchmark for other architectures.
The inline comment at line 149 helps, but consider making the model name more prominent in the output:
- print(" Aggregate (32 layers x 7 targets = 224):") + print(" Aggregate for Llama-8B (32 layers x 7 targets = 224):")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_v_proj.py` around lines 149 - 159, The aggregate calculation and printed summary in bench_v_proj.py currently hardcodes "32 layers x 7 targets" (variables sq, dp and computed gram_tot, fact_tot, cached_tot) which is Llama-8B specific; update the output to explicitly state the model/assumption and/or compute the aggregate from configurable values instead of fixed literals: surface the model name or a derived string (e.g., model_name or computed n_layers and targets_per_layer) in the print header, or compute layer_count and targets_per_layer from inputs used to derive sq and dp and use those when printing so the message accurately reflects the configuration used by gram_tot, fact_tot, and cached_tot.
110-115: Consider making port configurable to avoid conflicts.Hardcoded port
29500could conflict with other distributed jobs or parallel benchmark runs.+def run(rank, world_size, port=29500): + os.environ.update(MASTER_ADDR="localhost", MASTER_PORT=str(port), ...) -def run(rank, world_size): - os.environ.update(MASTER_ADDR="localhost", MASTER_PORT="29500", ...)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_v_proj.py` around lines 110 - 115, The hardcoded MASTER_PORT value in run() risks conflicts; make the port configurable by accepting a port argument (e.g., port=None) or reading a dedicated env var (e.g., DIST_MASTER_PORT) and defaulting to "29500" if unset, then use that value when setting os.environ["MASTER_PORT"] before calling dist.init_process_group("nccl"); update the run signature and any callers to pass the desired port and document the new parameter so parallel runs can specify different ports.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@benchmarks/bench_v_proj.py`:
- Around line 22-26: The bench_target function currently raises if k_high or
k_low aren't divisible by P, which prevents testing uneven sharding and the
related de-interleave code; update the bench_target docstring (function
bench_target) to note this limitation and add a TODO comment suggesting adding a
non-divisible-shard test path to exercise uneven shard logic (reference
variables k_high, k_low, P and local_k_high) so future work will remove the
divisibility guard or add separate test coverage for the uneven-shard code path.
- Around line 149-159: The aggregate calculation and printed summary in
bench_v_proj.py currently hardcodes "32 layers x 7 targets" (variables sq, dp
and computed gram_tot, fact_tot, cached_tot) which is Llama-8B specific; update
the output to explicitly state the model/assumption and/or compute the aggregate
from configurable values instead of fixed literals: surface the model name or a
derived string (e.g., model_name or computed n_layers and targets_per_layer) in
the print header, or compute layer_count and targets_per_layer from inputs used
to derive sq and dp and use those when printing so the message accurately
reflects the configuration used by gram_tot, fact_tot, and cached_tot.
- Around line 110-115: The hardcoded MASTER_PORT value in run() risks conflicts;
make the port configurable by accepting a port argument (e.g., port=None) or
reading a dedicated env var (e.g., DIST_MASTER_PORT) and defaulting to "29500"
if unset, then use that value when setting os.environ["MASTER_PORT"] before
calling dist.init_process_group("nccl"); update the run signature and any
callers to pass the desired port and document the new parameter so parallel runs
can specify different ports.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c3384706-c27c-4475-a91a-1b41bc9bc3e7
📒 Files selected for processing (2)
benchmarks/bench_v_proj.pysrc/mini_trainer/osft_utils.py
Fixes #73
Replace dV -= dV @ (V_high^T @ V_high) with the factored form dV -= (dV @ V_high^T) @ V_high. Under FSDP2 this replaces an (M, M) all-reduce with a (k_high, M) all-gather — M/k_high fewer bytes (2x for square weights, 7x for down_proj) and all-gather is cheaper per byte.
Add opt-in caching of the all-gathered V_high (OSFT_CACHE_V=1, default off). V_high is frozen, so the cache is exact.
Includes bench_v_proj.py benchmark and 10 new tests.
Summary by CodeRabbit
New Features
Bug Fixes / Improvements
Tests