diff --git a/atom/config.py b/atom/config.py index 4d23ffa66..e6fc7a423 100644 --- a/atom/config.py +++ b/atom/config.py @@ -359,6 +359,38 @@ def get_quant_config(config: PretrainedConfig) -> QuantizationConfig: } +class EngramConfig(PretrainedConfig): + model_type = "engram_lm" + + def __init__( + self, + vocab_size: int = 128, + hidden_size: int = 128, + num_hidden_layers: int = 4, + num_attention_heads: int = 4, + max_position_embeddings: int = 512, + tie_word_embeddings: bool = True, + engram_config: dict = None, + **kwargs, + ): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.tie_word_embeddings = tie_word_embeddings + self.engram_config = engram_config or {} + + model_config = kwargs.get("model_config", {}) + if model_config: + self.vocab_size = model_config.get("vocab_size", self.vocab_size) + self.hidden_size = model_config.get("hidden_size", self.hidden_size) + self.num_hidden_layers = model_config.get("num_layers", self.num_hidden_layers) + self.num_attention_heads = model_config.get("num_heads", self.num_attention_heads) + self.max_position_embeddings = model_config.get("max_seq_len", self.max_position_embeddings) + + def get_hf_config(model: str) -> PretrainedConfig: config_dict, _ = PretrainedConfig.get_config_dict( model, @@ -371,6 +403,10 @@ def _get_hf_token() -> str | None: return token return None + # Since we don't has config on huggingface, we need to load for our own config + if model_type == "engram_lm": + return EngramConfig.from_pretrained(model) + if model_type in _CONFIG_REGISTRY: config_class = AutoConfig.for_model(_CONFIG_REGISTRY[model_type]) return config_class.from_pretrained( diff --git a/atom/model_engine/llm_engine.py b/atom/model_engine/llm_engine.py index 145470c03..19f351de5 100644 --- a/atom/model_engine/llm_engine.py +++ b/atom/model_engine/llm_engine.py @@ -164,7 +164,10 @@ def postprocess(self, reqs: List[Sequence]): outputs = {} for req in reqs: self.requests.pop(req.id) - output_str = self.tokenizer.decode(req.completion_token_ids) + # this is for our simple trained model engram demo can running + valid_ids = [t for t in req.completion_token_ids if t >= 0] + output_str = self.tokenizer.decode(valid_ids) + # output_str = self.tokenizer.decode(req.completion_token_ids) req.leave_time = time.time() # Calculate TTFT (Time To First Token) and TPOT (Time Per Output Token) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index b1e3f3045..f190f99a9 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -6,6 +6,7 @@ import time from typing import Any, Optional, Union +from atom.model_ops.engram import EngramOp import numpy as np import torch import torch.profiler as torch_profiler @@ -52,6 +53,7 @@ "DeepseekV3ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM", "DeepseekV32ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM", "GptOssForCausalLM": "atom.models.gpt_oss.GptOssForCausalLM", + "EngramForCausalLM": "atom.models.engram.EngramForCausalLM", } # seed = 34567 # np.random.seed(seed) @@ -73,15 +75,51 @@ def __init__(self, max_num_batched_tokens: int, device: torch.device): # Event on the copy stream so we can synchronize the non-blocking copy. self.async_copy_event = torch.cuda.Event() self.async_copy_stream = torch.cuda.Stream() + self.engram_ops = [] self.clean() + + @staticmethod + def compute_engram_embeddings( + input_ids_np: np.ndarray, + engram_ops: list, + layer_id_to_engram_op: dict = None + ) -> dict[int, np.ndarray]: + if not engram_ops: + return {} + + if layer_id_to_engram_op is None: + layer_id_to_engram_op = {op.layer_id: op for op in engram_ops} + + hash_mapping = engram_ops[0].hash_mapping + + hash_results = hash_mapping.hash(input_ids_np) + + embedding_results = {} + for engram_op in engram_ops: + layer_id = engram_op.layer_id + hash_ids = hash_results[layer_id] + embeddings = engram_op.multi_head_embedding.forward_on_cpu(hash_ids) + # Remove batch dimension if batch size is 1: [1, T, num_heads, D] -> [T, num_heads, D] + if embeddings.ndim == 4 and embeddings.shape[0] == 1: + embeddings = embeddings[0] + embedding_results[layer_id] = embeddings + + return embedding_results - def send_to_cpu_async(self, gpu_tensor: torch.Tensor): + def send_to_cpu_async(self, gpu_tensor: torch.Tensor, batch: Optional[ScheduledBatch] = None): default_stream = torch.cuda.current_stream() with torch.cuda.stream(self.async_copy_stream): self.async_copy_stream.wait_stream(default_stream) cpu_tensor = gpu_tensor.to("cpu", non_blocking=True) self.async_copy_event.record(self.async_copy_stream) self.token_ids_cpu.append(cpu_tensor) + + if batch is not None: + current_token_ids = gpu_tensor.tolist() + req_ids = batch.req_ids + current_token_dict = {seq_id: token_id for seq_id, token_id in zip(req_ids, current_token_ids)} + # We compute engram hash after token_ids send to cpu + self._prefetch_engram_hash(batch, current_token_dict) def recv_async_output(self) -> list[int]: for _ in self.token_ids_cpu: @@ -95,6 +133,42 @@ def clean(self): self.prev_batch: Optional[ScheduledBatch] = None + def _prefetch_engram_hash(self, batch: ScheduledBatch, token_ids: dict[int, int]): + try: + engram_ops = self.engram_ops + if not engram_ops: + return + + import threading + hash_mapping = engram_ops[0].hash_mapping + next_input_ids_list = [] + valid_seq_ids = [] + + for seq_id in batch.req_ids: + if seq_id in token_ids: + next_input_ids_list.append(token_ids[seq_id]) + valid_seq_ids.append(seq_id) + + if not next_input_ids_list: + return + + next_input_ids = np.array([next_input_ids_list], dtype=np.int32) + + def compute_and_cache_batch_hash_and_embedding(): + embedding_results = tokenIDProcessor.compute_engram_embeddings( + next_input_ids, engram_ops + ) + + if engram_ops and embedding_results: + engram_ops[0].multi_head_embedding.save_embedding_results( + embedding_results, valid_seq_ids, hash_mapping.layer_ids + ) + + threading.Thread(target=compute_and_cache_batch_hash_and_embedding, daemon=True).start() + + except Exception as e: + logger.warning(e) + def prepare_sampled_ids( self, batch: ScheduledBatch, sampled_token_ids: torch.Tensor ) -> dict[int, int]: @@ -102,10 +176,11 @@ def prepare_sampled_ids( token_ids = sampled_token_ids.tolist() req_ids = batch.req_ids ret = {seq_id: token_id for seq_id, token_id in zip(req_ids, token_ids)} + self._prefetch_engram_hash(batch, ret) ret[-1] = 0 return ret token_ids = self.recv_async_output() - self.send_to_cpu_async(sampled_token_ids) + self.send_to_cpu_async(sampled_token_ids, batch) if self.prev_batch is not None: req_ids = self.prev_batch.req_ids @@ -326,6 +401,8 @@ def __init__(self, rank: int, config: Config): self.config.max_num_batched_tokens, self.device ) self.sampler = Sampler() + + self.engram_buffers = {} if self.config.speculative_config and get_pp_group().is_last_rank: self.drafter = EagleProposer(self.config, self.device, self) self.arange_np = np.arange( @@ -356,9 +433,54 @@ def __init__(self, rank: int, config: Config): torch.set_default_device("cpu") torch.set_default_dtype(default_dtype) + self.tokenID_processor.engram_ops = self.get_engram_op() + + self._init_engram_buffers() if self.config.compilation_config.level == 1: self.model = torch.compile(self.model, fullgraph=True, backend="eager") + + def get_engram_op(self) -> list[EngramOp]: + from atom.model_ops.engram import EngramOp + engram_ops = [] + for module in self.model.modules(): + if isinstance(module, EngramOp): + engram_ops.append(module) + return engram_ops + + def _init_engram_buffers(self): + engram_ops = self.get_engram_op() + if not engram_ops: + return + + for engram_op in engram_ops: + engram_op.multi_head_embedding.init_cpu_embedding() + + try: + hash_mapping = engram_ops[0].hash_mapping + max_tokens = self.config.max_num_batched_tokens + + num_hash_heads = (hash_mapping.max_ngram_size - 1) * hash_mapping.n_head_per_ngram + embed_dim = hash_mapping.n_embed_per_ngram // hash_mapping.n_head_per_ngram + + for layer_id in hash_mapping.layer_ids: + embedding_buffer = CpuGpuBuffer( + max_tokens * num_hash_heads * embed_dim, + dtype=torch.float32, + device=self.device, + with_numpy=True + ) + self.engram_buffers[layer_id] = embedding_buffer + + if not hasattr(self, 'engram_embedding_stream'): + self.engram_embedding_stream = torch.cuda.Stream() + for engram_op in engram_ops: + embedding_buffer = self.engram_buffers[engram_op.layer_id] + engram_op.embedding_buffer = embedding_buffer.gpu.reshape(max_tokens, num_hash_heads, embed_dim) + engram_op.embedding_stream = self.engram_embedding_stream + + except Exception as e: + logger.warning(e) def is_deepseek_mla(self) -> bool: if not hasattr(self.hf_text_config, "model_type"): @@ -894,6 +1016,106 @@ def prepare_sample(self, batch: ScheduledBatch) -> torch.Tensor: buffer = self.forward_vars["temperatures"] buffer.np[:bs] = batch.temperatures return buffer.copy_to_gpu(bs) + + def prepare_engram_embeddings_to_gpu(self, batch: ScheduledBatch): + """Transfer precomputed embeddings to GPU. + """ + if not self.engram_buffers: + return + + try: + engram_ops = self.get_engram_op() + if not engram_ops: + return + + req_ids = batch.req_ids + hash_mapping = engram_ops[0].hash_mapping + + layer_id_to_engram_op = {op.layer_id: op for op in engram_ops} + + cached_embeddings_per_layer = {layer_id: [] for layer_id in hash_mapping.layer_ids} + seqs_need_compute = [] + + for seq_id in req_ids: + seq_embeddings = {} + + for layer_id in hash_mapping.layer_ids: + engram_op = layer_id_to_engram_op.get(layer_id) + if engram_op is None: + continue + + cached_embedding = engram_op.multi_head_embedding.get_cached_embedding(layer_id, seq_id) + + # print(f"cached_embedding: {cached_embedding.shape}") + if cached_embedding is not None and cached_embedding.size > 0: + seq_embeddings[layer_id] = cached_embedding + + if len(seq_embeddings) == len(hash_mapping.layer_ids): + for layer_id in hash_mapping.layer_ids: + embedding = seq_embeddings[layer_id] + cached_embeddings_per_layer[layer_id].append(embedding) + else: + seqs_need_compute.append(seq_id) + + # It means has no prev batch to precompute, it's new prefill case, compute here + if seqs_need_compute: + # Use cpu input_ids to avoid GPU->CPU copy + total_tokens = batch.total_tokens_num + input_ids_np = self.tokenID_processor.input_ids.np[:total_tokens].reshape(1, -1) + + embedding_results = tokenIDProcessor.compute_engram_embeddings( + input_ids_np, engram_ops, layer_id_to_engram_op + ) + + for layer_id in hash_mapping.layer_ids: + if layer_id not in embedding_results: + continue + + embeddings = embedding_results[layer_id] + cached_embeddings_per_layer[layer_id].append(embeddings) + + first_layer_embeddings = cached_embeddings_per_layer[hash_mapping.layer_ids[0]] + if not first_layer_embeddings or len(first_layer_embeddings) == 0: + return + + for emb in first_layer_embeddings: + if emb is None or emb.size == 0: + return + + default_stream = torch.cuda.current_stream() + total_tokens = sum(e.shape[0] for e in first_layer_embeddings) + + if total_tokens == 0: + return + + with torch.cuda.stream(self.engram_embedding_stream): + self.engram_embedding_stream.wait_stream(default_stream) + + for layer_id in hash_mapping.layer_ids: + if layer_id not in self.engram_buffers: + continue + + layer_embeddings = cached_embeddings_per_layer[layer_id] + if not layer_embeddings or len(layer_embeddings) == 0: + continue + + valid_embeddings = [e for e in layer_embeddings if e is not None and e.size > 0] + if not valid_embeddings: + continue + + embeddings_np = np.concatenate(valid_embeddings, axis=0) + + # Write to CPU buffer and copy to GPU asynchronously + embedding_buffer = self.engram_buffers[layer_id] + num_hash_heads = embeddings_np.shape[1] + embed_dim = embeddings_np.shape[2] + total_elements = total_tokens * num_hash_heads * embed_dim + embedding_buffer.np[:total_elements] = embeddings_np.flatten() + embedding_buffer.copy_to_gpu(total_elements) + + except Exception as e: + logger.warning(e) + def prepare_model(self, batch: ScheduledBatch): total_tokens_num = batch.total_tokens_num @@ -902,7 +1124,8 @@ def prepare_model(self, batch: ScheduledBatch): input_ids = self.tokenID_processor.prepare_input_ids(batch) # if self.rank == 0: # print(f"input_ids: {input_ids}") - + # Prepare engram embeddings using CPU-side input_ids + self.prepare_engram_embeddings_to_gpu(batch) self.prepare_intputs(batch) temperatures = self.prepare_sample(batch) return ( diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 8cbfd8e36..612f445b4 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -87,6 +87,9 @@ def load_model( continue if name.endswith("kv_scale"): continue + # only for our simple trained model engram demo can running + if ".causal_mask" in name or ".offsets" in name: + continue if spec_decode: spec_layer = get_spec_layer_idx_from_weight_name(hf_config, name) if spec_layer is None: diff --git a/atom/model_ops/engram.py b/atom/model_ops/engram.py new file mode 100644 index 000000000..58b47f3bb --- /dev/null +++ b/atom/model_ops/engram.py @@ -0,0 +1,552 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + + +import math +import threading +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from atom.utils.forward_context import get_forward_context +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from sympy import isprime + + +# Global prefetch cache for engram hash +# {seq_id: {layer_id: np.ndarray}} +_global_prefetch_cache: Dict[int, Dict[int, np.ndarray]] = {} +_global_cache_lock = threading.Lock() + + +def find_next_prime(start, seen_primes): + candidate = start + 1 + while True: + if isprime(candidate) and candidate not in seen_primes: + return candidate + candidate += 1 + + +class CompressedTokenizer: + def __init__(self, tokenizer_name_or_path: Optional[str] = None): + self.tokenizer_name_or_path = tokenizer_name_or_path + self.lookup_table = None + self.num_new_token = 0 + if tokenizer_name_or_path is not None: + self._build_lookup_table() + + def _build_lookup_table(self): + try: + from transformers import AutoTokenizer + from tokenizers import normalizers + from tokenizers.normalizers import Regex + + self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_name_or_path, + trust_remote_code=True + ) + + SENTINEL = "\uE000" + self.normalizer = normalizers.Sequence([ + normalizers.NFKC(), + normalizers.NFD(), + normalizers.StripAccents(), + normalizers.Lowercase(), + normalizers.Replace(Regex(r"[ \t\r\n]+"), " "), + normalizers.Replace(Regex(r"^ $"), SENTINEL), + normalizers.Strip(), + normalizers.Replace(SENTINEL, " "), + ]) + + old2new = {} + key2new = {} + new_tokens = [] + + vocab_size = len(self.tokenizer) + for tid in range(vocab_size): + text = self.tokenizer.decode([tid], skip_special_tokens=False) + + # Handle special tokens + if "�" in text: + key = self.tokenizer.convert_ids_to_tokens(tid) + else: + norm = self.normalizer.normalize_str(text) + key = norm if norm else text + + nid = key2new.get(key) + if nid is None: + nid = len(new_tokens) + key2new[key] = nid + new_tokens.append(key) + old2new[tid] = nid + + # Create numpy lookup array + lookup = np.empty(vocab_size, dtype=np.int64) + for tid in range(vocab_size): + lookup[tid] = old2new[tid] + + self.lookup_table = lookup + self.num_new_token = len(new_tokens) + + except Exception as e: + print(f"Warning: Failed to build compressed tokenizer lookup table: {e}") + self.lookup_table = None + self.num_new_token = 128000 # Default + + def __len__(self): + return self.num_new_token + + def _compress(self, input_ids: np.ndarray) -> np.ndarray: + if self.lookup_table is None: + return input_ids + + arr = np.asarray(input_ids, dtype=np.int64) + pos_mask = arr >= 0 + out = arr.copy() + valid_ids = arr[pos_mask] + out[pos_mask] = self.lookup_table[valid_ids] + return out + + def __call__(self, input_ids: np.ndarray) -> np.ndarray: + return self._compress(input_ids) + + +@dataclass +class EngramConfig: + """Configuration for Engram module.""" + engram_vocab_size: List[int] = field(default_factory=lambda: [129280*5, 129280*5]) + max_ngram_size: int = 3 + n_embed_per_ngram: int = 512 + n_head_per_ngram: int = 8 + layer_ids: List[int] = field(default_factory=lambda: [1, 3]) + pad_id: int = 0 + seed: int = 42 + kernel_size: int = 4 + tokenizer_name_or_path: Optional[str] = None # For CompressedTokenizer + kernel_size: int = 4 + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "EngramConfig": + valid_fields = {f.name for f in cls.__dataclass_fields__.values()} + filtered_dict = {k: v for k, v in d.items() if k in valid_fields} + return cls(**filtered_dict) + +@dataclass +class BackBoneConfig: + hidden_size: int = 1024 + hc_mult: int = 4 + vocab_size: int = 129280 + num_layers: int = 30 + +backbone_config = BackBoneConfig() + +class NgramHashMapping: + def __init__( + self, + engram_vocab_size: List[int], + max_ngram_size: int, + n_embed_per_ngram: int, + n_head_per_ngram: int, + layer_ids: List[int], + tokenizer_name_or_path: Optional[str], + pad_id: int, + seed: int, + ): + self.vocab_size_per_ngram = engram_vocab_size + self.max_ngram_size = max_ngram_size + self.n_embed_per_ngram = n_embed_per_ngram + self.n_head_per_ngram = n_head_per_ngram + self.pad_id = pad_id + self.layer_ids = layer_ids + + self.compressed_tokenizer = CompressedTokenizer( + tokenizer_name_or_path=tokenizer_name_or_path + ) + self.tokenizer_vocab_size = len(self.compressed_tokenizer) + + max_long = np.iinfo(np.int64).max + M_max = int(max_long // max(self.tokenizer_vocab_size, 1)) + half_bound = max(1, M_max // 2) + PRIME_1 = 10007 + + self.layer_multipliers = {} + for layer_id in self.layer_ids: + base_seed = int(seed + PRIME_1 * int(layer_id)) + g = np.random.default_rng(base_seed) + r = g.integers( + low=0, + high=half_bound, + size=(self.max_ngram_size,), + dtype=np.int64 + ) + multipliers = r * 2 + 1 + self.layer_multipliers[layer_id] = multipliers + + self.vocab_size_across_layers = self.calculate_vocab_size_across_layers() + + def calculate_vocab_size_across_layers(self) -> Dict[int, List[List[int]]]: + seen_primes = set() + vocab_size_across_layers = {} + + for layer_id in self.layer_ids: + all_ngram_vocab_sizes = [] + for ngram in range(2, self.max_ngram_size + 1): + current_ngram_heads_sizes = [] + + vocab_size = self.vocab_size_per_ngram[ngram - 2] + num_head = self.n_head_per_ngram + current_prime_search_start = vocab_size - 1 + + for _ in range(num_head): + found_prime = find_next_prime( + current_prime_search_start, + seen_primes + ) + seen_primes.add(found_prime) + current_ngram_heads_sizes.append(found_prime) + current_prime_search_start = found_prime + + all_ngram_vocab_sizes.append(current_ngram_heads_sizes) + vocab_size_across_layers[layer_id] = all_ngram_vocab_sizes + + return vocab_size_across_layers + + def _get_ngram_hashes( + self, + input_ids: np.ndarray, + layer_id: int, + ) -> np.ndarray: + x = np.asarray(input_ids, dtype=np.int64) + B, T = x.shape + multipliers = self.layer_multipliers[layer_id] + + def shift_k(k: int) -> np.ndarray: + if k == 0: + return x + shifted = np.pad( + x, ((0, 0), (k, 0)), + mode='constant', + constant_values=self.pad_id + )[:, :T] + return shifted + + base_shifts = [shift_k(k) for k in range(self.max_ngram_size)] + all_hashes = [] + + for n in range(2, self.max_ngram_size + 1): + n_gram_index = n - 2 + tokens = base_shifts[:n] + + mix = tokens[0] * multipliers[0] + for k in range(1, n): + mix = np.bitwise_xor(mix, tokens[k] * multipliers[k]) + + head_vocab_sizes = self.vocab_size_across_layers[layer_id][n_gram_index] + + for j in range(self.n_head_per_ngram): + mod = int(head_vocab_sizes[j]) + head_hash = mix % mod + all_hashes.append(head_hash.astype(np.int64, copy=False)) + + return np.stack(all_hashes, axis=2) + + def hash(self, input_ids: np.ndarray) -> Dict[int, np.ndarray]: + input_ids = self.compressed_tokenizer(input_ids) + hash_ids_for_all_layers = {} + for layer_id in self.layer_ids: + hash_ids_for_all_layers[layer_id] = self._get_ngram_hashes( + input_ids, layer_id=layer_id + ) + return hash_ids_for_all_layers + + def hash_single_layer(self, input_ids: np.ndarray, layer_id: int) -> np.ndarray: + input_ids = self.compressed_tokenizer(input_ids) + return self._get_ngram_hashes(input_ids, layer_id) + + # def cache_batch_hash_results(self, hash_results: Dict[int, np.ndarray], seq_ids: List[int]): + # global _global_prefetch_cache, _global_cache_lock + # with _global_cache_lock: + # for i, seq_id in enumerate(seq_ids): + # if seq_id not in _global_prefetch_cache: + # _global_prefetch_cache[seq_id] = {} + # for layer_id in self.layer_ids: + # # Extract hash for token i: [1, 1, num_heads] + # token_hash = hash_results[layer_id][:, i:i+1, :] + # _global_prefetch_cache[seq_id][layer_id] = token_hash + + # def get_cached_hash(self, layer_id: int, seq_id: int) -> Optional[np.ndarray]: + # global _global_prefetch_cache, _global_cache_lock + # with _global_cache_lock: + # if seq_id in _global_prefetch_cache and layer_id in _global_prefetch_cache[seq_id]: + # hash_result = _global_prefetch_cache[seq_id][layer_id] + # # Remove from cache after use + # del _global_prefetch_cache[seq_id][layer_id] + # if not _global_prefetch_cache[seq_id]: + # del _global_prefetch_cache[seq_id] + # return hash_result + # return None + + # def clear_cache(self, seq_id: Optional[int] = None): + # """Clear global cache""" + # global _global_prefetch_cache, _global_cache_lock + # with _global_cache_lock: + # if seq_id is not None: + # _global_prefetch_cache.pop(seq_id, None) + # else: + # _global_prefetch_cache.clear() + + +class MultiHeadEmbedding(nn.Module): + + def __init__(self, list_of_N: List[int], D: int): + super().__init__() + self.num_heads = len(list_of_N) + self.D = D + + offsets = [0] + for n in list_of_N[:-1]: + offsets.append(offsets[-1] + n) + self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long)) + + total_vocab = sum(list_of_N) + self.embedding = nn.Embedding(total_vocab, D) + nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02) + + def init_cpu_embedding(self): + self.cpu_embedding = self.embedding.weight.detach().cpu().to(torch.float32).numpy() + self.cpu_offsets = self.offsets.cpu().numpy() + + def forward_on_cpu(self, hash_ids: np.ndarray) -> np.ndarray: + shifted_ids = hash_ids + self.cpu_offsets[None, None, :] + # shifted_ids = np.clip(shifted_ids, 0, self.embedding.num_embeddings - 1) + embeddings = self.cpu_embedding[shifted_ids] + return embeddings.astype(np.float32) + + def forward(self, hash_ids: torch.Tensor) -> torch.Tensor: + shifted_ids = hash_ids + self.offsets + # TODO: remove clamp in real model + # shifted_ids = shifted_ids.clamp(0, self.embedding.num_embeddings - 1) + return self.embedding(shifted_ids) + + def save_embedding_results(self, embeddings: np.ndarray, seq_ids: List[int], layer_ids: List[int]): + global _global_prefetch_cache, _global_cache_lock + with _global_cache_lock: + for i, seq_id in enumerate(seq_ids): + if seq_id not in _global_prefetch_cache: + _global_prefetch_cache[seq_id] = {} + for layer_id in layer_ids: + # embeddings[layer_id] shape: [B, T, num_heads, D] + # (B, T, 16, 64), 16 = 8 + 8 for 2 gram + 3 gram + # Extract i-th sequence: [1, T, num_heads, D] + seq_embedding = embeddings[layer_id][i:i+1] + _global_prefetch_cache[seq_id][layer_id] = seq_embedding + + def get_cached_embedding(self, layer_id: int, seq_id: int) -> Optional[np.ndarray]: + global _global_prefetch_cache, _global_cache_lock + with _global_cache_lock: + if seq_id in _global_prefetch_cache and layer_id in _global_prefetch_cache[seq_id]: + embedding = _global_prefetch_cache[seq_id][layer_id] + # Remove batch dimension: [1, T, num_heads, D] -> [T, num_heads, D] + if embedding.ndim == 4 and embedding.shape[0] == 1: + embedding = embedding[0] + return embedding + return None + +class ShortConv(nn.Module): + def __init__( + self, + hidden_size: int, + kernel_size: int = 4, + dilation: int = 1, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=hidden_size, + out_channels=hidden_size, + kernel_size=kernel_size, + groups=hidden_size, # Depthwise + bias=False, + padding=(kernel_size - 1) * dilation, + dilation=dilation, + ) + self.norm = nn.LayerNorm(hidden_size) + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, D = x.shape + x_norm = self.norm(x) + x_bct = x_norm.transpose(1, 2) + y_bct = self.conv(x_bct) + y_bct = y_bct[..., :T] + y = self.act(y_bct.transpose(1, 2)) + return y + + +# class ShortConv(nn.Module): +# def __init__( +# self, +# hidden_size: int, +# kernel_size: int = 4, +# dilation: int = 1, +# norm_eps: float = 1e-5, +# hc_mult: int = 4, +# activation: bool = True, +# ): +# super().__init__() +# self.hc_mult = hc_mult +# self.activation = activation + +# total_channels = hidden_size * hc_mult +# self.conv = nn.Conv1d( +# in_channels=total_channels, +# out_channels=total_channels, +# kernel_size=kernel_size, +# groups=total_channels, +# bias=False, +# padding=(kernel_size - 1) * dilation, +# dilation=dilation, +# ) + +# self.norms = nn.ModuleList([ +# nn.LayerNorm(hidden_size) +# for _ in range(hc_mult) +# ]) + +# if self.activation: +# self.act_fn = nn.SiLU() + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# """ +# Input: (B,L,HC_MULT,D) +# Output: (B,L,HC_MULT,D) +# """ +# B, T, G, C = x.shape + +# assert G == self.hc_mult, f"Input groups {G} != hc_mult {self.hc_mult}" + +# normed_chunks = [] +# for i in range(G): +# chunk = x[:, :, i, :] +# normed_chunks.append(self.norms[i](chunk)) + +# x_norm = torch.cat(normed_chunks, dim=-1) +# x_bct = x_norm.transpose(1, 2) +# y_bct = self.conv(x_bct) +# y_bct = y_bct[..., :T] + +# if self.activation: +# y_bct = self.act_fn(y_bct) +# y = y_bct.transpose(1, 2).view(B, T, G, C).contiguous() + +# return y + + +class EngramOp(nn.Module): + def __init__( + self, + layer_id: int, + hidden_size: int, + config: EngramConfig, + ): + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + self.config = config + + self.hash_mapping = NgramHashMapping( + engram_vocab_size=config.engram_vocab_size, + max_ngram_size=config.max_ngram_size, + n_embed_per_ngram=config.n_embed_per_ngram, + n_head_per_ngram=config.n_head_per_ngram, + layer_ids=config.layer_ids, + tokenizer_name_or_path=config.tokenizer_name_or_path, + pad_id=config.pad_id, + seed=config.seed, + ) + + self.multi_head_embedding = MultiHeadEmbedding( + list_of_N=[x for y in self.hash_mapping.vocab_size_across_layers[layer_id] for x in y], + D=config.n_embed_per_ngram // config.n_head_per_ngram, + ) + + self.short_conv = ShortConv( + hidden_size=hidden_size, + kernel_size=config.kernel_size, + dilation=config.max_ngram_size, + ) + + # self.short_conv = ShortConv( + # hidden_size = backbone_config.hidden_size, + # kernel_size = config.kernel_size, + # dilation = config.max_ngram_size, + # hc_mult = backbone_config.hc_mult, + # ) + + engram_hidden_size = (config.max_ngram_size - 1) * config.n_embed_per_ngram + self.value_proj = nn.Linear(engram_hidden_size, hidden_size) + self.key_proj = nn.Linear(engram_hidden_size, hidden_size) + + self.key_norm = nn.LayerNorm(hidden_size) + self.query_norm = nn.LayerNorm(hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states: [B, T, D] + input_ids: [B, T] + Returns: + output: [B, T, D] + """ + B, T = input_ids.shape + num_tokens = B * T + + if (hasattr(self, 'embedding_buffer') and self.embedding_buffer is not None): + if hasattr(self, 'embedding_stream') and self.embedding_stream is not None: + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(self.embedding_stream) + + # embedding_buffer shape: [max_tokens, num_heads, D] + embeddings_from_buffer = self.embedding_buffer[:num_tokens] + num_heads = embeddings_from_buffer.shape[1] + embed_dim = embeddings_from_buffer.shape[2] + # Reshape to [B, T, num_heads, D] + embeddings_from_buffer = embeddings_from_buffer.reshape(B, T, num_heads, embed_dim) + # Flatten to [B, T, num_heads * D] + embeddings = embeddings_from_buffer.flatten(start_dim=-2) + embeddings = embeddings.to(hidden_states.dtype) + + # print("input_ids in engram: ", input_ids) + # input_ids_np = input_ids.cpu().numpy() + # hash_ids_np = self.hash_mapping.hash(input_ids_np)[self.layer_id] + # hash_ids = torch.from_numpy(hash_ids_np).to(hidden_states.device) + # embeddings2 = self.multi_head_embedding(hash_ids).flatten(start_dim=-2) + + # print(f"hash_ids: {hash_ids}") + # print(f"if same: {torch.allclose(embeddings, embeddings2)}") + # print(f"diff: {embeddings - embeddings2}") + + else: + input_ids_np = input_ids.cpu().numpy() + hash_ids_np = self.hash_mapping.hash(input_ids_np)[self.layer_id] + hash_ids = torch.from_numpy(hash_ids_np).to(hidden_states.device) + embeddings = self.multi_head_embedding(hash_ids).flatten(start_dim=-2) + + key = self.key_norm(self.key_proj(embeddings)) + value = self.value_proj(embeddings) + + query = self.query_norm(hidden_states) + + gate = (key * query).sum(dim=-1) / math.sqrt(self.hidden_size) + gate = gate.abs().clamp_min(1e-6).sqrt() * gate.sign() + gate = gate.sigmoid().unsqueeze(-1) + + gated_value = gate * value + output = gated_value + self.short_conv(gated_value) + + return output diff --git a/atom/models/engram.py b/atom/models/engram.py new file mode 100644 index 000000000..c78162b21 --- /dev/null +++ b/atom/models/engram.py @@ -0,0 +1,379 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +""" +Engram Language Model for ATOM Inference +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Any, Dict, List, Optional, Union + +from aiter.dist.parallel_state import get_pp_group +from atom.config import Config, QuantizationConfig +from atom.model_ops.engram import EngramOp, EngramConfig +from atom.models.utils import ( + IntermediateTensors, + make_empty_intermediate_tensors_factory, + PPMissingLayer, +) +from atom.utils.decorators import support_torch_compile + +EngramModuleConfig = EngramConfig + +class EngramAttention(nn.Module): + """Causal self-attention for Engram model.""" + def __init__( + self, + hidden_size: int, + num_heads: int, + max_seq_len: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + assert hidden_size % num_heads == 0 + + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.hidden_size = hidden_size + + self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size) + self.out_proj = nn.Linear(hidden_size, hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, T, D = x.shape + + qkv = self.qkv_proj(x) + q, k, v = qkv.split(D, dim=-1) + + q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) + + out = F.scaled_dot_product_attention(q, k, v, is_causal=True) + + out = out.transpose(1, 2).contiguous().view(B, T, D) + return self.out_proj(out) + +class EngramMLP(nn.Module): + """Feed-forward network for Engram model.""" + def __init__( + self, + hidden_size: int, + intermediate_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + if intermediate_size is None: + intermediate_size = 4 * hidden_size + + self.fc1 = nn.Linear(hidden_size, intermediate_size) + self.fc2 = nn.Linear(intermediate_size, hidden_size) + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc2(self.act(self.fc1(x))) + + +class EngramDecoderLayer(nn.Module): + def __init__( + self, + layer_id: int, + hidden_size: int, + num_heads: int, + max_seq_len: int, + engram_config: EngramModuleConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.layer_id = layer_id + self.hidden_size = hidden_size + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.attn = EngramAttention( + hidden_size=hidden_size, + num_heads=num_heads, + max_seq_len=max_seq_len, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + self.ffn = EngramMLP( + hidden_size=hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.ffn", + ) + + self.engram = None + if layer_id in engram_config.layer_ids: + self.engram = EngramOp(layer_id, hidden_size, engram_config) + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Match original TransformerBlock logic exactly: + - Engram augmentation (before attention) + - Attention with residual + - FFN with residual + """ + x = hidden_states + + if self.engram is not None: + engram_out = self.engram(x, input_ids) + x = x + engram_out + + x = x + self.attn(self.norm1(x)) + + x = x + self.ffn(self.norm2(x)) + return x, None + + +@support_torch_compile +class EngramModel(nn.Module): + """Engram language model backbone.""" + def __init__( + self, + atom_config: Config, + prefix: str = "", + ): + super().__init__() + config = atom_config.hf_config + self.config = config + + self.hidden_size = getattr(config, 'hidden_size', 128) + self.num_layers = getattr(config, 'num_hidden_layers', 4) + self.num_heads = getattr(config, 'num_attention_heads', 4) + self.vocab_size = getattr(config, 'vocab_size', 128) + self.max_seq_len = getattr(config, 'max_position_embeddings', 256) + + engram_config_dict = getattr(config, 'engram_config', None) + if engram_config_dict: + self.engram_config = EngramModuleConfig.from_dict(engram_config_dict) + else: + self.engram_config = EngramModuleConfig() + + if get_pp_group().is_first_rank: + self.token_embedding = nn.Embedding(self.vocab_size, self.hidden_size) + else: + self.token_embedding = PPMissingLayer() + + self.position_embedding = nn.Embedding(self.max_seq_len, self.hidden_size) + + self.blocks = nn.ModuleList([ + EngramDecoderLayer( + layer_id=i, + hidden_size=self.hidden_size, + num_heads=self.num_heads, + max_seq_len=self.max_seq_len, + engram_config=self.engram_config, + prefix=f"{prefix}.blocks.{i}", + ) + for i in range(self.num_layers) + ]) + self.start_layer = 0 + self.end_layer = self.num_layers + + if get_pp_group().is_last_rank: + self.norm = nn.LayerNorm(self.hidden_size) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.hidden_size + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.token_embedding(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if input_ids.dim() == 2: + input_ids = input_ids.flatten() + if positions.dim() == 2: + positions = positions.flatten() + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + # Use clamped values for embedding lookup without modifying original tensor + vocab_size = self.token_embedding.num_embeddings + input_ids_clamped = torch.clamp(input_ids, 0, vocab_size - 1) + hidden_states = self.get_input_embeddings(input_ids_clamped) # [T, D] + + max_pos = self.position_embedding.num_embeddings + positions_clamped = torch.clamp(positions, 0, max_pos - 1) + pos_emb = self.position_embedding(positions_clamped) # [T, D] + hidden_states = hidden_states + pos_emb + + hidden_states = hidden_states.unsqueeze(0) + input_ids = input_ids.unsqueeze(0) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for layer in self.blocks[self.start_layer:self.end_layer]: + hidden_states, _ = layer(hidden_states, input_ids) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": None + }) + + hidden_states = self.norm(hidden_states) + if hidden_states.dim() == 3: + hidden_states = hidden_states.squeeze(0) + + return hidden_states + +class EngramForCausalLM(nn.Module): + """Engram model for causal language modeling.""" + packed_modules_mapping = {} + + def __init__( + self, + atom_config: Config, + prefix: str = "", + ): + super().__init__() + config = atom_config.hf_config + self.config = config + + hidden_size = getattr(config, 'hidden_size', 128) + num_layers = getattr(config, 'num_hidden_layers', 4) + num_heads = getattr(config, 'num_attention_heads', 4) + vocab_size = getattr(config, 'vocab_size', 128) + max_seq_len = getattr(config, 'max_position_embeddings', 256) + + # Load engram config + engram_config_dict = getattr(config, 'engram_config', None) + if engram_config_dict: + engram_config = EngramModuleConfig.from_dict(engram_config_dict) + else: + engram_config = EngramModuleConfig() + + # Token embedding (directly on this class, no 'model.' prefix) + if get_pp_group().is_first_rank: + self.token_embedding = nn.Embedding(vocab_size, hidden_size) + else: + self.token_embedding = PPMissingLayer() + + self.position_embedding = nn.Embedding(max_seq_len, hidden_size) + + self.blocks = nn.ModuleList([ + EngramDecoderLayer( + layer_id=i, + hidden_size=hidden_size, + num_heads=num_heads, + max_seq_len=max_seq_len, + engram_config=engram_config, + prefix=f"{prefix}.blocks.{i}" if prefix else f"blocks.{i}", + ) + for i in range(num_layers) + ]) + self.start_layer = 0 + self.end_layer = num_layers + + if get_pp_group().is_last_rank: + self.norm = nn.LayerNorm(hidden_size) + else: + self.norm = PPMissingLayer() + + if get_pp_group().is_last_rank: + self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) + + tie_weights = getattr(config, 'tie_word_embeddings', True) + if tie_weights and hasattr(self, 'token_embedding') and not isinstance(self.token_embedding, PPMissingLayer): + self.lm_head.weight = self.token_embedding.weight + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], hidden_size + ) + + self._hidden_size = hidden_size + self._max_seq_len = max_seq_len + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.token_embedding(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + # # Flatten to 1D if needed + if input_ids.dim() == 2: + input_ids = input_ids.flatten() + if positions.dim() == 2: + positions = positions.flatten() + + T = input_ids.shape[0] + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + vocab_size = self.token_embedding.num_embeddings + input_ids_clamped = torch.clamp(input_ids, 0, vocab_size - 1) + hidden_states = self.get_input_embeddings(input_ids_clamped) # [T, D] + + max_pos = self.position_embedding.num_embeddings + positions_clamped = torch.clamp(positions, 0, max_pos - 1) + pos_emb = self.position_embedding(positions_clamped) # [T, D] + hidden_states = hidden_states + pos_emb + + hidden_states = hidden_states.unsqueeze(0) + input_ids = input_ids.unsqueeze(0) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for layer in self.blocks[self.start_layer:self.end_layer]: + hidden_states, _ = layer(hidden_states, input_ids) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": None + }) + + hidden_states = self.norm(hidden_states) + if hidden_states.dim() == 3: + hidden_states = hidden_states.squeeze(0) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + from atom.utils.forward_context import get_forward_context + forward_ctx = get_forward_context() + + if forward_ctx and forward_ctx.context and forward_ctx.context.is_prefill: + attn_meta = forward_ctx.attn_metadata + if attn_meta and hasattr(attn_meta, 'cu_seqlens_q') and attn_meta.cu_seqlens_q is not None: + last_indices = attn_meta.cu_seqlens_q[1:] - 1 + # Clamp indices to valid range + last_indices = last_indices.clamp(0, hidden_states.shape[0] - 1) + hidden_states = hidden_states[last_indices].contiguous() + + logits = self.lm_head(hidden_states) + return logits