|
| 1 | +# Fine-tuning StarCoder for chat-based applications |
| 2 | + |
| 3 | +This is a fully-working example to fine-tune `StarCoder` on a corpus of multi-turn dialogues and thus create a coding assistant that is chatty and helpful. The resulting model is quite good at generating code for plots and other programming tasks. For example, given the prompt: |
| 4 | + |
| 5 | +``` |
| 6 | +Draw me a map of the world using geopandas. Make it so that only Germany and Spain are colored red. |
| 7 | +``` |
| 8 | + |
| 9 | +the model will produce: |
| 10 | + |
| 11 | +```python |
| 12 | +import pandas as pd |
| 13 | +from matplotlib import pyplot as plt |
| 14 | +import geopandas as gpd |
| 15 | +from shapely.geometry import Point, Polygon |
| 16 | +%matplotlib inline |
| 17 | +# Load data from URL or local file |
| 18 | +url = 'https://raw.githubusercontent.com/datasets/geo-countries/master/data/countries.geojson' |
| 19 | +df = gpd.read_file(url) |
| 20 | + |
| 21 | +# Create plot |
| 22 | +fig, ax = plt.subplots() |
| 23 | +ax.set_aspect('equal') |
| 24 | +world = gpd.GeoDataFrame({'name': ['World'], |
| 25 | + 'geometry': [Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])]}) |
| 26 | +germany = df[df['ADMIN'] == "Germany"] |
| 27 | +spain = df[df['ADMIN'] == "Spain"] |
| 28 | +other = df[(df['ADMIN']!= "Germany") & (df['ADMIN']!= "Spain")] |
| 29 | +world.plot(color='lightgrey', edgecolor='white', ax=ax) |
| 30 | +germany.plot(color="red", ax=ax) |
| 31 | +spain.plot(color="red", ax=ax) |
| 32 | +other.plot(color="skyblue", ax=ax) |
| 33 | +plt.title("European Countries") |
| 34 | +plt.show() |
| 35 | +``` |
| 36 | + |
| 37 | +Check out our [blog post](https://huggingface.co/blog/starchat-alpha) for more details. |
| 38 | + |
| 39 | +## Getting started |
| 40 | + |
| 41 | +To run the `train.py` script, first create a Python virtual environment using e.g. Conda: |
| 42 | + |
| 43 | +```shell |
| 44 | +conda create -n chat python=3.10 && conda activate chat |
| 45 | +``` |
| 46 | + |
| 47 | +Next, install PyTorch v1.13.1. Since this is hardware-dependent, we direct you to the [PyTorch Installation Page](https://pytorch.org/get-started/previous-versions/#v1131) for this step. Next, install the rest of the project dependencies: |
| 48 | + |
| 49 | +```shell |
| 50 | +pip install -r requirements.txt |
| 51 | +``` |
| 52 | + |
| 53 | +You'll also need to be logged into both your Hugging Face account. To do so, run: |
| 54 | + |
| 55 | +```shell |
| 56 | +huggingface-cli login |
| 57 | +``` |
| 58 | + |
| 59 | +Finally, install Git LFS with: |
| 60 | + |
| 61 | +```shell |
| 62 | +sudo apt-get install git-lfs |
| 63 | +``` |
| 64 | + |
| 65 | +## Prepare your dataset |
| 66 | + |
| 67 | +For training and inference, we use _dialogue templates_ to format each message in a conversation. For example, a typical dialogue between a human user and AI assistant takes the form: |
| 68 | + |
| 69 | +```json |
| 70 | +{ |
| 71 | + "messages": [ |
| 72 | + { |
| 73 | + "content": "Is it possible to imagine a society without law?", |
| 74 | + "role": "user"}, |
| 75 | + { |
| 76 | + "content": "It is difficult to imagine a society that is able to be maintained without any semblance of Law.", |
| 77 | + "role": "assistant", |
| 78 | + }, |
| 79 | + { |
| 80 | + "content": "It seems like you consider the absence of law equal to the absence of anything that could guide the behaviour of the individual.", |
| 81 | + "role": "user", |
| 82 | + }, |
| 83 | + { |
| 84 | + "content": "You are correct that there are other factors that can guide behavior in a society and play a role in shaping individuals' behavior and interactions with each other. However, even in societies where these factors are present, laws still serve an important role in maintaining social order and resolving conflicts.", |
| 85 | + "role": "assistant", |
| 86 | + } |
| 87 | + ] |
| 88 | +} |
| 89 | +``` |
| 90 | + |
| 91 | +Make sure you convert your dataset according to this schema, in particular you need to include a `messages` column like the above. You can adjust the model, dataset, and hyperparamters in the `config.yaml` file. |
| 92 | + |
| 93 | +## Launch training |
| 94 | + |
| 95 | +We use DeepSpeed ZeRO-3 to shard the model and optimizer across 8 x A100 (80GB) GPUs. To fine-tune run: |
| 96 | + |
| 97 | +``` |
| 98 | +TRANSFORMERS_VERBOSITY=info torchrun --nproc_per_node=8 train.py config.yaml --deepspeed=deepspeed_z3_config_bf16.json |
| 99 | +``` |
| 100 | + |
| 101 | +By default, this will save the model checkpoint in the `data/` directory and also push it to the Hugging Face Hub. |
| 102 | + |
| 103 | + |
| 104 | +## Generate samples |
| 105 | + |
| 106 | +To generate a few coding examples from your model, run: |
| 107 | + |
| 108 | +```shell |
| 109 | +python generate.py --model_id path/to/your/model |
| 110 | +``` |
| 111 | + |
0 commit comments