Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces StackedDecoderLayers to optimize the transformer forward pass using nnx.vmap and jax.lax.scan, which is a significant performance improvement for training and prefill. The changes are well-encapsulated, making the model code cleaner and more efficient. However, I've identified a critical bug in the load_safetensors utility related to how parameter paths are handled, which would prevent it from loading weights correctly for certain layer types.
skyrl-tx/tx/utils/models.py
Outdated
| updates = [] | ||
| for path, param in nnx.to_flat_state(nnx.state(module)): | ||
| if filter_fn is not None and not filter_fn(path): | ||
| continue | ||
| key = key_prefix + get_param_key(path) | ||
| if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): | ||
| continue | ||
| if "experts" in path: | ||
| tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) | ||
| else: | ||
| tensor = tensors[key] if "embed_tokens" in key else tensors[key].T | ||
| if len(path) >= 2 and path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: | ||
| tensor = tensor.reshape(param.shape) | ||
| assert param.shape == tensor.shape, f"shape mismatch for {key}" | ||
| updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding))) | ||
| nnx.update(module, nnx.from_flat_state(updates)) |
There was a problem hiding this comment.
The checks for parameter types like "lora_A" in path or "experts" in path are incorrect. The path variable is a tuple of nnx.path.PathEntry objects, not strings, so these checks will always evaluate to False. This will prevent weights for LoRA, experts, and projections from being loaded correctly.
To fix this, you should convert the path to a tuple of strings before performing these checks. This will ensure that the logic correctly identifies the parameter types and applies the appropriate loading logic.
| updates = [] | |
| for path, param in nnx.to_flat_state(nnx.state(module)): | |
| if filter_fn is not None and not filter_fn(path): | |
| continue | |
| key = key_prefix + get_param_key(path) | |
| if skip_lora and ("lora_A" in path or "lora_B" in path or "lora_scaling" in path or "lora_ranks" in path): | |
| continue | |
| if "experts" in path: | |
| tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) | |
| else: | |
| tensor = tensors[key] if "embed_tokens" in key else tensors[key].T | |
| if len(path) >= 2 and path[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: | |
| tensor = tensor.reshape(param.shape) | |
| assert param.shape == tensor.shape, f"shape mismatch for {key}" | |
| updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding))) | |
| nnx.update(module, nnx.from_flat_state(updates)) | |
| updates = [] | |
| for path, param in nnx.to_flat_state(nnx.state(module)): | |
| path_str_tuple = tuple(map(str, path)) | |
| if filter_fn is not None and not filter_fn(path): | |
| continue | |
| key = key_prefix + get_param_key(path) | |
| if skip_lora and any(p in path_str_tuple for p in ("lora_A", "lora_B", "lora_scaling", "lora_ranks")): | |
| continue | |
| if "experts" in path_str_tuple: | |
| tensor = np.stack([tensors[key_prefix + get_expert_key(path, i)].T for i in range(config.get_num_experts())], axis=0) | |
| else: | |
| tensor = tensors[key] if "embed_tokens" in key else tensors[key].T | |
| if len(path_str_tuple) >= 2 and path_str_tuple[-2] in {"q_proj", "k_proj", "v_proj", "o_proj"}: | |
| tensor = tensor.reshape(param.shape) | |
| assert param.shape == tensor.shape, f"shape mismatch for {key}" | |
| updates.append((path, jax.device_put(tensor.astype(param.dtype), param.sharding))) | |
| nnx.update(module, nnx.from_flat_state(updates)) |
| return self.get_metadata("_parent")[self.get_metadata("_idx")].shape | ||
|
|
||
|
|
||
| class StackedDecoderLayers(nnx.Module): |
There was a problem hiding this comment.
Probably the easiest way to implement DeepSeekV3 is to implement DualStackedDecoderLayers which has two StackedDecoderLayers as members and the same interface as StackedDecoderLayers (modulo the constructor which can take two create_layer_fn functions and takes their respective numbers as arguments). This could be a separate PR.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and valuable optimization by implementing stacked weights for decoder layers, which should improve training and prefill performance by leveraging jax.lax.scan. The use of StackedDecoderLayers and the ArrayRef helper class to maintain compatibility with standard checkpoint formats is a clever approach.
The changes are well-integrated across the model, layers, and utility functions. However, I've found a critical issue in the save_safetensors function, which was not fully updated to handle the new ArrayRef variable types returned by unstack_state, which will cause model saving to fail. I've also included a medium-severity suggestion to improve the robustness of extract_adapter_state.
Addressing these points will help ensure the new implementation is both performant and correct.
|
|
||
| # unstack_state converts stacked paths (layers._stacked.xxx) to per-layer paths | ||
| # (layers.0.xxx) matching the checkpoint key format used by HuggingFace | ||
| model_params = nnx.to_flat_state(unstack_state(model)) |
There was a problem hiding this comment.
While using unstack_state is correct for handling stacked models, it introduces a bug in the subsequent loop (lines 146-160). The param variable in the loop is now an ArrayRef (a type of nnx.Variable), not a raw JAX array. Operations like param.reshape, param.T, or slicing like param[i,:,:] will fail on this object.
Additionally, checks like if "q_proj" in path and if "embed_tokens" in path are incorrect as path is a tuple of strings and these checks are too broad.
The loop body needs to be updated to correctly handle nnx.Variable objects by first extracting their array value (e.g., value = param[...]) and then performing operations on that value. The checks should also be made more specific, for example by checking path[-2] for layer types or key for embed_tokens.
Here is a suggested correction for the loop body:
tensors = {}
for path, param in model_params:
if "rngs" in path:
continue
if filter_fn is not None and not filter_fn(path):
continue
key = get_param_key(path, prefix=prefix)
value = param[...]
if "experts" in path:
for i in range(config.get_num_experts()):
tensors[get_expert_key(path, i)] = value[i, :, :].T
continue
if path[-2] in {"q_proj", "k_proj", "v_proj"}:
value = value.reshape(value.shape[0], -1)
elif path[-2] == "o_proj":
value = value.reshape(-1, value.shape[-1])
tensors[key] = value if "embed_tokens" in key else value.T
skyrl-tx/tx/utils/models.py
Outdated
| idx = get_adapter_idx(path, adapter_index) | ||
| if key == "lora_A": | ||
| return p[idx][..., :, :rank] | ||
| return p[idx][..., :rank, :] |
There was a problem hiding this comment.
The original assertion on parameter dimensions was correctly removed as it didn't account for stacked weights. However, it would be good to re-introduce a corrected assertion to maintain robustness. You can assert on the shape of the sliced parameter, which should have a consistent number of dimensions (3 for linear, 4 for expert LoRA) regardless of whether the original parameter was stacked.
This helps catch potential issues with parameter shapes early.
| idx = get_adapter_idx(path, adapter_index) | |
| if key == "lora_A": | |
| return p[idx][..., :, :rank] | |
| return p[idx][..., :rank, :] | |
| idx = get_adapter_idx(path, adapter_index) | |
| sliced_p = p[idx] | |
| assert sliced_p.ndim in {3, 4}, f"LoRA parameters sliced by adapter should have 3 or 4 dimensions, got shape {sliced_p.shape}" | |
| if key == "lora_A": | |
| return sliced_p[..., :, :rank] | |
| return sliced_p[..., :rank, :] |
This is based on all the great work that @raulchen did in #996 and #906, it also fixes the performance regression in decoding vs. the main branch.