Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Segment Anything 2 #8243

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ up to 10x. Here is a list of the algorithms we support, and the platforms they c

| Name | Type | Framework | CPU | GPU |
| ------------------------------------------------------------------------------------------------------- | ---------- | ---------- | --- | --- |
| [Segment Anything 2.0](/serverless/pytorch/facebookresearch/sam2/nuclio/) | interactor | PyTorch | ✔️ | ✔️ |
| [Segment Anything](/serverless/pytorch/facebookresearch/sam/nuclio/) | interactor | PyTorch | ✔️ | ✔️ |
| [Deep Extreme Cut](/serverless/openvino/dextr/nuclio) | interactor | OpenVINO | ✔️ | |
| [Faster RCNN](/serverless/openvino/omz/public/faster_rcnn_inception_resnet_v2_atrous_coco/nuclio) | detector | OpenVINO | ✔️ | |
Expand Down
4 changes: 4 additions & 0 deletions changelog.d/20240731_000641_ruelj2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Added

- Added support for the Segment Anything 2.0 as a Nuclio serverless function. Currently fully supports SAM2 on GPU and CPU.
(<https://github.com/cvat-ai/cvat/pull/8243>)
67 changes: 67 additions & 0 deletions serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (C) 2023-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

metadata:
name: pth-facebookresearch-sam2-vit-h
namespace: cvat
annotations:
name: Segment Anything 2.0
version: 2
type: interactor
spec:
min_pos_points: 1
min_neg_points: 0
animated_gif: https://raw.githubusercontent.com/cvat-ai/cvat/develop/site/content/en/images/hrnet_example.gif
help_message: The interactor allows to get a mask of an object using at least one positive, and any negative points inside it

spec:
description: Interactive object segmentation with Segment-Anything 2.0
runtime: 'python:3.8'
handler: main:handler
eventTimeout: 30s

build:
image: cvat.pth.facebookresearch.sam2.vit_h:latest-gpu
baseImage: pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
directives:
preCopy:
# set NVIDIA container runtime settings
- kind: ENV
value: NVIDIA_VISIBLE_DEVICES=all
- kind: ENV
value: NVIDIA_DRIVER_CAPABILITIES=compute,utility
# disable interactive frontend
- kind: ENV
value: DEBIAN_FRONTEND=noninteractive
- kind: ENV
value: TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9+PTX"
# set workdir
- kind: WORKDIR
value: /opt/nuclio/sam2
# install basic deps
- kind: RUN
value: apt-get update && apt-get -y install build-essential curl git
# install sam2 code
- kind: RUN
value: pip install git+https://github.com/facebookresearch/segment-anything-2.git@main
# download sam2 weights
- kind: RUN
value: curl -O https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt
triggers:
myHttpTrigger:
maxWorkers: 1
kind: 'http'
workerAvailabilityTimeoutMilliseconds: 10000
attributes:
maxRequestBodySize: 33554432 # 32MB
resources:
limits:
nvidia.com/gpu: 1

platform:
attributes:
restartPolicy:
name: always
maximumRetryCount: 3
mountMode: volume
60 changes: 60 additions & 0 deletions serverless/pytorch/facebookresearch/sam2/nuclio/function.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (C) 2023-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

metadata:
name: pth-facebookresearch-sam2-vit-h
namespace: cvat
annotations:
name: Segment Anything 2.0
version: 2
type: interactor
spec:
min_pos_points: 1
min_neg_points: 0
animated_gif: https://raw.githubusercontent.com/cvat-ai/cvat/develop/site/content/en/images/hrnet_example.gif
help_message: The interactor allows to get a mask of an object using at least one positive, and any negative points inside it

spec:
description: Interactive object segmentation with Segment-Anything 2.0
runtime: 'python:3.8'
handler: main:handler
eventTimeout: 30s

build:
image: cvat.pth.facebookresearch.sam2.vit_h
baseImage: ubuntu:22.04
directives:
preCopy:
# disable interactive frontend
- kind: ENV
value: DEBIAN_FRONTEND=noninteractive
# set workdir
- kind: WORKDIR
value: /opt/nuclio/sam2
# install basic deps
- kind: RUN
value: apt-get update && apt-get -y install build-essential curl git python3 python3-pip ffmpeg libsm6 libxext6
# install sam2 code
- kind: RUN
value: SAM2_BUILD_CUDA=0 pip install git+https://github.com/facebookresearch/segment-anything-2.git@main
# download sam2 weights
- kind: RUN
value: curl -O https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt
# map pip3 and python3 to pip and python
- kind: RUN
value: ln -s /usr/bin/pip3 /usr/local/bin/pip && ln -s /usr/bin/python3 /usr/bin/python
triggers:
myHttpTrigger:
maxWorkers: 2
kind: 'http'
workerAvailabilityTimeoutMilliseconds: 10000
attributes:
maxRequestBodySize: 33554432 # 32MB

platform:
attributes:
restartPolicy:
name: always
maximumRetryCount: 3
mountMode: volume
41 changes: 41 additions & 0 deletions serverless/pytorch/facebookresearch/sam2/nuclio/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (C) 2023-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import json
import base64
from PIL import Image
import io
from model_handler import ModelHandler

def init_context(context):
model = ModelHandler()
context.user_data.model = model
context.logger.info("Init context...100%")

def handler(context, event):
try:
context.logger.info("call handler")
data = event.body
buf = io.BytesIO(base64.b64decode(data["image"]))
image = Image.open(buf)
image = image.convert("RGB") # to make sure image comes in RGB
pos_points = data["pos_points"]
neg_points = data["neg_points"]

mask = context.user_data.model.handle(image, pos_points, neg_points)

return context.Response(
body=json.dumps({'mask': mask.tolist()}),
headers={},
content_type='application/json',
status_code=200
)
except Exception as e:
context.logger.error(f"Error in handler: {str(e)}")
return context.Response(
body=json.dumps({'error': str(e)}),
headers={},
content_type='application/json',
status_code=500
)
35 changes: 35 additions & 0 deletions serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (C) 2023-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import numpy as np
import torch
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

class ModelHandler:
def __init__(self):
self.device = torch.device('cpu')
if torch.cuda.is_available():
self.device = torch.device('cuda')
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

self.sam_checkpoint = "./sam2_hiera_large.pt"
self.model_cfg = "sam2_hiera_l.yaml"
self.predictor = SAM2ImagePredictor(build_sam2(self.model_cfg, self.sam_checkpoint, device=self.device))

def handle(self, image, pos_points, neg_points):
pos_points, neg_points = list(pos_points), list(neg_points)
with torch.inference_mode():
self.predictor.set_image(np.array(image))
masks, scores, _ = self.predictor.predict(
point_coords=np.array(pos_points + neg_points),
point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)),
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
best_mask = masks[sorted_ind][0]
return best_mask
Loading