Skip to content

Commit c83c455

Browse files
authored
Merge pull request #369 from roboflow/uploadrfdetr
upload rfdetr
2 parents 7476e72 + 8cd79a5 commit c83c455

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

roboflow/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from roboflow.models import CLIPModel, GazeModel # noqa: F401
1616
from roboflow.util.general import write_line
1717

18-
__version__ = "1.1.60"
18+
__version__ = "1.1.61"
1919

2020

2121
def check_key(api_key, model, notebook, num_retries=0):

roboflow/util/model_processor.py

+77
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def _get_processor_function(model_type: str) -> Callable:
2727
"paligemma",
2828
"paligemma2",
2929
"florence-2",
30+
"rfdetr",
3031
]
3132

3233
if not any(supported_model in model_type for supported_model in supported_models):
@@ -57,6 +58,9 @@ def _get_processor_function(model_type: str) -> Callable:
5758
if "yolonas" in model_type:
5859
return _process_yolonas
5960

61+
if "rfdetr" in model_type:
62+
return _process_rfdetr
63+
6064
return _process_yolo
6165

6266

@@ -220,6 +224,79 @@ def _process_yolo(model_type: str, model_path: str, filename: str) -> str:
220224
return zip_file_name
221225

222226

227+
def _process_rfdetr(model_type: str, model_path: str, filename: str) -> str:
228+
_supported_types = ["rfdetr-base", "rfdetr-large"]
229+
if model_type not in _supported_types:
230+
raise ValueError(f"Model type {model_type} not supported. Supported types are {_supported_types}")
231+
232+
if not os.path.exists(model_path):
233+
raise FileNotFoundError(f"Model path {model_path} does not exist.")
234+
235+
model_files = os.listdir(model_path)
236+
pt_file = next((f for f in model_files if f.endswith(".pt") or f.endswith(".pth")), None)
237+
238+
if pt_file is None:
239+
raise RuntimeError("No .pt or .pth model file found in the provided path")
240+
241+
get_classnames_txt_for_rfdetr(model_path, pt_file)
242+
243+
# Copy the .pt file to weights.pt if not already named weights.pt
244+
if pt_file != "weights.pt":
245+
shutil.copy(os.path.join(model_path, pt_file), os.path.join(model_path, "weights.pt"))
246+
247+
required_files = ["weights.pt"]
248+
249+
optional_files = ["results.csv", "results.png", "model_artifacts.json", "class_names.txt"]
250+
251+
zip_file_name = "roboflow_deploy.zip"
252+
with zipfile.ZipFile(os.path.join(model_path, zip_file_name), "w") as zipMe:
253+
for file in required_files:
254+
zipMe.write(os.path.join(model_path, file), arcname=file, compress_type=zipfile.ZIP_DEFLATED)
255+
256+
for file in optional_files:
257+
if os.path.exists(os.path.join(model_path, file)):
258+
zipMe.write(os.path.join(model_path, file), arcname=file, compress_type=zipfile.ZIP_DEFLATED)
259+
260+
return zip_file_name
261+
262+
263+
def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str):
264+
class_names_path = os.path.join(model_path, "class_names.txt")
265+
if os.path.exists(class_names_path):
266+
maybe_prepend_dummy_class(class_names_path)
267+
return class_names_path
268+
269+
import torch
270+
271+
model = torch.load(os.path.join(model_path, pt_file), map_location="cpu", weights_only=False)
272+
args = vars(model["args"])
273+
if "class_names" in args:
274+
with open(class_names_path, "w") as f:
275+
for class_name in args["class_names"]:
276+
f.write(class_name + "\n")
277+
maybe_prepend_dummy_class(class_names_path)
278+
return class_names_path
279+
280+
raise FileNotFoundError(
281+
f"No class_names.txt file found in model path {model_path}.\n"
282+
f"This should only happen on rfdetr models trained before version 1.1.0.\n"
283+
f"Please re-train your model with the latest version of the rfdetr library, or\n"
284+
f"please create a class_names.txt file in the model path with the class names\n"
285+
f"in new lines in the order of the classes in the model.\n"
286+
)
287+
288+
289+
def maybe_prepend_dummy_class(class_name_file: str):
290+
with open(class_name_file) as f:
291+
class_names = f.readlines()
292+
293+
dummy_class = "background_class83422\n"
294+
if dummy_class not in class_names:
295+
class_names.insert(0, dummy_class)
296+
with open(class_name_file, "w") as f:
297+
f.writelines(class_names)
298+
299+
223300
def _process_huggingface(
224301
model_type: str, model_path: str, filename: str = "fine-tuned-paligemma-3b-pt-224.f16.npz"
225302
) -> str:

0 commit comments

Comments
 (0)