Skip to content

Commit ba6159c

Browse files
FrankLeeeeecanghuazyksir
authored
added abstraction for target model backend (#278)
* added abstraction for target model backend Co-authored-by: canghua <[email protected]> Co-authored-by: zyksir <[email protected]> * polish --------- Co-authored-by: canghua <[email protected]> Co-authored-by: zyksir <[email protected]>
1 parent c5549ba commit ba6159c

32 files changed

+1617
-709
lines changed

.github/workflows/test.yaml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,29 @@ jobs:
2424
- name: Checkout code
2525
uses: actions/checkout@v4
2626

27+
- name: Restore cache
28+
run: |
29+
if [ -d /github/home/sf ] && [ ! -z "$(ls -A /github/home/sf/)" ]; then
30+
cp -p -r /github/home/sf/* ./
31+
fi
32+
2733
- name: Install dependencies
34+
shell: bash
2835
run: |
29-
pip install -e .
36+
# if sf venv does not exist, create it
37+
if [ ! -d sf ]; then
38+
uv venv sf -p 3.11
39+
fi
40+
source sf/bin/activate
41+
uv pip install -v . --prerelease=allow
3042
3143
- name: Run test
3244
timeout-minutes: 30
45+
shell: bash
3346
run: |
47+
source sf/bin/activate
3448
python -m unittest discover -s ./tests -p "test_*.py" -v
49+
50+
- name: Save cache
51+
run: |
52+
cp -p -r sf /github/home/

examples/run_llama3_eagle3_online.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ torchrun \
1111
$ROOT_DIR/scripts/train_eagle3_online.py \
1212
--target-model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
1313
--draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \
14-
--train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
14+
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
1515
--output-dir $ROOT_DIR/outputs/llama3-8b-eagle3 \
1616
--num-epochs 2 \
1717
--batch-size 2 \
1818
--learning-rate 1e-4 \
1919
--max-length 2048 \
2020
--chat-template llama3 \
2121
--cache-dir $ROOT_DIR/cache \
22-
--attention-backend flex_attention
22+
--attention-backend sdpa \
23+
--target-model-backend sglang \
24+
--log-interval 10

0 commit comments

Comments
 (0)