Skip to content

Commit af65c81

Browse files
committed
fix sam
1 parent 99b76ce commit af65c81

File tree

4 files changed

+71
-12
lines changed

4 files changed

+71
-12
lines changed

label_studio_ml/examples/segment_anything_model/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
5959

6060
COPY . .
6161

62-
# Add ONNX model
63-
RUN python3 onnxconverter.py
62+
# Add ONNX model (skip if it fails - not critical for basic functionality)
63+
RUN python3 onnxconverter.py || echo "Warning: ONNX conversion failed, but continuing build"
6464

6565
EXPOSE 9090
6666

label_studio_ml/examples/segment_anything_model/onnxconverter.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,24 @@ def convert(checkpoint_path):
5555
dynamic_axes=dynamic_axes,
5656
)
5757

58-
quantize_dynamic(
59-
model_input=onnx_model_path,
60-
model_output=onnx_model_quantized_path,
61-
optimize_model=True,
62-
per_channel=False,
63-
reduce_range=False,
64-
weight_type=QuantType.QUInt8,
65-
)
58+
# Newer versions of onnxruntime don't have optimize_model parameter
59+
try:
60+
quantize_dynamic(
61+
model_input=onnx_model_path,
62+
model_output=onnx_model_quantized_path,
63+
optimize_model=True,
64+
per_channel=False,
65+
reduce_range=False,
66+
weight_type=QuantType.QUInt8,
67+
)
68+
except TypeError:
69+
# Fallback for newer onnxruntime versions without optimize_model
70+
quantize_dynamic(
71+
model_input=onnx_model_path,
72+
model_output=onnx_model_quantized_path,
73+
per_channel=False,
74+
reduce_range=False,
75+
weight_type=QuantType.QUInt8,
76+
)
6677

6778
convert(VITH_CHECKPOINT)

label_studio_ml/examples/segment_anything_model/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
numpy>=2,<2.3.0
22
label_studio_converter
33
opencv-python-headless>=4.12.0,<5.0.0
4-
onnxruntime==1.15.1
5-
onnx==1.12.0
4+
onnxruntime>=1.18.0
5+
onnx>=1.15.0
66
torch==2.0.1
77
torchvision==0.15.2
88
gunicorn==22.0.0

label_studio_ml/examples/segment_anything_model/sam_predictor.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,52 @@
99
from label_studio_ml.utils import InMemoryLRUDictCache
1010
from label_studio_sdk._extensions.label_studio_tools.core.utils.io import get_local_path
1111

12+
# Monkey-patch torch.as_tensor to handle numpy 2.x compatibility
13+
_original_as_tensor = torch.as_tensor
14+
def _patched_as_tensor(data, dtype=None, device=None):
15+
"""Patched version of torch.as_tensor that handles numpy 2.x compatibility"""
16+
if isinstance(data, np.ndarray):
17+
# For numpy 2.x compatibility, ensure arrays are properly converted
18+
if dtype is None and data.dtype == np.uint8:
19+
# Explicitly convert uint8 arrays
20+
return _original_as_tensor(data.copy(), dtype=torch.uint8, device=device)
21+
elif dtype is not None:
22+
# If dtype is specified, ensure the array is compatible
23+
if data.dtype == np.float32 and dtype == torch.int:
24+
# Convert float32 to int properly
25+
return _original_as_tensor(data.astype(np.int32), dtype=dtype, device=device)
26+
return _original_as_tensor(data, dtype=dtype, device=device)
27+
torch.as_tensor = _patched_as_tensor
28+
29+
# Also patch tensor.numpy() to handle numpy 2.x compatibility
30+
_original_tensor_numpy = torch.Tensor.numpy
31+
def _patched_tensor_numpy(self, *args, **kwargs):
32+
"""Patched version of tensor.numpy() that handles numpy 2.x compatibility"""
33+
try:
34+
return _original_tensor_numpy(self, *args, **kwargs)
35+
except RuntimeError as e:
36+
if "Numpy is not available" in str(e):
37+
# Fallback: manually convert tensor to numpy array
38+
# This is a workaround for numpy 2.x compatibility issues
39+
arr = self.detach().cpu().contiguous()
40+
# Convert to list first, then to numpy array
41+
if arr.dim() == 0:
42+
return np.array(arr.item())
43+
else:
44+
# Map torch dtypes to numpy dtypes
45+
dtype_map = {
46+
torch.float32: np.float32,
47+
torch.float64: np.float64,
48+
torch.int32: np.int32,
49+
torch.int64: np.int64,
50+
torch.uint8: np.uint8,
51+
torch.bool: np.bool_,
52+
}
53+
np_dtype = dtype_map.get(arr.dtype, None)
54+
return np.array(arr.tolist(), dtype=np_dtype)
55+
raise
56+
torch.Tensor.numpy = _patched_tensor_numpy
57+
1258
logger = logging.getLogger(__name__)
1359
_MODELS_DIR = pathlib.Path(__file__).parent / "models"
1460

@@ -91,6 +137,8 @@ def set_image(self, img_path, calculate_embeddings=True, task=None):
91137
)
92138
image = cv2.imread(image_path)
93139
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
140+
# Ensure image is contiguous and properly typed for numpy 2.x compatibility
141+
image = np.ascontiguousarray(image, dtype=np.uint8)
94142
self.predictor.set_image(image)
95143
payload = {'image_shape': image.shape[:2]}
96144
logger.debug(f'Finished set_image({img_path}) in `IN_MEM_CACHE`: image shape {image.shape[:2]}')

0 commit comments

Comments
 (0)