Skip to content

Commit 360eb45

Browse files
committed
init lightllm repo
1 parent d2ccc02 commit 360eb45

File tree

107 files changed

+7898
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+7898
-0
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
__pycache__/
2+
.pyc
3+
build
4+
dist
5+
*.egg-info

Dockerfile

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
FROM debian:bullseye-slim as pytorch-install
2+
ARG PYTORCH_VERSION=2.0.0
3+
ARG PYTHON_VERSION=3.9
4+
ARG CUDA_VERSION=11.8
5+
ARG MAMBA_VERSION=23.1.0-1
6+
ARG CUDA_CHANNEL=nvidia
7+
ARG INSTALL_CHANNEL=pytorch
8+
ARG TARGETPLATFORM
9+
10+
ENV PATH /opt/conda/bin:$PATH
11+
12+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
13+
build-essential \
14+
ca-certificates \
15+
ccache \
16+
curl \
17+
git && \
18+
rm -rf /var/lib/apt/lists/*
19+
20+
21+
RUN case ${TARGETPLATFORM} in \
22+
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
23+
*) MAMBA_ARCH=x86_64 ;; \
24+
esac && \
25+
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
26+
RUN chmod +x ~/mambaforge.sh && \
27+
bash ~/mambaforge.sh -b -p /opt/conda && \
28+
rm ~/mambaforge.sh
29+
30+
RUN case ${TARGETPLATFORM} in \
31+
"linux/arm64") exit 1 ;; \
32+
*) /opt/conda/bin/conda update -y conda && \
33+
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch==$PYTORCH_VERSION "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
34+
esac && \
35+
/opt/conda/bin/conda clean -ya
36+
37+
FROM nvidia/cuda:11.8.0-devel-ubuntu20.04 as base
38+
39+
ENV PATH=/opt/conda/bin:$PATH \
40+
CONDA_PREFIX=/opt/conda
41+
42+
WORKDIR /usr/src
43+
44+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
45+
libssl-dev \
46+
ca-certificates \
47+
make \
48+
&& rm -rf /var/lib/apt/lists/*
49+
50+
COPY --from=pytorch-install /opt/conda /opt/conda
51+
COPY requirements.txt requirements.txt
52+
RUN pip install -r requirements.txt && rm -rf requirements.txt
53+
RUN apt update -y && apt install -y vim wget curl git

README.md

+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
<div align="center">
2+
<picture>
3+
<img alt="LightLLM" src="assets/lightllm.drawio.png" width=90%>
4+
</picture>
5+
</div>
6+
7+
---
8+
LightLLM is a Python-based LLM (Large Language Model) inference and serving framework, notable for its lightweight design, easy scalability, and high-speed performance. LightLLM harnesses the strengths of numerous well-regarded open-source implementations, including but not limited to FasterTransformer, TGI, vLLM, and FlashAttention.
9+
10+
## Features
11+
12+
- Tri-process asynchronous collaboration: tokenization, model inference, and detokenization are performed asynchronously, leading to a considerable improvement in GPU utilization.
13+
- Nopad (Unpad): offers support for nopad attention operations across multiple models to efficiently handle requests with large length disparities.
14+
- Dynamic Batch: enables dynamic batch scheduling of requests
15+
- [FlashAttention](https://github.com/Dao-AILab/flash-attention): incorporates FlashAttention to improve speed and reduce GPU memory footprint during inference.
16+
- Tensor Parallelism: utilizes tensor parallelism over multiple GPUs for faster inference.
17+
- [Token Attention](./docs/TokenAttention.md): implements token-wise's KV cache memory management mechanism, allowing for zero memory waste during inference.
18+
- High-performance Router: collaborates with Token Attention to meticulously manage the GPU memory of each token, thereby optimizing system throughput.
19+
20+
## Supported Model List
21+
22+
- [BLOOM](https://huggingface.co/bigscience/bloom)
23+
- [LLaMA](https://github.com/facebookresearch/llama)
24+
- [LLaMA V2](https://huggingface.co/meta-llama)
25+
26+
## Get started
27+
28+
### Requirements
29+
30+
The code has been tested with Pytorch>=1.3, CUDA 11.8, and Python 3.9. To install the necessary dependencies, please refer to the provided **requirements.txt** and follow the instructions as
31+
32+
~~~shell
33+
pip install -r requirements.txt
34+
~~~
35+
36+
A more straightforward approach is to use the official Docker container:
37+
38+
~~~shell
39+
docker build -t image_name .
40+
docker run -it --gpus all -p 8080:80 -v your_local_path:/data/ image_name /bin/bash
41+
~~~
42+
43+
### Installation
44+
45+
- Install from the source code by
46+
47+
~~~shell
48+
python setup.py install
49+
~~~
50+
51+
The code has been tested on a range of GPUs including V100, A100, A800, 4090, and H800. If you are running the code on V100, A100, A800, etc., we recommend using triton==2.0.0.dev20221202. If you are running the code on 4090, H800, etc., it is necessary to compile and install the source code of [triton==2.1.0](https://github.com/openai/triton/tree/main) from the GitHub repository. If the code doesn't work on other GPUs, try modifying the triton kernel used in model inference.
52+
53+
### RUN LLaMA
54+
With efficient Routers and TokenAttention, LightLLM can be deployed as a service and achieve the state-of-the-art throughput performance.
55+
56+
Launch the server:
57+
58+
~~~shell
59+
python -m lightllm.server.api_server --model_dir /path/llama-7B --tp 1 --max_total_token_num 120000
60+
~~~
61+
62+
The parameter `max_total_token_num` is influenced by the GPU memory of the deployment environment. A larger value for this parameter allows for the processing of more concurrent requests, thereby increasing system concurrency. For more startup parameters, please refer to [api_server.py](lightllm/server/api_server.py).
63+
64+
To initiate a query in the shell:
65+
66+
~~~shell
67+
curl 127.0.0.1:8000/generate \
68+
-X POST \
69+
-d '{"inputs":"What is AI?","parameters":{"max_new_tokens":17, "frequency_penalty":1}}' \
70+
-H 'Content-Type: application/json'
71+
~~~
72+
73+
To query from Python:
74+
75+
~~~python
76+
import time
77+
import requests
78+
import json
79+
80+
url = 'http://localhost:8000/generate'
81+
headers = {'Content-Type': 'application/json'}
82+
data = {
83+
'inputs': 'What is AI?',
84+
"parameters": {
85+
'do_sample': False,
86+
'ignore_eos': False,
87+
'max_new_tokens': 1024,
88+
}
89+
}
90+
response = requests.post(url, headers=headers, data=json.dumps(data))
91+
if response.status_code == 200:
92+
print(response.json())
93+
else:
94+
print('Error:', response.status_code, response.text)
95+
~~~
96+
97+
## Performance
98+
99+
### Service Performance
100+
101+
We compared the service performance of LightLLM and vLLM==0.1.2 on LLaMA-7B using an A800 with 80G GPU memory.
102+
103+
To begin, prepare the data as follows:
104+
105+
~~~shell
106+
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
107+
~~~
108+
109+
Launch the service:
110+
111+
~~~shell
112+
python -m lightllm.server.api_server --model_dir /path/llama-7b --tp 1 --max_total_token_num 121060 --tokenizer_mode auto
113+
~~~
114+
115+
Evaluation:
116+
117+
~~~shell
118+
cd test
119+
python benchmark_serving.py --tokenizer /path/llama-7b --dataset /path/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000 --request-rate 200
120+
~~~
121+
122+
The performance comparisons results are presented below:
123+
124+
| vLLM | LightLLM |
125+
| ---------------------------------------------------- | ----------------------------------------------------- |
126+
| Total time: 361.79 s<br/>Throughput: 5.53 requests/s | Total time: 188.85 s<br/>Throughput: 10.59 requests/s |
127+
128+
### Static inference performance
129+
130+
For debugging, we offer static performance testing scripts for various models. For instance, you can evaluate the inference performance of the LLaMA model by
131+
132+
~~~shell
133+
cd test/lightllama
134+
python test_model_infer.py
135+
~~~
136+
137+
### FAQ
138+
139+
- In case the LLaMA tokenizer fails to load, consider resolving this by running the command 'pip install protobuf==3.20.0'.
140+
141+
## License
142+
143+
This repository is released under the [Apache-2.0](LICENSE) license.
144+
145+
## Acknowledgement
146+
147+
We learned a lot from the following projects when developing LightLLM.
148+
- [Faster Transformer](https://github.com/NVIDIA/FasterTransformer)
149+
- [Text Generation Inference](https://github.com/huggingface/text-generation-inference)
150+
- [vLLM](https://github.com/vllm-project/vllm)
151+
- [Flash Attention 1&2](https://github.com/Dao-AILab/flash-attention)

assets/att.gif

58.2 KB
Loading

assets/lightllm.drawio.png

176 KB
Loading

assets/logo.png

70.8 KB
Loading

benchmark.md

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#### lightllm
2+
3+
#### Launch service
4+
5+
~~~shell
6+
python -m lightllm.server.api_server --model_dir /path/llama-7b --tp 1 --max_total_token_num 121060 --tokenizer_mode auto
7+
~~~
8+
9+
#### Evaluation
10+
11+
~~~shell
12+
python benchmark_serving.py --tokenizer /path/llama-7b --dataset /path/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000 --request-rate 200
13+
~~~
14+
15+
#### vllm
16+
17+
#### Launch service
18+
~~~shell
19+
python -m vllm.entrypoints.api_server --model /path/llama-7b --swap-space 16 --disable-log-requests --port 9009
20+
~~~
21+
22+
#### Evaluation
23+
24+
~~~shell
25+
python benchmark_serving_vllm.py --backend vllm --tokenizer /path/llama-7b --dataset /path/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 2000 --request-rate 200 --host 127.0.0.1 --port 9009
26+
~~~

docs/TokenAttention.md

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# TokenAttention
2+
3+
Transformers form the basis of modern large language models. During autoregressive decoding, these models cache key-value tensors of context tokens into GPU memory to facilitate fast generation of the next token. However, these caches occupy significant GPU memory. The unpredictable nature of cache size, due to the variability in the length of each request, exacerbates the issue, resulting in significant memory fragmentation in the absence of a suitable memory management mechanism.
4+
5+
To alleviate this issue, PagedAttention was proposed to store the KV cache in non-contiguous memory spaces. It partitions the KV cache of each sequence into multiple blocks, with each block containing the keys and values for a fixed number of tokens. This approach effectively controls memory waste within the last block during attention computation. While PagedAttention alleviates memory fragmentation to some extent, it still leaves room for memory waste. Additionally, when handling multiple high-concurrency requests, the allocation and deallocation of memory blocks fall short of efficiency, leading to suboptimal memory utilization.
6+
7+
To address the above challenges, we introduce TokenAttention, an attention mechanism that manages key and value caching at the token level. Compared to PagedAttention, our TokenAttention not only minimizes memory fragmentation and enables efficient memory sharing but also facilitates efficient memory allocation and deallocation. It allows for more precise and fine-grained memory management, thus optimizing memory utilization.
8+
9+
<div align="center">
10+
11+
| Features | PagedAttention | TokenAttention |
12+
| -------------------------------------------- | :------------: | :------------: |
13+
| Low memory fragmentation | &#x2713; | &#x2713; |
14+
| Efficient memory sharing | &#x2713; | &#x2713; |
15+
| Efficient memory allocation and deallocation | &#x2717; | &#x2713; |
16+
| Fine-grained memory management | &#x2717; | &#x2713; |
17+
</div>
18+
19+
The operation mechanism of TokenAttention is illustrated in the figure below:
20+
21+
<div align="center">
22+
<img alt="TokenAtt" src="../assets/att.gif" width=60%>
23+
</div>
24+
25+
During model initialization, the KV cache is pre-allocated based on the user-set **max_total_token_num** and a Token Table is created to record the actual storage locations of input tokens.
26+
27+
When handling new requests, the system first checks for available contiguous space in the pre-allocated Token cache for storing the key-value (KV) cache. TokenAttention favors assigning contiguous graphics memory space for requests to minimize memory access during the inference process. Only when contiguous space is insufficient does it allocate non-contiguous graphics memory for the requests. Since memory management is conducted on a token-by-token basis, TokenAttention achieves nearly zero waste, yielding higher throughput compared to vllm.
28+
29+
We have implemented an efficient TokenAttention operator using OpenAI Triton. When provided with a query vector, this operator can efficiently retrieve the corresponding KV cache based on the Token Table and conduct the attention computation.
30+
31+
Upon completion of requests, the corresponding graphics memory can be quickly freed by deleting their records on the Token Table, which makes way for scheduling new requests. Given that TokenAttention pre-allocates all KV cache space during model initialization, it can efficiently release memory for completed requests and merge different batches of requests during dynamic scheduling, thereby effectively maximizing GPU utilization.

format.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import os
2+
import glob
3+
4+
for filename in glob.glob('./**/*.py', recursive=True):
5+
print(filename)
6+
os.system(f"autopep8 --max-line-length 140 --in-place --aggressive --aggressive {filename}")

lightllm/__init__.py

Whitespace-only changes.

lightllm/common/__init__.py

Whitespace-only changes.

lightllm/common/configs/__init__.py

Whitespace-only changes.

lightllm/common/configs/config.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
_DEFAULT_MAX_INPUT_ADD_OUTPUT_LEN = 1024 * 5
3+
4+
setting = {
5+
"max_req_total_len" : _DEFAULT_MAX_INPUT_ADD_OUTPUT_LEN
6+
}

lightllm/common/gqa_mem_manager.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .mem_manager import MemoryManager
2+
3+
class GQAMemoryManager(MemoryManager):
4+
def __init__(self, size, dtype, key_value_head_num, head_dim, layer_num):
5+
super().__init__(size, dtype, key_value_head_num, head_dim, layer_num)

lightllm/common/infer_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
def init_bloc(b_loc, b_seq_len, max_len_in_batch, alloc_mem_index):
2+
start_index = 0
3+
b_seq_len_numpy = b_seq_len.cpu().numpy()
4+
for i in range(len(b_seq_len)):
5+
cur_seq_len = b_seq_len_numpy[i]
6+
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + cur_seq_len]
7+
start_index += cur_seq_len
8+
return

lightllm/common/mem_manager.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import torch
2+
3+
4+
class MemoryManager:
5+
def __init__(self, size, dtype, head_num, head_dim, layer_num):
6+
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
7+
self._mem_cum_sum = torch.empty((size,), dtype=torch.int32, device="cuda")
8+
self.indexes = torch.arange(0, size, dtype=torch.long, device="cuda")
9+
self.can_use_mem_size = size
10+
self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)]
11+
self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)]
12+
13+
@torch.no_grad()
14+
def alloc(self, need_size):
15+
if need_size > self.can_use_mem_size:
16+
print(f'warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}')
17+
return None
18+
19+
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum)
20+
select_index = torch.logical_and(self._mem_cum_sum <= need_size, self.mem_state == 1)
21+
select_index = self.indexes[select_index]
22+
self.mem_state[select_index] = 0
23+
self.can_use_mem_size -= len(select_index)
24+
return select_index
25+
26+
@torch.no_grad()
27+
def alloc_contiguous(self, need_size):
28+
if need_size > self.can_use_mem_size:
29+
print(f'warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}')
30+
return None
31+
32+
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self._mem_cum_sum)
33+
sum_size = len(self._mem_cum_sum)
34+
loc_sums = self._mem_cum_sum[need_size - 1:] - self._mem_cum_sum[0:sum_size - need_size + 1] + self.mem_state[0:sum_size - need_size + 1]
35+
can_used_loc = self.indexes[0:sum_size - need_size + 1][loc_sums == need_size]
36+
if can_used_loc.shape[0] == 0:
37+
# print(f'warn no enough cache to contiguous need_size {need_size} left_size {self.can_use_mem_size}')
38+
return None
39+
start_loc = can_used_loc[0]
40+
select_index = self.indexes[start_loc : start_loc + need_size]
41+
42+
self.mem_state[select_index] = 0
43+
self.can_use_mem_size -= len(select_index)
44+
start = start_loc.item()
45+
end = start + need_size
46+
return select_index, start, end
47+
48+
@torch.no_grad()
49+
def free(self, free_index):
50+
"""_summary_
51+
52+
Args:
53+
free_index (torch.Tensor): _description_
54+
"""
55+
self.can_use_mem_size += free_index.shape[0]
56+
self.mem_state[free_index] = 1
57+
if self.can_use_mem_size == len(self.mem_state):
58+
print(f"freed all gpu mem size {self.can_use_mem_size}")
59+
# print(f"free state {self.can_use_mem_size} all {len(self.mem_state)}")
60+
return
61+
62+
@torch.no_grad()
63+
def free_all(self):
64+
self.can_use_mem_size = len(self.mem_state)
65+
self.mem_state[:] = 1
66+
67+

lightllm/common/triton_kernel/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)