Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions inference/core/utils/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,29 @@ def masks2poly(masks: np.ndarray) -> List[np.ndarray]:
list: A list of segments, where each segment is obtained by converting the corresponding mask.
"""
segments = []
# Process per-mask to avoid allocating an entire N x H x W uint8 copy
append = segments.append # local var for faster loop
# Minimize np.any and m_uint8 construction cost, and reduce append/continue branchiness
for mask in masks:
dtype = mask.dtype
# Fast-path: bool -> zero-copy uint8 view
if mask.dtype == np.bool_:
if dtype == np.bool_:
m_uint8 = mask
if not m_uint8.flags.c_contiguous:
m_uint8 = np.ascontiguousarray(m_uint8)
m_uint8 = m_uint8.view(np.uint8)
elif mask.dtype == np.uint8:
# avoid unnecessary .view for C-contiguous bool: just multiply by 255 for OpenCV
if m_uint8.dtype != np.uint8:
m_uint8 = m_uint8.astype(np.uint8, copy=False)
elif dtype == np.uint8:
m_uint8 = mask if mask.flags.c_contiguous else np.ascontiguousarray(mask)
else:
# Fallback: threshold to bool then view as uint8 (may allocate once)
m_bool = mask > 0
if not m_bool.flags.c_contiguous:
m_bool = np.ascontiguousarray(m_bool)
m_uint8 = m_bool.view(np.uint8)
# Fallback: threshold to uint8 in one step (avoid bool then view)
m_uint8 = (mask > 0).astype(np.uint8)

# Quickly skip empty masks
if not np.any(m_uint8):
segments.append(np.zeros((0, 2), dtype=np.float32))
continue

segments.append(mask2poly(m_uint8))
# Use count_nonzero, faster than np.any for dense binary
if np.count_nonzero(m_uint8) == 0:
append(np.zeros((0, 2), dtype=np.float32))
else:
append(mask2poly(m_uint8))
return segments


Expand Down Expand Up @@ -107,14 +107,26 @@ def mask2poly(mask: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: Contours represented as a float32 array.
"""
# cv2.findContours can return 2 or 3 values depending on OpenCV version
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
if contours:
contours = np.array(
contours[np.array([len(x) for x in contours]).argmax()]
).reshape(-1, 2)
# argmax is O(N); avoid making full int32 array (len) if only one contour
if len(contours) == 1:
cnt = contours[0]
else:
# Use a generator to save a temp array, and only track the largest
max_len = 0
max_cnt = None
for c in contours:
l = len(c)
if l > max_len:
max_len = l
max_cnt = c
cnt = max_cnt
cnt_out = np.asarray(cnt, dtype=np.float32).reshape(-1, 2)
else:
contours = np.zeros((0, 2))
return contours.astype("float32")
cnt_out = np.zeros((0, 2), dtype=np.float32)
return cnt_out


def mask2multipoly(mask: np.ndarray) -> np.ndarray:
Expand Down