Skip to content

Commit

Permalink
Merge pull request #1788 from roboflow/feat/transformers-keypoints
Browse files Browse the repository at this point in the history
feat: ✨ Add xyxy_xywh function and from_transformers method for KeyPoints class added
  • Loading branch information
SkalskiP authored Feb 17, 2025
2 parents d381f43 + c06d5f0 commit 63388fd
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 2 deletions.
6 changes: 6 additions & 0 deletions docs/detection/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ status: new

:::supervision.detection.utils.xywh_to_xyxy

<div class="md-typeset">
<h2><a href="#supervision.detection.utils.xyxy_to_xywh">xyxy_to_xywh</a></h2>
</div>

:::supervision.detection.utils.xyxy_to_xywh

<div class="md-typeset">
<h2><a href="#supervision.detection.utils.xcycwh_to_xyxy">xcycwh_to_xyxy</a></h2>
</div>
Expand Down
2 changes: 2 additions & 0 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
xcycwh_to_xyxy,
xywh_to_xyxy,
xyxy_to_polygons,
xyxy_to_xywh,
)
from supervision.draw.color import Color, ColorPalette
from supervision.draw.utils import (
Expand Down Expand Up @@ -226,4 +227,5 @@
"xcycwh_to_xyxy",
"xywh_to_xyxy",
"xyxy_to_polygons",
"xyxy_to_xywh",
]
37 changes: 37 additions & 0 deletions supervision/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,43 @@ def xywh_to_xyxy(xywh: np.ndarray) -> np.ndarray:
return xyxy


def xyxy_to_xywh(xyxy: np.ndarray) -> np.ndarray:
"""
Converts bounding box coordinates from `(x_min, y_min, x_max, y_max)`
format to `(x, y, width, height)` format.
Args:
xyxy (np.ndarray): A numpy array of shape `(N, 4)` where each row
corresponds to a bounding box in the format `(x_min, y_min, x_max,
y_max)`.
Returns:
np.ndarray: A numpy array of shape `(N, 4)` where each row corresponds
to a bounding box in the format `(x, y, width, height)`.
Examples:
```python
import numpy as np
import supervision as sv
xyxy = np.array([
[10, 20, 40, 60],
[15, 25, 50, 70]
])
sv.xyxy_to_xywh(xyxy=xyxy)
# array([
# [10, 20, 30, 40],
# [15, 25, 35, 45]
# ])
```
"""
xywh = xyxy.copy()
xywh[:, 2] = xyxy[:, 2] - xyxy[:, 0]
xywh[:, 3] = xyxy[:, 3] - xyxy[:, 1]
return xywh


def xcycwh_to_xyxy(xcycwh: np.ndarray) -> np.ndarray:
"""
Converts bounding box coordinates from `(center_x, center_y, width, height)`
Expand Down
91 changes: 89 additions & 2 deletions supervision/keypoint/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class simplifies data manipulation and filtering, providing a uniform API for
method, which accepts [MediaPipe](https://github.com/google-ai-edge/mediapipe)
pose result.
```python
import cv2
import mediapipe as mp
Expand Down Expand Up @@ -314,6 +315,7 @@ def from_mediapipe(
key_points = sv.KeyPoints.from_mediapipe(
face_landmarker_result, (image_width, image_height))
```
""" # noqa: E501 // docs
if hasattr(mediapipe_results, "pose_landmarks"):
results = mediapipe_results.pose_landmarks
Expand Down Expand Up @@ -473,7 +475,7 @@ def from_detectron2(cls, detectron2_results: Any) -> KeyPoints:
A `sv.KeyPoints` object containing the keypoint coordinates, class IDs,
and class names, and confidences of each keypoint.
Example:
Examples:
```python
import cv2
import supervision as sv
Expand Down Expand Up @@ -510,6 +512,91 @@ def from_detectron2(cls, detectron2_results: Any) -> KeyPoints:
else:
return cls.empty()

@classmethod
def from_transformers(cls, transfomers_results: Any) -> KeyPoints:
"""
Create a `sv.KeyPoints` object from the
[Transformers](https://github.com/huggingface/transformers) inference result.
Args:
transfomers_results (Any): The output of a
Transformers model containing instances with prediction data.
Returns:
A `sv.KeyPoints` object containing the keypoint coordinates, class IDs,
and class names, and confidences of each keypoint.
Examples:
```python
from PIL import Image
import requests
import supervision as sv
import torch
from transformers import (
AutoProcessor,
RTDetrForObjectDetection,
VitPoseForPoseEstimation,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
image = Image.open(<SOURCE_IMAGE_PATH>)
DETECTION_MODEL_ID = "PekingU/rtdetr_r50vd_coco_o365"
detection_processor = AutoProcessor.from_pretrained(DETECTION_MODEL_ID, use_fast=True)
detection_model = RTDetrForObjectDetection.from_pretrained(DETECTION_MODEL_ID, device_map=DEVICE)
inputs = detection_processor(images=frame, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = detection_model(**inputs)
target_size = torch.tensor([(frame.height, frame.width)])
results = detection_processor.post_process_object_detection(
outputs, target_sizes=target_size, threshold=0.3)
detections = sv.Detections.from_transformers(results[0])
boxes = sv.xyxy_to_xywh(detections[detections.class_id == 0].xyxy)
POSE_ESTIMATION_MODEL_ID = "usyd-community/vitpose-base-simple"
pose_estimation_processor = AutoProcessor.from_pretrained(POSE_ESTIMATION_MODEL_ID)
pose_estimation_model = VitPoseForPoseEstimation.from_pretrained(
POSE_ESTIMATION_MODEL_ID, device_map=DEVICE)
inputs = pose_estimation_processor(frame, boxes=[boxes], return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = pose_estimation_model(**inputs)
results = pose_estimation_processor.post_process_pose_estimation(outputs, boxes=[boxes])
key_point = sv.KeyPoints.from_transformers(results[0])
```
""" # noqa: E501 // docs

if "keypoints" in transfomers_results[0]:
if transfomers_results[0]["keypoints"].cpu().numpy().size == 0:
return cls.empty()

result_data = [
(
result["keypoints"].cpu().numpy(),
result["scores"].cpu().numpy(),
)
for result in transfomers_results
]

xy, scores = zip(*result_data)

return cls(
xy=np.stack(xy).astype(np.float32),
confidence=np.stack(scores).astype(np.float32),
class_id=np.arange(len(xy)).astype(int),
)
else:
return cls.empty()

def __getitem__(
self, index: Union[int, slice, List[int], np.ndarray, str]
) -> Union[KeyPoints, List, np.ndarray, None]:
Expand Down Expand Up @@ -639,7 +726,7 @@ def as_detections(
Returns:
detections (Detections): The converted detections object.
Example:
Examples:
```python
keypoints = sv.KeyPoints.from_inference(...)
detections = keypoints.as_detections()
Expand Down
24 changes: 24 additions & 0 deletions test/detection/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
scale_boxes,
xcycwh_to_xyxy,
xywh_to_xyxy,
xyxy_to_xywh,
)

TEST_MASK = np.zeros((1, 1000, 1000), dtype=bool)
Expand Down Expand Up @@ -1381,6 +1382,29 @@ def test_xywh_to_xyxy(xywh: np.ndarray, expected_result: np.ndarray) -> None:
np.testing.assert_array_equal(result, expected_result)


@pytest.mark.parametrize(
"xyxy, expected_result",
[
(np.array([[10, 20, 40, 60]]), np.array([[10, 20, 30, 40]])), # standard case
(np.array([[0, 0, 0, 0]]), np.array([[0, 0, 0, 0]])), # zero size bounding box
(
np.array([[50, 50, 150, 150]]),
np.array([[50, 50, 100, 100]]),
), # large bounding box
(
np.array([[-10, -20, 20, 20]]),
np.array([[-10, -20, 30, 40]]),
), # negative coordinates
(np.array([[50, 50, 50, 80]]), np.array([[50, 50, 0, 30]])), # zero width
(np.array([[50, 50, 70, 50]]), np.array([[50, 50, 20, 0]])), # zero height
(np.array([]).reshape(0, 4), np.array([]).reshape(0, 4)), # empty array
],
)
def test_xyxy_to_xywh(xyxy: np.ndarray, expected_result: np.ndarray) -> None:
result = xyxy_to_xywh(xyxy)
np.testing.assert_array_equal(result, expected_result)


@pytest.mark.parametrize(
"xcycwh, expected_result",
[
Expand Down

0 comments on commit 63388fd

Please sign in to comment.