-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy path7_fine_tune.py
102 lines (81 loc) · 2.58 KB
/
7_fine_tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import tempfile
import torch
from ultralytics import YOLO
import fiftyone as fo
"""
This code will show you how to fine tune a YOLO model with the data we worked with. We are not going to run it in the workshops
since that would take too long. This is here for when you want to see this later
The results are in the /fine-tuning-yolo/train directory
"""
DATASET_NAME = 'training_data'
DEFAULT_MODEL_SIZE = "m"
DEFAULT_IMAGE_SIZE = 640
DEFAULT_EPOCHS = 10
PROJECT_NAME = 'fine-tuning-yolo'
def get_torch_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def train_classifier(
dataset_name=None,
model_size=DEFAULT_MODEL_SIZE,
image_size=DEFAULT_IMAGE_SIZE,
epochs=DEFAULT_EPOCHS,
project_name="mislabel_confidence_noise",
gt_field="ground_truth",
train_split=None,
test_split=None,
**kwargs
):
if dataset_name:
dataset = fo.load_dataset(dataset_name)
dataset.take(0.2 * len(dataset)).tag_samples("test")
dataset.match_tags("test", bool=False).tag_samples("train")
train = dataset.match_tags("train")
test = dataset.match_tags("test")
else:
train = train_split
test = test_split
if model_size is None:
model_size = "s"
elif model_size not in ["n", "s", "m", "l", "x"]:
raise ValueError("model_size must be one of ['n', 's', 'm', 'l', 'x']")
splits_dict = {
"train": train,
"val": test,
"test": test,
}
data_dir = tempfile.mkdtemp()
for key, split in splits_dict.items():
split_dir = os.path.join(data_dir, key)
os.makedirs(split_dir)
split.export(
export_dir=split_dir,
dataset_type=fo.types.ImageClassificationDirectoryTree,
label_field=gt_field,
export_media="symlink",
)
# Load a pre-trained YOLOv8 model for classification
model = YOLO(f"yolo11{model_size}-cls.pt")
# Train the model
model.train(
data=data_dir, # Path to the dataset
epochs=epochs, # Number of epochs
imgsz=image_size, # Image size
device=get_torch_device(),
batch = 16,
project=project_name,
exist_ok=True # Allow the output to overwrite previous model runs
)
return model
def main():
train_classifier(
dataset_name=DATASET_NAME,
project_name=PROJECT_NAME,
)
if __name__ == "__main__":
main()