Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,38 @@ def get_quant_config(config: PretrainedConfig) -> QuantizationConfig:
}


class EngramConfig(PretrainedConfig):
model_type = "engram_lm"

def __init__(
self,
vocab_size: int = 128,
hidden_size: int = 128,
num_hidden_layers: int = 4,
num_attention_heads: int = 4,
max_position_embeddings: int = 512,
tie_word_embeddings: bool = True,
engram_config: dict = None,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.tie_word_embeddings = tie_word_embeddings
self.engram_config = engram_config or {}

model_config = kwargs.get("model_config", {})
if model_config:
self.vocab_size = model_config.get("vocab_size", self.vocab_size)
self.hidden_size = model_config.get("hidden_size", self.hidden_size)
self.num_hidden_layers = model_config.get("num_layers", self.num_hidden_layers)
self.num_attention_heads = model_config.get("num_heads", self.num_attention_heads)
self.max_position_embeddings = model_config.get("max_seq_len", self.max_position_embeddings)


def get_hf_config(model: str) -> PretrainedConfig:
config_dict, _ = PretrainedConfig.get_config_dict(
model,
Expand All @@ -371,6 +403,10 @@ def _get_hf_token() -> str | None:
return token
return None

# Since we don't has config on huggingface, we need to load for our own config
if model_type == "engram_lm":
return EngramConfig.from_pretrained(model)

if model_type in _CONFIG_REGISTRY:
config_class = AutoConfig.for_model(_CONFIG_REGISTRY[model_type])
return config_class.from_pretrained(
Expand Down
5 changes: 4 additions & 1 deletion atom/model_engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def postprocess(self, reqs: List[Sequence]):
outputs = {}
for req in reqs:
self.requests.pop(req.id)
output_str = self.tokenizer.decode(req.completion_token_ids)
# this is for our simple trained model engram demo can running
valid_ids = [t for t in req.completion_token_ids if t >= 0]
output_str = self.tokenizer.decode(valid_ids)
# output_str = self.tokenizer.decode(req.completion_token_ids)
req.leave_time = time.time()

# Calculate TTFT (Time To First Token) and TPOT (Time Per Output Token)
Expand Down
229 changes: 226 additions & 3 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
from typing import Any, Optional, Union

from atom.model_ops.engram import EngramOp
import numpy as np
import torch
import torch.profiler as torch_profiler
Expand Down Expand Up @@ -52,6 +53,7 @@
"DeepseekV3ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM",
"DeepseekV32ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM",
"GptOssForCausalLM": "atom.models.gpt_oss.GptOssForCausalLM",
"EngramForCausalLM": "atom.models.engram.EngramForCausalLM",
}
# seed = 34567
# np.random.seed(seed)
Expand All @@ -73,15 +75,51 @@ def __init__(self, max_num_batched_tokens: int, device: torch.device):
# Event on the copy stream so we can synchronize the non-blocking copy.
self.async_copy_event = torch.cuda.Event()
self.async_copy_stream = torch.cuda.Stream()
self.engram_ops = []
self.clean()

@staticmethod
def compute_engram_embeddings(
input_ids_np: np.ndarray,
engram_ops: list,
layer_id_to_engram_op: dict = None
) -> dict[int, np.ndarray]:
if not engram_ops:
return {}

if layer_id_to_engram_op is None:
layer_id_to_engram_op = {op.layer_id: op for op in engram_ops}

hash_mapping = engram_ops[0].hash_mapping

hash_results = hash_mapping.hash(input_ids_np)

embedding_results = {}
for engram_op in engram_ops:
layer_id = engram_op.layer_id
hash_ids = hash_results[layer_id]
embeddings = engram_op.multi_head_embedding.forward_on_cpu(hash_ids)
# Remove batch dimension if batch size is 1: [1, T, num_heads, D] -> [T, num_heads, D]
if embeddings.ndim == 4 and embeddings.shape[0] == 1:
embeddings = embeddings[0]
embedding_results[layer_id] = embeddings

return embedding_results

def send_to_cpu_async(self, gpu_tensor: torch.Tensor):
def send_to_cpu_async(self, gpu_tensor: torch.Tensor, batch: Optional[ScheduledBatch] = None):
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.async_copy_stream):
self.async_copy_stream.wait_stream(default_stream)
cpu_tensor = gpu_tensor.to("cpu", non_blocking=True)
self.async_copy_event.record(self.async_copy_stream)
self.token_ids_cpu.append(cpu_tensor)

if batch is not None:
current_token_ids = gpu_tensor.tolist()
req_ids = batch.req_ids
current_token_dict = {seq_id: token_id for seq_id, token_id in zip(req_ids, current_token_ids)}
# We compute engram hash after token_ids send to cpu
self._prefetch_engram_hash(batch, current_token_dict)

def recv_async_output(self) -> list[int]:
for _ in self.token_ids_cpu:
Expand All @@ -95,17 +133,54 @@ def clean(self):

self.prev_batch: Optional[ScheduledBatch] = None

def _prefetch_engram_hash(self, batch: ScheduledBatch, token_ids: dict[int, int]):
try:
engram_ops = self.engram_ops
if not engram_ops:
return

import threading
hash_mapping = engram_ops[0].hash_mapping
next_input_ids_list = []
valid_seq_ids = []

for seq_id in batch.req_ids:
if seq_id in token_ids:
next_input_ids_list.append(token_ids[seq_id])
valid_seq_ids.append(seq_id)

if not next_input_ids_list:
return

next_input_ids = np.array([next_input_ids_list], dtype=np.int32)

def compute_and_cache_batch_hash_and_embedding():
embedding_results = tokenIDProcessor.compute_engram_embeddings(
next_input_ids, engram_ops
)

if engram_ops and embedding_results:
engram_ops[0].multi_head_embedding.save_embedding_results(
embedding_results, valid_seq_ids, hash_mapping.layer_ids
)

threading.Thread(target=compute_and_cache_batch_hash_and_embedding, daemon=True).start()

except Exception as e:
logger.warning(e)

def prepare_sampled_ids(
self, batch: ScheduledBatch, sampled_token_ids: torch.Tensor
) -> dict[int, int]:
if not self.is_deferred_out:
token_ids = sampled_token_ids.tolist()
req_ids = batch.req_ids
ret = {seq_id: token_id for seq_id, token_id in zip(req_ids, token_ids)}
self._prefetch_engram_hash(batch, ret)
ret[-1] = 0
return ret
token_ids = self.recv_async_output()
self.send_to_cpu_async(sampled_token_ids)
self.send_to_cpu_async(sampled_token_ids, batch)

if self.prev_batch is not None:
req_ids = self.prev_batch.req_ids
Expand Down Expand Up @@ -326,6 +401,8 @@ def __init__(self, rank: int, config: Config):
self.config.max_num_batched_tokens, self.device
)
self.sampler = Sampler()

self.engram_buffers = {}
if self.config.speculative_config and get_pp_group().is_last_rank:
self.drafter = EagleProposer(self.config, self.device, self)
self.arange_np = np.arange(
Expand Down Expand Up @@ -356,9 +433,54 @@ def __init__(self, rank: int, config: Config):

torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
self.tokenID_processor.engram_ops = self.get_engram_op()

self._init_engram_buffers()

if self.config.compilation_config.level == 1:
self.model = torch.compile(self.model, fullgraph=True, backend="eager")

def get_engram_op(self) -> list[EngramOp]:
from atom.model_ops.engram import EngramOp
engram_ops = []
for module in self.model.modules():
if isinstance(module, EngramOp):
engram_ops.append(module)
return engram_ops

def _init_engram_buffers(self):
engram_ops = self.get_engram_op()
if not engram_ops:
return

for engram_op in engram_ops:
engram_op.multi_head_embedding.init_cpu_embedding()

try:
hash_mapping = engram_ops[0].hash_mapping
max_tokens = self.config.max_num_batched_tokens

num_hash_heads = (hash_mapping.max_ngram_size - 1) * hash_mapping.n_head_per_ngram
embed_dim = hash_mapping.n_embed_per_ngram // hash_mapping.n_head_per_ngram

for layer_id in hash_mapping.layer_ids:
embedding_buffer = CpuGpuBuffer(
max_tokens * num_hash_heads * embed_dim,
dtype=torch.float32,
device=self.device,
with_numpy=True
)
self.engram_buffers[layer_id] = embedding_buffer

if not hasattr(self, 'engram_embedding_stream'):
self.engram_embedding_stream = torch.cuda.Stream()
for engram_op in engram_ops:
embedding_buffer = self.engram_buffers[engram_op.layer_id]
engram_op.embedding_buffer = embedding_buffer.gpu.reshape(max_tokens, num_hash_heads, embed_dim)
engram_op.embedding_stream = self.engram_embedding_stream

except Exception as e:
logger.warning(e)

def is_deepseek_mla(self) -> bool:
if not hasattr(self.hf_text_config, "model_type"):
Expand Down Expand Up @@ -894,6 +1016,106 @@ def prepare_sample(self, batch: ScheduledBatch) -> torch.Tensor:
buffer = self.forward_vars["temperatures"]
buffer.np[:bs] = batch.temperatures
return buffer.copy_to_gpu(bs)

def prepare_engram_embeddings_to_gpu(self, batch: ScheduledBatch):
"""Transfer precomputed embeddings to GPU.
"""
if not self.engram_buffers:
return

try:
engram_ops = self.get_engram_op()
if not engram_ops:
return

req_ids = batch.req_ids
hash_mapping = engram_ops[0].hash_mapping

layer_id_to_engram_op = {op.layer_id: op for op in engram_ops}

cached_embeddings_per_layer = {layer_id: [] for layer_id in hash_mapping.layer_ids}
seqs_need_compute = []

for seq_id in req_ids:
seq_embeddings = {}

for layer_id in hash_mapping.layer_ids:
engram_op = layer_id_to_engram_op.get(layer_id)
if engram_op is None:
continue

cached_embedding = engram_op.multi_head_embedding.get_cached_embedding(layer_id, seq_id)

# print(f"cached_embedding: {cached_embedding.shape}")
if cached_embedding is not None and cached_embedding.size > 0:
seq_embeddings[layer_id] = cached_embedding

if len(seq_embeddings) == len(hash_mapping.layer_ids):
for layer_id in hash_mapping.layer_ids:
embedding = seq_embeddings[layer_id]
cached_embeddings_per_layer[layer_id].append(embedding)
else:
seqs_need_compute.append(seq_id)

# It means has no prev batch to precompute, it's new prefill case, compute here
if seqs_need_compute:
# Use cpu input_ids to avoid GPU->CPU copy
total_tokens = batch.total_tokens_num
input_ids_np = self.tokenID_processor.input_ids.np[:total_tokens].reshape(1, -1)

embedding_results = tokenIDProcessor.compute_engram_embeddings(
input_ids_np, engram_ops, layer_id_to_engram_op
)

for layer_id in hash_mapping.layer_ids:
if layer_id not in embedding_results:
continue

embeddings = embedding_results[layer_id]
cached_embeddings_per_layer[layer_id].append(embeddings)

first_layer_embeddings = cached_embeddings_per_layer[hash_mapping.layer_ids[0]]
if not first_layer_embeddings or len(first_layer_embeddings) == 0:
return

for emb in first_layer_embeddings:
if emb is None or emb.size == 0:
return

default_stream = torch.cuda.current_stream()
total_tokens = sum(e.shape[0] for e in first_layer_embeddings)

if total_tokens == 0:
return

with torch.cuda.stream(self.engram_embedding_stream):
self.engram_embedding_stream.wait_stream(default_stream)

for layer_id in hash_mapping.layer_ids:
if layer_id not in self.engram_buffers:
continue

layer_embeddings = cached_embeddings_per_layer[layer_id]
if not layer_embeddings or len(layer_embeddings) == 0:
continue

valid_embeddings = [e for e in layer_embeddings if e is not None and e.size > 0]
if not valid_embeddings:
continue

embeddings_np = np.concatenate(valid_embeddings, axis=0)

# Write to CPU buffer and copy to GPU asynchronously
embedding_buffer = self.engram_buffers[layer_id]
num_hash_heads = embeddings_np.shape[1]
embed_dim = embeddings_np.shape[2]
total_elements = total_tokens * num_hash_heads * embed_dim
embedding_buffer.np[:total_elements] = embeddings_np.flatten()
embedding_buffer.copy_to_gpu(total_elements)

except Exception as e:
logger.warning(e)


def prepare_model(self, batch: ScheduledBatch):
total_tokens_num = batch.total_tokens_num
Expand All @@ -902,7 +1124,8 @@ def prepare_model(self, batch: ScheduledBatch):
input_ids = self.tokenID_processor.prepare_input_ids(batch)
# if self.rank == 0:
# print(f"input_ids: {input_ids}")

# Prepare engram embeddings using CPU-side input_ids
self.prepare_engram_embeddings_to_gpu(batch)
self.prepare_intputs(batch)
temperatures = self.prepare_sample(batch)
return (
Expand Down
3 changes: 3 additions & 0 deletions atom/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def load_model(
continue
if name.endswith("kv_scale"):
continue
# only for our simple trained model engram demo can running
if ".causal_mask" in name or ".offsets" in name:
continue
if spec_decode:
spec_layer = get_spec_layer_idx_from_weight_name(hf_config, name)
if spec_layer is None:
Expand Down
Loading