Skip to content

Commit e2cfa34

Browse files
authored
Merge pull request #175 from yanboliang/band-emb
Remove nn.Embedding layer from model size
2 parents 1095a5c + ebd10d3 commit e2cfa34

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

generate.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,18 @@ def _load_model(checkpoint_path, device, precision, use_tp):
243243
model = model.to(device=device, dtype=precision)
244244
return model.eval()
245245

246+
def _get_model_size(model):
247+
model_size = 0
248+
for name, child in model.named_children():
249+
if not isinstance(child, torch.nn.Embedding):
250+
model_size += sum(
251+
[
252+
p.numel() * p.dtype.itemsize
253+
for p in itertools.chain(child.parameters(), child.buffers())
254+
]
255+
)
256+
return model_size
257+
246258
B_INST, E_INST = "[INST]", "[/INST]"
247259

248260
def main(
@@ -299,7 +311,7 @@ def main(
299311
prompt_length = encoded.size(0)
300312

301313
torch.manual_seed(1234)
302-
model_size = sum([p.numel() * p.dtype.itemsize for p in itertools.chain(model.parameters(), model.buffers())])
314+
model_size = _get_model_size(model)
303315
if compile:
304316
if is_speculative and use_tp: # and ("cuda" in device):
305317
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case

0 commit comments

Comments
 (0)