Skip to content

[tx] Implement stacked weights#1018

Open
pcmoritz wants to merge 28 commits intoNovaSky-AI:mainfrom
pcmoritz:tx-stacked-layers
Open

[tx] Implement stacked weights#1018
pcmoritz wants to merge 28 commits intoNovaSky-AI:mainfrom
pcmoritz:tx-stacked-layers

Conversation

@pcmoritz
Copy link
Collaborator

@pcmoritz pcmoritz commented Feb 4, 2026

This is based on all the great work that @raulchen did in #996 and #906, it also fixes the performance regression in decoding vs. the main branch.


Open with Devin

@pcmoritz pcmoritz added the tx label Feb 4, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces StackedDecoderLayers to optimize the transformer forward pass using nnx.vmap and jax.lax.scan, which is a significant performance improvement for training and prefill. The changes are well-encapsulated, making the model code cleaner and more efficient. However, I've identified a critical bug in the load_safetensors utility related to how parameter paths are handled, which would prevent it from loading weights correctly for certain layer types.

Comment on lines 118 to 133
updates = []
for path, param in nnx.to_flat_state(nnx.state(module)):
if filter_fn is not None and not filter_fn(path):
continue
key = key_prefix + get_param_key(path)
if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path):
continue
if "experts" in path:
tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0)
else:
tensor = tensors[key] if "embed_tokens" in key else tensors[key].T
if len(path) >= 2 and path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}:
tensor = tensor.reshape(param.shape)
assert param.shape == tensor.shape, f"shape mismatch for {key}"
updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding)))
nnx.update(module, nnx.from_flat_state(updates))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The checks for parameter types like "lora_A" in path or "experts" in path are incorrect. The path variable is a tuple of nnx.path.PathEntry objects, not strings, so these checks will always evaluate to False. This will prevent weights for LoRA, experts, and projections from being loaded correctly.

To fix this, you should convert the path to a tuple of strings before performing these checks. This will ensure that the logic correctly identifies the parameter types and applies the appropriate loading logic.

Suggested change
updates = []
for path, param in nnx.to_flat_state(nnx.state(module)):
if filter_fn is not None and not filter_fn(path):
continue
key = key_prefix + get_param_key(path)
if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path):
continue
if "experts" in path:
tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0)
else:
tensor = tensors[key] if "embed_tokens" in key else tensors[key].T
if len(path) >= 2 and path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}:
tensor = tensor.reshape(param.shape)
assert param.shape == tensor.shape, f"shape mismatch for {key}"
updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding)))
nnx.update(module, nnx.from_flat_state(updates))
updates = []
for path, param in nnx.to_flat_state(nnx.state(module)):
path_str_tuple = tuple(map(str, path))
if filter_fn is not None and not filter_fn(path):
continue
key = key_prefix + get_param_key(path)
if skip_lora and any(p in path_str_tuple for p in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")):
continue
if "experts" in path_str_tuple:
tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0)
else:
tensor = tensors[key] if "embed_tokens" in key else tensors[key].T
if len(path_str_tuple) >= 2 and path_str_tuple[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}:
tensor = tensor.reshape(param.shape)
assert param.shape == tensor.shape, f"shape mismatch for {key}"
updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding)))
nnx.update(module, nnx.from_flat_state(updates))

return self.get_metadata("_parent")[self.get_metadata("_idx")].shape


class StackedDecoderLayers(nnx.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably the easiest way to implement DeepSeekV3 is to implement DualStackedDecoderLayers which has two StackedDecoderLayers as members and the same interface as StackedDecoderLayers (modulo the constructor which can take two create_layer_fn functions and takes their respective numbers as arguments). This could be a separate PR.

@pcmoritz
Copy link
Collaborator Author

pcmoritz commented Feb 5, 2026

/gemini review

@pcmoritz pcmoritz changed the title [WIP] [tx] Implement stacked weights [tx] Implement stacked weights Feb 5, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable optimization by implementing stacked weights for decoder layers, which should improve training and prefill performance by leveraging jax.lax.scan. The use of StackedDecoderLayers and the ArrayRef helper class to maintain compatibility with standard checkpoint formats is a clever approach.

The changes are well-integrated across the model, layers, and utility functions. However, I've found a critical issue in the save_safetensors function, which was not fully updated to handle the new ArrayRef variable types returned by unstack_state, which will cause model saving to fail. I've also included a medium-severity suggestion to improve the robustness of extract_adapter_state.

Addressing these points will help ensure the new implementation is both performant and correct.


# unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths
# (layers.0.xxx) matching the checkpoint key format used by HuggingFace
model_params = nnx.to_flat_state(unstack_state(model))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While using unstack_state is correct for handling stacked models, it introduces a bug in the subsequent loop (lines 146-160). The param variable in the loop is now an ArrayRef (a type of nnx.Variable), not a raw JAX array. Operations like param.reshape, param.T, or slicing like param[i,:,:] will fail on this object.

Additionally, checks like if "q_proj" in path and if "embed_tokens" in path are incorrect as path is a tuple of strings and these checks are too broad.

The loop body needs to be updated to correctly handle nnx.Variable objects by first extracting their array value (e.g., value = param[...]) and then performing operations on that value. The checks should also be made more specific, for example by checking path[-2] for layer types or key for embed_tokens.

Here is a suggested correction for the loop body:

    tensors = {}
    for path, param in model_params:
        if "rngs" in path:
            continue
        if filter_fn is not None and not filter_fn(path):
            continue
        key = get_param_key(path, prefix=prefix)
        value = param[...]
        if "experts" in path:
            for i in range(config.get_num_experts()):
                tensors[get_expert_key(path, i)] = value[i, :, :].T
            continue
        if path[-2] in {"q_proj", "k_proj", "v_proj"}:
            value = value.reshape(value.shape[0], -1)
        elif path[-2] == "o_proj":
            value = value.reshape(-1, value.shape[-1])
        tensors[key] = value if "embed_tokens" in key else value.T

Comment on lines 265 to 268
idx = get_adapter_idx(path, adapter_index)
if key == "lora_A":
return p[idx][..., :, :rank]
return p[idx][..., :rank, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The original assertion on parameter dimensions was correctly removed as it didn't account for stacked weights. However, it would be good to re-introduce a corrected assertion to maintain robustness. You can assert on the shape of the sliced parameter, which should have a consistent number of dimensions (3 for linear, 4 for expert LoRA) regardless of whether the original parameter was stacked.

This helps catch potential issues with parameter shapes early.

Suggested change
idx = get_adapter_idx(path, adapter_index)
if key == "lora_A":
return p[idx][..., :, :rank]
return p[idx][..., :rank, :]
idx = get_adapter_idx(path, adapter_index)
sliced_p = p[idx]
assert sliced_p.ndim in {3, 4}, f"LoRA parameters sliced by adapter should have 3 or 4 dimensions, got shape {sliced_p.shape}"
if key == "lora_A":
return sliced_p[..., :, :rank]
return sliced_p[..., :rank, :]

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no bugs or issues to report.

Open in Devin Review

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 8 additional findings in Devin Review.

Open in Devin Review

@pcmoritz pcmoritz closed this Feb 11, 2026
@pcmoritz pcmoritz reopened this Feb 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant