diff --git a/models/hrm/hrm_act_v1.py b/models/hrm/hrm_act_v1.py index e91c7d1a..8304c3f9 100644 --- a/models/hrm/hrm_act_v1.py +++ b/models/hrm/hrm_act_v1.py @@ -110,8 +110,8 @@ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: embed_init_std = 1.0 / self.embed_scale self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) - self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) - self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) + self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) + self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div if self.config.puzzle_emb_ndim > 0: @@ -133,9 +133,14 @@ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)]) self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) + # --- CORRECTED CODE BLOCK --- # Initial states - self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) - self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + h_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1) + self.register_buffer('H_init', h_init_tensor) + + l_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1) + self.register_buffer('L_init', l_init_tensor) + # --- END OF CORRECTION --- # Q head special init # Init Q to (almost) zero for faster learning during bootstrapping diff --git a/models/layers.py b/models/layers.py index 008a172a..62e4b599 100644 --- a/models/layers.py +++ b/models/layers.py @@ -88,8 +88,11 @@ def __init__(self, dim, max_position_embeddings, base, device=None): # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = nn.Buffer(emb.cos(), persistent=False) - self.sin_cached = nn.Buffer(emb.sin(), persistent=False) + + # --- CORRECTED CODE BLOCK --- + self.register_buffer('cos_cached', emb.cos(), persistent=False) + self.register_buffer('sin_cached', emb.sin(), persistent=False) + # --- END OF CORRECTION --- def forward(self): return self.cos_cached, self.sin_cached @@ -142,7 +145,7 @@ def __init__(self, hidden_size: int, expansion: float): inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256) self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False) - self.down_proj = CastedLinear(inter, hidden_size, bias=False) + self.down_proj = CastedLinear(inter, hidden_size, bias=False) def forward(self, x): gate, up = self.gate_up_proj(x).chunk(2, dim=-1) diff --git a/models/sparse_embedding.py b/models/sparse_embedding.py index c701524b..ca64c868 100644 --- a/models/sparse_embedding.py +++ b/models/sparse_embedding.py @@ -13,17 +13,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, ini super().__init__() self.cast_to = cast_to + # --- CORRECTED CODE BLOCK --- # Real Weights - # Truncated LeCun normal init - self.weights = nn.Buffer( - trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True - ) + weights_tensor = trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std) + self.register_buffer('weights', weights_tensor) # Local weights and IDs - # Local embeddings, with gradient, not persistent - self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False) - # Local embedding IDs, not persistent - self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False) + local_weights_tensor = torch.zeros(batch_size, embedding_dim, requires_grad=True) + self.register_buffer('local_weights', local_weights_tensor, persistent=False) + + local_ids_tensor = torch.zeros(batch_size, dtype=torch.int32) + self.register_buffer('local_ids', local_ids_tensor, persistent=False) + # --- END OF CORRECTION --- def forward(self, inputs: torch.Tensor) -> torch.Tensor: if not self.training: @@ -81,7 +82,7 @@ def step(self, closure=None): # type: ignore assert local_weights_grad is not None assert local_ids is not None assert weights is not None - + # Apply SignSGD # Adam ≈ SignSGD if gradient is very sparse _sparse_emb_signsgd_dist( @@ -112,10 +113,10 @@ def _sparse_emb_signsgd_dist( if world_size > 1: all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device) - all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) + all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) dist.all_gather_into_tensor(all_weights_grad, local_weights_grad) - dist.all_gather_into_tensor(all_ids, local_ids) + dist.all_gather_into_tensor(all_ids, local_ids) # Unique grad_ids, inv = all_ids.unique(return_inverse=True)