Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion roboflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from roboflow.models import CLIPModel, GazeModel # noqa: F401
from roboflow.util.general import write_line

__version__ = "1.2.10"
__version__ = "1.2.11"


def check_key(api_key, model, notebook, num_retries=0):
Expand Down
7 changes: 5 additions & 2 deletions roboflow/adapters/rfapi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import urllib
from typing import List, Optional
from typing import Dict, List, Optional, Union

import requests
from requests.exceptions import RequestException
Expand Down Expand Up @@ -58,6 +58,7 @@ def start_version_training(
speed: Optional[str] = None,
checkpoint: Optional[str] = None,
model_type: Optional[str] = None,
epochs: Optional[int] = None,
):
"""
Start a training job for a specific version.
Expand All @@ -66,14 +67,16 @@ def start_version_training(
"""
url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train?api_key={api_key}&nocache=true"

data = {}
data: Dict[str, Union[str, int]] = {}
if speed is not None:
data["speed"] = speed
if checkpoint is not None:
data["checkpoint"] = checkpoint
if model_type is not None:
# API expects camelCase
data["modelType"] = model_type
if epochs is not None:
data["epochs"] = epochs

response = requests.post(url, json=data)
if not response.ok:
Expand Down
30 changes: 20 additions & 10 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,18 @@ def export(self, model_format=None) -> bool | None:
else:
raise RuntimeError(f"Unexpected export {export_info}")

def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
def train(
self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False, epochs=None
) -> InferenceModel:
"""
Ask the Roboflow API to train a previously exported version's dataset.

Args:
speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`.
model_type: The type of model to train. Default depends on kind of project. It takes precedence over speed. You can check the list of model ids by sending an invalid parameter in this argument.
checkpoint: A string representing the checkpoint to use while training
plot: Whether to plot the training results. Default is `False`.
epochs: Number of epochs to train the model
plot_in_notebook: Whether to plot the training results. Default is `False`.

Returns:
An instance of the trained model class
Expand Down Expand Up @@ -336,6 +339,7 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
speed=payload_speed,
checkpoint=payload_checkpoint,
model_type=payload_model_type,
epochs=epochs,
)

status = "training"
Expand Down Expand Up @@ -385,15 +389,15 @@ def live_plot(epochs, mAP, loss, title=""):
write_line(line="Training failed")
break

epochs: Union[np.ndarray, list]
epoch_ids: Union[np.ndarray, list]
mAP: Union[np.ndarray, list]
loss: Union[np.ndarray, list]

if "roboflow-train" in models.keys():
import numpy as np

# training has started
epochs = np.array([int(epoch["epoch"]) for epoch in models["roboflow-train"]["epochs"]])
epoch_ids = np.array([int(epoch["epoch"]) for epoch in models["roboflow-train"]["epochs"]])
mAP = np.array([float(epoch["mAP"]) for epoch in models["roboflow-train"]["epochs"]])
loss = np.array(
[
Expand All @@ -410,23 +414,29 @@ def live_plot(epochs, mAP, loss, title=""):
num_machine_spin_dots = ["."]
title = "Training Machine Spinning Up" + "".join(num_machine_spin_dots)

epochs = []
epoch_ids = []
mAP = []
loss = []

if (len(epochs) > len(previous_epochs)) or (len(epochs) == 0):
if (len(epoch_ids) > len(previous_epochs)) or (len(epoch_ids) == 0):
if plot_in_notebook:
live_plot(epochs, mAP, loss, title)
live_plot(epoch_ids, mAP, loss, title)
else:
if len(epochs) > 0:
if len(epoch_ids) > 0:
title = (
title + ": Epoch: " + str(epochs[-1]) + " mAP: " + str(mAP[-1]) + " loss: " + str(loss[-1])
title
+ ": Epoch: "
+ str(epoch_ids[-1])
+ " mAP: "
+ str(mAP[-1])
+ " loss: "
+ str(loss[-1])
)
if not first_graph_write:
write_line(title)
first_graph_write = True

previous_epochs = copy.deepcopy(epochs)
previous_epochs = copy.deepcopy(epoch_ids)

time.sleep(5)

Expand Down