Skip to content

Commit 54ba27e

Browse files
authored
Add example to fine-tune StarCoder for chat-based applications (#17)
* Add StarChat files * Clean up * Fix readme * Tweak * Clean up * Final polish * Fix steps * Final tweaks * Delete dead code * Fix typo
1 parent bdc636e commit 54ba27e

11 files changed

+1277
-0
lines changed

.gitignore

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# poetry
98+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99+
# This is especially recommended for binary packages to ensure reproducibility, and is more
100+
# commonly ignored for libraries.
101+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102+
#poetry.lock
103+
104+
# pdm
105+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106+
#pdm.lock
107+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108+
# in version control.
109+
# https://pdm.fming.dev/#use-with-ide
110+
.pdm.toml
111+
112+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113+
__pypackages__/
114+
115+
# Celery stuff
116+
celerybeat-schedule
117+
celerybeat.pid
118+
119+
# SageMath parsed files
120+
*.sage.py
121+
122+
# Environments
123+
.env
124+
.venv
125+
env/
126+
venv/
127+
ENV/
128+
env.bak/
129+
venv.bak/
130+
131+
# Spyder project settings
132+
.spyderproject
133+
.spyproject
134+
135+
# Rope project settings
136+
.ropeproject
137+
138+
# mkdocs documentation
139+
/site
140+
141+
# mypy
142+
.mypy_cache/
143+
.dmypy.json
144+
dmypy.json
145+
146+
# Pyre type checker
147+
.pyre/
148+
149+
# pytype static type analyzer
150+
.pytype/
151+
152+
# Cython debug symbols
153+
cython_debug/
154+
155+
# PyCharm
156+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158+
# and can be added to the global gitignore or merged into this file. For a more nuclear
159+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160+
#.idea/
161+
162+
data/
163+
wandb/

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
# What is this about?
66
💫 StarCoder is a language model (LM) trained on source code and natural language text. Its training data incorporates more that 80 different programming languages as well as text extracted from GitHub issues and commits and from notebooks. This repository showcases how we get an overview of this LM's capabilities.
77

8+
# News
9+
10+
* **May 9, 2023:** We've fine-tuned StarCoder to act as a helpful coding assistant 💬! Check out the `chat/` directory for the training code and play with the model [here](https://huggingface.co/spaces/HuggingFaceH4/starchat-playground).
11+
812
# Disclaimer
913

1014
Before you can use the model go to `hf.co/bigcode/starcoder` and accept the agreement. And make sure you are logged into the Hugging Face hub with:

chat/README.md

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Fine-tuning StarCoder for chat-based applications
2+
3+
This is a fully-working example to fine-tune `StarCoder` on a corpus of multi-turn dialogues and thus create a coding assistant that is chatty and helpful. The resulting model is quite good at generating code for plots and other programming tasks. For example, given the prompt:
4+
5+
```
6+
Draw me a map of the world using geopandas. Make it so that only Germany and Spain are colored red.
7+
```
8+
9+
the model will produce:
10+
11+
```python
12+
import pandas as pd
13+
from matplotlib import pyplot as plt
14+
import geopandas as gpd
15+
from shapely.geometry import Point, Polygon
16+
%matplotlib inline
17+
# Load data from URL or local file
18+
url = 'https://raw.githubusercontent.com/datasets/geo-countries/master/data/countries.geojson'
19+
df = gpd.read_file(url)
20+
21+
# Create plot
22+
fig, ax = plt.subplots()
23+
ax.set_aspect('equal')
24+
world = gpd.GeoDataFrame({'name': ['World'],
25+
'geometry': [Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])]})
26+
germany = df[df['ADMIN'] == "Germany"]
27+
spain = df[df['ADMIN'] == "Spain"]
28+
other = df[(df['ADMIN']!= "Germany") & (df['ADMIN']!= "Spain")]
29+
world.plot(color='lightgrey', edgecolor='white', ax=ax)
30+
germany.plot(color="red", ax=ax)
31+
spain.plot(color="red", ax=ax)
32+
other.plot(color="skyblue", ax=ax)
33+
plt.title("European Countries")
34+
plt.show()
35+
```
36+
37+
Check out our [blog post](https://huggingface.co/blog/starchat-alpha) for more details.
38+
39+
## Getting started
40+
41+
To run the `train.py` script, first create a Python virtual environment using e.g. Conda:
42+
43+
```shell
44+
conda create -n chat python=3.10 && conda activate chat
45+
```
46+
47+
Next, install PyTorch v1.13.1. Since this is hardware-dependent, we direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/previous-versions/#v1131) for this step. Next, install the rest of the project dependencies:
48+
49+
```shell
50+
pip install -r requirements.txt
51+
```
52+
53+
You'll also need to be logged into both your Hugging Face account. To do so, run:
54+
55+
```shell
56+
huggingface-cli login
57+
```
58+
59+
Finally, install Git LFS with:
60+
61+
```shell
62+
sudo apt-get install git-lfs
63+
```
64+
65+
## Prepare your dataset
66+
67+
For training and inference, we use _dialogue templates_ to format each message in a conversation. For example, a typical dialogue between a human user and AI assistant takes the form:
68+
69+
```json
70+
{
71+
"messages": [
72+
{
73+
"content": "Is it possible to imagine a society without law?",
74+
"role": "user"},
75+
{
76+
"content": "It is difficult to imagine a society that is able to be maintained without any semblance of Law.",
77+
"role": "assistant",
78+
},
79+
{
80+
"content": "It seems like you consider the absence of law equal to the absence of anything that could guide the behaviour of the individual.",
81+
"role": "user",
82+
},
83+
{
84+
"content": "You are correct that there are other factors that can guide behavior in a society and play a role in shaping individuals' behavior and interactions with each other. However, even in societies where these factors are present, laws still serve an important role in maintaining social order and resolving conflicts.",
85+
"role": "assistant",
86+
}
87+
]
88+
}
89+
```
90+
91+
Make sure you convert your dataset according to this schema, in particular you need to include a `messages` column like the above. You can adjust the model, dataset, and hyperparamters in the `config.yaml` file.
92+
93+
## Launch training
94+
95+
We use DeepSpeed ZeRO-3 to shard the model and optimizer across 8 x A100 (80GB) GPUs. To fine-tune run:
96+
97+
```
98+
TRANSFORMERS_VERBOSITY=info torchrun --nproc_per_node=8 train.py config.yaml --deepspeed=deepspeed_z3_config_bf16.json
99+
```
100+
101+
By default, this will save the model checkpoint in the `data/` directory and also push it to the Hugging Face Hub.
102+
103+
104+
## Generate samples
105+
106+
To generate a few coding examples from your model, run:
107+
108+
```shell
109+
python generate.py --model_id path/to/your/model
110+
```
111+

0 commit comments

Comments
 (0)