Skip to content

Commit b901208

Browse files
Merge branch 'main' into size-measurement-docs-change
2 parents 7643f01 + ec8c138 commit b901208

File tree

10 files changed

+367
-12
lines changed

10 files changed

+367
-12
lines changed

inference/core/interfaces/http/http_api.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@
250250
EXECUTION_ID_HEADER = None
251251

252252

253+
def get_content_type(request: Request) -> str:
254+
content_type = request.headers.get("content-type", "")
255+
return content_type.split(";")[0].strip()
256+
257+
253258
class LambdaMiddleware(BaseHTTPMiddleware):
254259
async def dispatch(self, request, call_next):
255260
response = await call_next(request)
@@ -457,7 +462,7 @@ async def check_authorization_serverless(request: Request, call_next):
457462
skip_check = True
458463

459464
elif (
460-
request.headers.get("content-type", None) == "application/json"
465+
get_content_type(request) == "application/json"
461466
and int(request.headers.get("content-length", 0)) > 0
462467
):
463468
json_params = await request.json()
@@ -484,7 +489,7 @@ def _unauthorized_response(msg):
484489
api_key = req_params.get("api_key", None)
485490
if (
486491
api_key is None
487-
and request.headers.get("content-type", None) == "application/json"
492+
and get_content_type(request) == "application/json"
488493
and int(request.headers.get("content-length", 0)) > 0
489494
):
490495
# have to try catch here, because some legacy endpoints that abuse Content-Type header but dont actually receive json
@@ -544,7 +549,7 @@ def _unauthorized_response(msg):
544549
api_key = req_params.get("api_key", None)
545550
if (
546551
api_key is None
547-
and request.headers.get("content-type", None) == "application/json"
552+
and get_content_type(request) == "application/json"
548553
and int(request.headers.get("content-length", 0)) > 0
549554
):
550555
# have to try catch here, because some legacy endpoints that abuse Content-Type header but dont actually receive json

inference_experimental/inference_exp/models/auto_loaders/models_registry.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ class RegistryEntry:
5757
module_name="inference_exp.models.yolov7.yolov7_instance_segmentation_trt",
5858
class_name="YOLOv7ForInstanceSegmentationTRT",
5959
),
60+
("yolov8", CLASSIFICATION_TASK, BackendType.ONNX): RegistryEntry(
61+
model_class=LazyClass(
62+
module_name="inference_exp.models.yolov8.yolov8_classification_onnx",
63+
class_name="YOLOv8ForClassificationOnnx",
64+
),
65+
),
6066
("yolov8", OBJECT_DETECTION_TASK, BackendType.ONNX): RegistryEntry(
6167
model_class=LazyClass(
6268
module_name="inference_exp.models.yolov8.yolov8_object_detection_onnx",
@@ -137,6 +143,12 @@ class RegistryEntry:
137143
module_name="inference_exp.models.yolov10.yolov10_object_detection_trt",
138144
class_name="YOLOv10ForObjectDetectionTRT",
139145
),
146+
("yolov11", CLASSIFICATION_TASK, BackendType.ONNX): RegistryEntry(
147+
model_class=LazyClass(
148+
module_name="inference_exp.models.yolov11.yolov11_onnx",
149+
class_name="YOLOv11ForClassificationOnnx",
150+
),
151+
),
140152
("yolov11", OBJECT_DETECTION_TASK, BackendType.ONNX): RegistryEntry(
141153
model_class=LazyClass(
142154
module_name="inference_exp.models.yolov11.yolov11_onnx",

inference_experimental/inference_exp/models/common/roboflow/post_processing.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,20 +263,18 @@ def rescale_key_points_detections(
263263
dtype=image_detections.dtype,
264264
device=image_detections.device,
265265
).repeat(key_points_slots_in_prediction)
266-
image_detections[:, 5 + num_classes :].sub_(key_points_offsets)
266+
image_detections[:, 6:].sub_(key_points_offsets)
267267
key_points_scale = torch.as_tensor(
268268
[metadata.scale_width, metadata.scale_height, 1.0],
269269
dtype=image_detections.dtype,
270270
device=image_detections.device,
271271
).repeat(key_points_slots_in_prediction)
272-
image_detections[:, 5 + num_classes :].div_(key_points_scale)
272+
image_detections[:, 6:].div_(key_points_scale)
273273
if (
274274
metadata.static_crop_offset.offset_x != 0
275275
or metadata.static_crop_offset.offset_y != 0
276276
):
277-
static_crop_offset_length = (
278-
image_detections.shape[1] - 5 - num_classes
279-
) // 3
277+
static_crop_offset_length = (image_detections.shape[1] - 6) // 3
280278
static_crop_offsets = torch.as_tensor(
281279
[
282280
metadata.static_crop_offset.offset_x,
@@ -287,7 +285,7 @@ def rescale_key_points_detections(
287285
dtype=image_detections.dtype,
288286
device=image_detections.device,
289287
)
290-
image_detections[:, 5 + num_classes :].add_(static_crop_offsets)
288+
image_detections[:, 6:].add_(static_crop_offsets)
291289
static_crop_offsets = torch.as_tensor(
292290
[
293291
metadata.static_crop_offset.offset_x,

inference_experimental/inference_exp/models/yolov11/yolov11_onnx.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from inference_exp.models.yolov8.yolov8_classification_onnx import (
2+
YOLOv8ForClassificationOnnx,
3+
)
14
from inference_exp.models.yolov8.yolov8_instance_segmentation_onnx import (
25
YOLOv8ForInstanceSegmentationOnnx,
36
)
@@ -19,3 +22,7 @@ class YOLOv11ForInstanceSegmentationOnnx(YOLOv8ForInstanceSegmentationOnnx):
1922

2023
class YOLOv11ForForKeyPointsDetectionOnnx(YOLOv8ForKeyPointsDetectionOnnx):
2124
pass
25+
26+
27+
class YOLOv11ForClassificationOnnx(YOLOv8ForClassificationOnnx):
28+
pass
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from threading import Lock
2+
from typing import List, Optional, Tuple, Union
3+
4+
import numpy as np
5+
import torch
6+
from inference_exp import ClassificationModel, ClassificationPrediction
7+
from inference_exp.configuration import DEFAULT_DEVICE
8+
from inference_exp.entities import ColorFormat
9+
from inference_exp.errors import (
10+
CorruptedModelPackageError,
11+
EnvironmentConfigurationError,
12+
MissingDependencyError,
13+
)
14+
from inference_exp.models.base.types import PreprocessedInputs
15+
from inference_exp.models.common.model_packages import get_model_package_contents
16+
from inference_exp.models.common.onnx import (
17+
run_session_with_batch_size_limit,
18+
set_execution_provider_defaults,
19+
)
20+
from inference_exp.models.common.roboflow.model_packages import (
21+
InferenceConfig,
22+
ResizeMode,
23+
parse_class_names_file,
24+
parse_inference_config,
25+
)
26+
from inference_exp.models.common.roboflow.pre_processing import (
27+
pre_process_network_input,
28+
)
29+
from inference_exp.utils.onnx_introspection import get_selected_onnx_execution_providers
30+
31+
try:
32+
import onnxruntime
33+
except ImportError as import_error:
34+
raise MissingDependencyError(
35+
message=f"Could not import ResNet model with ONNX backend - this error means that some additional dependencies "
36+
f"are not installed in the environment. If you run the `inference-exp` library directly in your Python "
37+
f"program, make sure the following extras of the package are installed: \n"
38+
f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
39+
f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
40+
f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
41+
f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
42+
f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
43+
f"You can also contact Roboflow to get support.",
44+
help_url="https://todo",
45+
) from import_error
46+
47+
48+
class YOLOv8ForClassificationOnnx(ClassificationModel[torch.Tensor, torch.Tensor]):
49+
50+
@classmethod
51+
def from_pretrained(
52+
cls,
53+
model_name_or_path: str,
54+
onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
55+
default_onnx_trt_options: bool = True,
56+
device: torch.device = DEFAULT_DEVICE,
57+
**kwargs,
58+
) -> "YOLOv8ForClassificationOnnx":
59+
if onnx_execution_providers is None:
60+
onnx_execution_providers = get_selected_onnx_execution_providers()
61+
if not onnx_execution_providers:
62+
raise EnvironmentConfigurationError(
63+
message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
64+
f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
65+
f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
66+
f"contact the platform support.",
67+
help_url="https://todo",
68+
)
69+
onnx_execution_providers = set_execution_provider_defaults(
70+
providers=onnx_execution_providers,
71+
model_package_path=model_name_or_path,
72+
device=device,
73+
default_onnx_trt_options=default_onnx_trt_options,
74+
)
75+
model_package_content = get_model_package_contents(
76+
model_package_dir=model_name_or_path,
77+
elements=[
78+
"class_names.txt",
79+
"inference_config.json",
80+
"weights.onnx",
81+
],
82+
)
83+
class_names = parse_class_names_file(
84+
class_names_path=model_package_content["class_names.txt"]
85+
)
86+
inference_config = parse_inference_config(
87+
config_path=model_package_content["inference_config.json"],
88+
allowed_resize_modes={
89+
ResizeMode.STRETCH_TO,
90+
ResizeMode.LETTERBOX,
91+
ResizeMode.CENTER_CROP,
92+
ResizeMode.LETTERBOX_REFLECT_EDGES,
93+
},
94+
)
95+
if inference_config.post_processing.type != "softmax":
96+
raise CorruptedModelPackageError(
97+
message="Expected Softmax to be the post-processing",
98+
help_url="https://todo",
99+
)
100+
session = onnxruntime.InferenceSession(
101+
path_or_bytes=model_package_content["weights.onnx"],
102+
providers=onnx_execution_providers,
103+
)
104+
input_shape = session.get_inputs()[0].shape
105+
input_batch_size = input_shape[0]
106+
if isinstance(input_batch_size, str):
107+
input_batch_size = None
108+
input_name = session.get_inputs()[0].name
109+
return cls(
110+
session=session,
111+
input_name=input_name,
112+
inference_config=inference_config,
113+
class_names=class_names,
114+
device=device,
115+
input_batch_size=input_batch_size,
116+
)
117+
118+
def __init__(
119+
self,
120+
session: onnxruntime.InferenceSession,
121+
input_name: str,
122+
inference_config: InferenceConfig,
123+
class_names: List[str],
124+
device: torch.device,
125+
input_batch_size: Optional[int],
126+
):
127+
self._session = session
128+
self._input_name = input_name
129+
self._inference_config = inference_config
130+
self._class_names = class_names
131+
self._device = device
132+
self._input_batch_size = input_batch_size
133+
self._session_thread_lock = Lock()
134+
135+
@property
136+
def class_names(self) -> List[str]:
137+
return self._class_names
138+
139+
def pre_process(
140+
self,
141+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
142+
input_color_format: Optional[ColorFormat] = None,
143+
image_size: Optional[Tuple[int, int]] = None,
144+
**kwargs,
145+
) -> torch.Tensor:
146+
return pre_process_network_input(
147+
images=images,
148+
image_pre_processing=self._inference_config.image_pre_processing,
149+
network_input=self._inference_config.network_input,
150+
target_device=self._device,
151+
input_color_format=input_color_format,
152+
image_size_wh=image_size,
153+
)[0]
154+
155+
def forward(
156+
self, pre_processed_images: PreprocessedInputs, **kwargs
157+
) -> torch.Tensor:
158+
with self._session_thread_lock:
159+
return run_session_with_batch_size_limit(
160+
session=self._session,
161+
inputs={self._input_name: pre_processed_images},
162+
min_batch_size=self._input_batch_size,
163+
max_batch_size=self._input_batch_size,
164+
)[0]
165+
166+
def post_process(
167+
self,
168+
model_results: torch.Tensor,
169+
**kwargs,
170+
) -> ClassificationPrediction:
171+
if self._inference_config.post_processing.fused:
172+
confidence = model_results
173+
else:
174+
confidence = torch.nn.functional.softmax(model_results, dim=-1)
175+
return ClassificationPrediction(
176+
class_id=confidence.argmax(dim=-1),
177+
confidence=confidence,
178+
)

inference_experimental/inference_exp/models/yolov8/yolov8_key_points_detection_onnx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,9 @@ def post_process(
233233
confidence=result[:, 4],
234234
)
235235
)
236-
key_points_reshaped = result[:, 6:].view(result.shape[0], -1, 3)
236+
key_points_reshaped = result[:, 6:].view(
237+
result.shape[0], self._key_points_slots_in_prediction, 3
238+
)
237239
xy = key_points_reshaped[:, :, :2]
238240
confidence = key_points_reshaped[:, :, 2]
239241
key_points_classes_for_instance_class = (

inference_experimental/inference_exp/models/yolov8/yolov8_key_points_detection_torch_script.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ def post_process(
191191
confidence=result[:, 4],
192192
)
193193
)
194-
key_points_reshaped = result[:, 6:].view(result.shape[0], -1, 3)
194+
key_points_reshaped = result[:, 6:].view(
195+
result.shape[0], self._key_points_slots_in_prediction, 3
196+
)
195197
xy = key_points_reshaped[:, :, :2]
196198
confidence = key_points_reshaped[:, :, 2]
197199
key_points_classes_for_instance_class = (

inference_experimental/inference_exp/models/yolov8/yolov8_key_points_detection_trt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def post_process(
257257
confidence=result[:, 4],
258258
)
259259
)
260-
key_points_reshaped = result[:, 6:].view(result.shape[0], -1, 3)
260+
key_points_reshaped = result[:, 6:].view(
261+
result.shape[0], self._key_points_slots_in_prediction, 3
262+
)
261263
xy = key_points_reshaped[:, :, :2]
262264
confidence = key_points_reshaped[:, :, 2]
263265
key_points_classes_for_instance_class = (

inference_experimental/tests/integration_tests/models/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@
118118
YOLOV8N_POSE_TORCHSCRIPT_STATIC_NMS_FUSED_CENTER_CROP_PACKAGE_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/yolov8n-pose-torchscript-static-nms-fused-center-crop.zip"
119119
YOLOV8N_POSE_TORCHSCRIPT_STATIC_NMS_FUSED_STATIC_CROP_CENTER_CROP_PACKAGE_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/yolov8n-pose-torchscript-static-nms-fused-static-crop-center-crop.zip"
120120

121+
YOLOV8_CLS_ONNX_PACKAGE_URL = "https://storage.googleapis.com/roboflow-tests-assets/rf-platform-models/yolov8-cls-onnx-static-bs.zip"
122+
121123

122124
@pytest.fixture(scope="module")
123125
def original_clip_download_dir() -> str:
@@ -942,3 +944,11 @@ def yolov8n_pose_torchscript_static_nms_fused_static_crop_center_crop_package()
942944
model_package_zip_url=YOLOV8N_POSE_TORCHSCRIPT_STATIC_NMS_FUSED_STATIC_CROP_CENTER_CROP_PACKAGE_URL,
943945
package_name="yolov8n-pose-torchscript-static-nms-fused-static-crop-center-crop",
944946
)
947+
948+
949+
@pytest.fixture(scope="module")
950+
def yolov8_cls_static_bs_onnx_package() -> str:
951+
return download_model_package(
952+
model_package_zip_url=YOLOV8_CLS_ONNX_PACKAGE_URL,
953+
package_name="yolov8-cls-static-onnx",
954+
)

0 commit comments

Comments
 (0)