Skip to content

Commit f9d6041

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e1b834f commit f9d6041

File tree

9 files changed

+33
-28
lines changed

9 files changed

+33
-28
lines changed

litgpt/api.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,9 @@ def distribute(
384384
else:
385385
kv_cache_size = fixed_kv_cache_size
386386
model.set_kv_cache(
387-
batch_size=1, max_seq_length=kv_cache_size, device=fabric.device,
387+
batch_size=1,
388+
max_seq_length=kv_cache_size,
389+
device=fabric.device,
388390
)
389391
self.kv_cache_initialized = True
390392
self.fixed_kv_cache_size = fixed_kv_cache_size
@@ -513,7 +515,9 @@ def generate(
513515
else:
514516
device = self.preprocessor.device
515517
self.model.set_kv_cache(
516-
batch_size=1, max_seq_length=max_returned_tokens, device=device,
518+
batch_size=1,
519+
max_seq_length=max_returned_tokens,
520+
device=device,
517521
)
518522
self.kv_cache_initialized = True
519523

@@ -522,7 +526,9 @@ def generate(
522526
tmp_device = self.model.mha.mask_cache.device
523527
self.model.clear_kv_cache()
524528
self.model.set_kv_cache(
525-
batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device,
529+
batch_size=1,
530+
max_seq_length=max_returned_tokens,
531+
device=tmp_device,
526532
)
527533
else:
528534
for block in self.model.transformer.h:

litgpt/attention.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def __call__(
136136
if use_mask:
137137
# Special case requires building a mask. `mask_cache` is only needed
138138
# then.
139-
assert (
140-
self.mask_cache is not None
141-
), "mask_cache must be given if sliding window attention is used, or if input_pos given and T > 1"
139+
assert self.mask_cache is not None, (
140+
"mask_cache must be given if sliding window attention is used, or if input_pos given and T > 1"
141+
)
142142
if is_causal:
143143
mask = self.mask_cache[:T, :T].view(1, 1, T, T)
144144
is_causal = False
@@ -156,9 +156,7 @@ def __call__(
156156
nh_k = self.config.n_query_groups
157157
q_per_kv = nh_q // nh_k
158158
if q_per_kv > 1:
159-
mask = mask.unsqueeze(2).expand(
160-
-1, -1, q_per_kv, -1, -1
161-
).reshape(B, nh_q, T, -1)
159+
mask = mask.unsqueeze(2).expand(-1, -1, q_per_kv, -1, -1).reshape(B, nh_q, T, -1)
162160

163161
# Efficient attention using Flash Attention CUDA kernels.
164162
# NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled.

litgpt/config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
from copy import deepcopy
44
from dataclasses import dataclass, field
55
from pathlib import Path
6-
from typing import Any, Callable, Literal, Optional, Type, Union, List
6+
from typing import Any, Callable, List, Literal, Optional, Type, Union
77

88
import torch
99
import yaml
1010
from typing_extensions import Self
1111

1212
from litgpt.utils import find_multiple
1313

14-
1514
# See `Config.start_of_layer_hook`. A start of layer hook is called just before
1615
# a layer is computed. The call is `hook(x, block_idx, input_pos)`, where
1716
# `x` is the layer input, `block_idx` the number of the layer, and `input_pos`

litgpt/generate/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def generate_fn(
171171

172172
prompt_size = prompt.size(0)
173173
if prompt_size == 0:
174-
raise ValueError(f"prompt must not be empty")
174+
raise ValueError("prompt must not be empty")
175175
sample_kwargs = dict(
176176
temperature=temperature,
177177
top_k=top_k,

litgpt/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
66
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
77
"""
8+
9+
from dataclasses import replace
810
from functools import partial
911
from typing import Any, List, Optional, Tuple, Union
10-
from dataclasses import replace
1112

1213
import torch
1314
import torch.nn as nn
@@ -189,7 +190,8 @@ def reset_parameters(self) -> None:
189190
self.mha.set_seq_length(self.max_seq_length, device=self.cos.device)
190191

191192
def set_start_of_layer_hook(
192-
self, hook: Optional[StartOfLayerHook],
193+
self,
194+
hook: Optional[StartOfLayerHook],
193195
):
194196
"""
195197
Sets a function `hook(x, block_idx, input_pos)`, which is called
@@ -452,7 +454,9 @@ def get_kv_cache_params(self) -> Optional[KVCacheParams]:
452454
batch_size = min(c.batch_size for c in caches)
453455
cache_length = min(c.cache_length for c in caches)
454456
params = replace(
455-
params, batch_size=batch_size, cache_length=cache_length,
457+
params,
458+
batch_size=batch_size,
459+
cache_length=cache_length,
456460
)
457461
return params
458462

tests/generate/test_main.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
174174
)
175175
assert (
176176
generate_mock.mock_calls
177-
== [call(ANY, tensor_like, len_return_value, **sample_kwargs, eos_id=tokenizer_mock.return_value.eos_id)] * num_samples
177+
== [call(ANY, tensor_like, len_return_value, **sample_kwargs, eos_id=tokenizer_mock.return_value.eos_id)]
178+
* num_samples
178179
)
179180
expected_output = "foo bar baz\n" * num_samples
180181
# Allow for the config to be printed before the expected repeated strings.
@@ -209,9 +210,7 @@ def test_sample(temperature):
209210
)
210211
# Note: Both `sample` and `batched_sample` create only 1 sample, not 3.
211212
# It is like passing `logits[:, 1-:, :]`
212-
token = batched_sample(
213-
logits, kwargs=dict(temperature=temperature, top_p=0.8)
214-
)
213+
token = batched_sample(logits, kwargs=dict(temperature=temperature, top_p=0.8))
215214

216215
assert token.shape == (2, 1)
217216
# sample is batch size 1 only for now - this should be [0, 1] once batched generation is supported

tests/test_batch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def create_llm(tmp_path, batch_size, max_seq_length, device) -> tuple[LLM, GPT]:
3232
)
3333
model: GPT = llm.model
3434
model.set_kv_cache(
35-
batch_size=batch_size, max_seq_length=max_seq_length, device=device,
35+
batch_size=batch_size,
36+
max_seq_length=max_seq_length,
37+
device=device,
3638
)
3739

3840
return llm, model
@@ -89,7 +91,9 @@ def test_batched_equivalence(tmp_path):
8991
# Switch to batched generation
9092
model.clear_kv_cache()
9193
model.set_kv_cache(
92-
batch_size=batch_size, max_seq_length=max_seq_length, device=device,
94+
batch_size=batch_size,
95+
max_seq_length=max_seq_length,
96+
device=device,
9397
)
9498

9599
toks_1: torch.Tensor = batched_next_token(

tests/test_chat.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,8 @@ def test_generate(monkeypatch, generated, stop_tokens, expected):
4747
model.config.block_size = 100
4848
model.max_seq_length = 100
4949
# Mock methods called during generation
50-
monkeypatch.setattr(
51-
model, "kv_cache_max_prefill_length", lambda: 80
52-
)
53-
monkeypatch.setattr(
54-
model, "kv_cache_max_tokens_forward", lambda: 20
55-
)
50+
monkeypatch.setattr(model, "kv_cache_max_prefill_length", lambda: 80)
51+
monkeypatch.setattr(model, "kv_cache_max_tokens_forward", lambda: 20)
5652
it = iter(generated)
5753

5854
def multinomial(*_, **__):

tests/test_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import random
55
from copy import deepcopy
66
from functools import partial
7-
from unittest import mock
87

98
import pytest
109
import torch

0 commit comments

Comments
 (0)