-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce Segment Anything 2 #8243
base: develop
Are you sure you want to change the base?
Conversation
Segment Anything 2.0 require to compile a .cu file with nvcc at build time. Hence, a cuda devel baseImage is required to build the nuclio container.
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughThe recent update enhances the documentation and functionality of a serverless image segmentation service using the Segment Anything 2.0 model. Key changes include the addition of a new entry in the README.md, the introduction of configuration and processing scripts for serverless deployment, and improvements for GPU optimization. Overall, these changes streamline the integration of advanced segmentation capabilities, making it more accessible for developers. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant HTTPTrigger
participant MainHandler
participant ModelHandler
User->>HTTPTrigger: Send image and points
HTTPTrigger->>MainHandler: Forward request
MainHandler->>MainHandler: Initialize context
MainHandler->>ModelHandler: Process image with points
ModelHandler->>ModelHandler: Generate mask
ModelHandler-->>MainHandler: Return mask
MainHandler-->>HTTPTrigger: Send response with mask
HTTPTrigger-->>User: Display result
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Outside diff range, codebase verification and nitpick comments (2)
changelog.d/20240731_000641_ruelj2.md (1)
3-3
: Improve conciseness by removing redundant adverbs.The phrase "currently fully" can be simplified for clarity.
- This fork currently fully supports Segment Anything 2.0 on GPU with the large model. + This fork supports Segment Anything 2.0 on GPU with the large model.Tools
LanguageTool
[style] ~3-~3: Using two adverbs in a row like this may make your sentence unnecessarily wordy. For clarity and conciseness, see if you can omit one of them.
Context: ... the integration with Nuclio. This fork currently fully supports Segment Anything 2.0 on GPU wi...(LY_LY)
serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py (1)
10-15
: Ensure GPU availability.The initialization of
SAM2ImagePredictor
assumes GPU availability without checking if CUDA is available. This can lead to issues if CUDA is not available. Consider modifying the initialization to handle cases where CUDA is not available.
serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py
, line 15:self.predictor = SAM2ImagePredictor(build_sam2(self.model_cfg, self.sam_checkpoint, device="cuda"))
Analysis chain
Ensure GPU availability.
The initialization assumes GPU availability. Consider handling cases where CUDA is not available.
Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify if `torch.cuda.is_available()` is always true in the codebase. # Test: Search for the usage of `torch.cuda.is_available()`. Expect: Proper handling of cases where CUDA is not available. rg --type python -A 5 $'torch.cuda.is_available()'Length of output: 82
Script:
#!/bin/bash # Description: Verify if `torch.cuda.is_available()` is always true in the codebase. # Test: Search for the usage of `torch.cuda.is_available()`. Expect: Proper handling of cases where CUDA is not available. rg -A 5 'torch.cuda.is_available()'Length of output: 4749
Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Files selected for processing (5)
- README.md (1 hunks)
- changelog.d/20240731_000641_ruelj2.md (1 hunks)
- serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml (1 hunks)
- serverless/pytorch/facebookresearch/sam2/nuclio/main.py (1 hunks)
- serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py (1 hunks)
Files skipped from review due to trivial changes (1)
- README.md
Additional context used
LanguageTool
changelog.d/20240731_000641_ruelj2.md
[style] ~3-~3: Using two adverbs in a row like this may make your sentence unnecessarily wordy. For clarity and conciseness, see if you can omit one of them.
Context: ... the integration with Nuclio. This fork currently fully supports Segment Anything 2.0 on GPU wi...(LY_LY)
Additional comments not posted (7)
serverless/pytorch/facebookresearch/sam2/nuclio/main.py (1)
16-19
: Verify compatibility with Ampere GPUs.Ensure that the device properties check for Ampere GPUs is correct and covers all necessary conditions.
Verification successful
The compatibility check for Ampere GPUs is correctly implemented.
The code correctly checks if the major version of the CUDA device properties is 8 or higher to identify Ampere GPUs and enables TensorFloat-32 (TF32) support accordingly.
serverless/pytorch/facebookresearch/sam2/nuclio/main.py
:
- Line 16:
if torch.cuda.get_device_properties(0).major >= 8:
Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the compatibility check for Ampere GPUs. # Test: Search for the usage of `torch.cuda.get_device_properties`. Expect: Proper handling of Ampere GPUs. rg --type python -A 5 $'torch.cuda.get_device_properties'Length of output: 87
Script:
#!/bin/bash # Description: Verify the compatibility check for Ampere GPUs. # Test: Search for the usage of `torch.cuda.get_device_properties`. Expect: Proper handling of Ampere GPUs. rg -A 5 'torch.cuda.get_device_properties'Length of output: 692
serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml (6)
5-17
: Ensure themin_pos_points
andmin_neg_points
values are correct.The metadata section defines the minimum number of positive and negative points required. Verify that these values align with the requirements of the Segment Anything 2.0 model.
18-22
: Ensure the runtime and handler are correctly defined.The runtime is set to Python 3.8, and the handler is defined as
main:handler
. Verify that these values are correct and compatible with the Segment Anything 2.0 model.
24-50
: Ensure the base image and build directives are correct.The base image is set to
pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
. Verify that this image is compatible with the Segment Anything 2.0 model and the CUDA version required.Check the installation of dependencies and weights.
The build directives include installing dependencies and downloading weights. Ensure that these steps are correctly defined and necessary for the Segment Anything 2.0 model.
51-57
: Ensure the HTTP trigger attributes are correct.The HTTP trigger includes attributes like
maxWorkers
andmaxRequestBodySize
. Verify that these values are appropriate for the expected workload and data size.
58-60
: Ensure the GPU resource limits are correct.The resource limits specify using one GPU. Verify that this is sufficient for the Segment Anything 2.0 model's requirements.
62-67
: Ensure the platform attributes are correct.The platform section includes a restart policy and mount mode. Verify that these values are appropriate for the Nuclio function's deployment environment.
def handle(self, image, pos_points, neg_points): | ||
pos_points, neg_points = list(pos_points), list(neg_points) | ||
with torch.inference_mode(): | ||
self.predictor.set_image(np.array(image)) | ||
masks, scores, logits = self.predictor.predict( | ||
point_coords=np.array(pos_points + neg_points), | ||
point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)), | ||
multimask_output=True, | ||
) | ||
sorted_ind = np.argsort(scores)[::-1] | ||
best_mask = masks[sorted_ind][0] | ||
return best_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure proper error handling.
The handle
method does not include error handling. Consider adding try-except blocks to handle potential errors during prediction.
- def handle(self, image, pos_points, neg_points):
- pos_points, neg_points = list(pos_points), list(neg_points)
- with torch.inference_mode():
- self.predictor.set_image(np.array(image))
- masks, scores, logits = self.predictor.predict(
- point_coords=np.array(pos_points + neg_points),
- point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)),
- multimask_output=True,
- )
- sorted_ind = np.argsort(scores)[::-1]
- best_mask = masks[sorted_ind][0]
- return best_mask
+ def handle(self, image, pos_points, neg_points):
+ try:
+ pos_points, neg_points = list(pos_points), list(neg_points)
+ with torch.inference_mode():
+ self.predictor.set_image(np.array(image))
+ masks, scores, logits = self.predictor.predict(
+ point_coords=np.array(pos_points + neg_points),
+ point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)),
+ multimask_output=True,
+ )
+ sorted_ind = np.argsort(scores)[::-1]
+ best_mask = masks[sorted_ind][0]
+ return best_mask
+ except Exception as e:
+ # Handle or log the error as needed
+ raise RuntimeError("Error during prediction") from e
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def handle(self, image, pos_points, neg_points): | |
pos_points, neg_points = list(pos_points), list(neg_points) | |
with torch.inference_mode(): | |
self.predictor.set_image(np.array(image)) | |
masks, scores, logits = self.predictor.predict( | |
point_coords=np.array(pos_points + neg_points), | |
point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)), | |
multimask_output=True, | |
) | |
sorted_ind = np.argsort(scores)[::-1] | |
best_mask = masks[sorted_ind][0] | |
return best_mask | |
def handle(self, image, pos_points, neg_points): | |
try: | |
pos_points, neg_points = list(pos_points), list(neg_points) | |
with torch.inference_mode(): | |
self.predictor.set_image(np.array(image)) | |
masks, scores, logits = self.predictor.predict( | |
point_coords=np.array(pos_points + neg_points), | |
point_labels=np.array([1]*len(pos_points) + [0]*len(neg_points)), | |
multimask_output=True, | |
) | |
sorted_ind = np.argsort(scores)[::-1] | |
best_mask = masks[sorted_ind][0] | |
return best_mask | |
except Exception as e: | |
# Handle or log the error as needed | |
raise RuntimeError("Error during prediction") from e |
def init_context(context): | ||
# use bfloat16 for the entire notebook | ||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure proper cleanup of autocast context.
The torch.autocast
context is entered but never exited. Ensure proper cleanup to avoid potential issues.
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
+ context.user_data.autocast = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
+ context.user_data.autocast.__enter__()
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def init_context(context): | |
# use bfloat16 for the entire notebook | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
def init_context(context): | |
# use bfloat16 for the entire notebook | |
context.user_data.autocast = torch.autocast(device_type="cuda", dtype=torch.bfloat16) | |
context.user_data.autocast.__enter__() |
def handler(context, event): | ||
context.logger.info("call handler") | ||
data = event.body | ||
buf = io.BytesIO(base64.b64decode(data["image"])) | ||
context.logger.info(f"data: {data}") | ||
image = Image.open(buf) | ||
image = image.convert("RGB") # to make sure image comes in RGB | ||
pos_points = data["pos_points"] | ||
neg_points = data["neg_points"] | ||
|
||
mask = context.user_data.model.handle(image, pos_points, neg_points) | ||
|
||
return context.Response( | ||
body=json.dumps({ 'mask': mask.tolist() }), | ||
headers={}, | ||
content_type='application/json', | ||
status_code=200 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure proper error handling in handler function.
The handler
function does not include error handling. Consider adding try-except blocks to handle potential errors during processing.
- context.logger.info("call handler")
- data = event.body
- buf = io.BytesIO(base64.b64decode(data["image"]))
- context.logger.info(f"data: {data}")
- image = Image.open(buf)
- image = image.convert("RGB") # to make sure image comes in RGB
- pos_points = data["pos_points"]
- neg_points = data["neg_points"]
-
- mask = context.user_data.model.handle(image, pos_points, neg_points)
-
- return context.Response(
- body=json.dumps({ 'mask': mask.tolist() }),
- headers={},
- content_type='application/json',
- status_code=200
- )
+ try:
+ context.logger.info("call handler")
+ data = event.body
+ buf = io.BytesIO(base64.b64decode(data["image"]))
+ context.logger.info(f"data: {data}")
+ image = Image.open(buf)
+ image = image.convert("RGB") # to make sure image comes in RGB
+ pos_points = data["pos_points"]
+ neg_points = data["neg_points"]
+
+ mask = context.user_data.model.handle(image, pos_points, neg_points)
+
+ return context.Response(
+ body=json.dumps({ 'mask': mask.tolist() }),
+ headers={},
+ content_type='application/json',
+ status_code=200
+ )
+ except Exception as e:
+ context.logger.error(f"Error processing request: {e}")
+ return context.Response(
+ body=json.dumps({ 'error': str(e) }),
+ headers={},
+ content_type='application/json',
+ status_code=500
+ )
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def handler(context, event): | |
context.logger.info("call handler") | |
data = event.body | |
buf = io.BytesIO(base64.b64decode(data["image"])) | |
context.logger.info(f"data: {data}") | |
image = Image.open(buf) | |
image = image.convert("RGB") # to make sure image comes in RGB | |
pos_points = data["pos_points"] | |
neg_points = data["neg_points"] | |
mask = context.user_data.model.handle(image, pos_points, neg_points) | |
return context.Response( | |
body=json.dumps({ 'mask': mask.tolist() }), | |
headers={}, | |
content_type='application/json', | |
status_code=200 | |
) | |
def handler(context, event): | |
try: | |
context.logger.info("call handler") | |
data = event.body | |
buf = io.BytesIO(base64.b64decode(data["image"])) | |
context.logger.info(f"data: {data}") | |
image = Image.open(buf) | |
image = image.convert("RGB") # to make sure image comes in RGB | |
pos_points = data["pos_points"] | |
neg_points = data["neg_points"] | |
mask = context.user_data.model.handle(image, pos_points, neg_points) | |
return context.Response( | |
body=json.dumps({ 'mask': mask.tolist() }), | |
headers={}, | |
content_type='application/json', | |
status_code=200 | |
) | |
except Exception as e: | |
context.logger.error(f"Error processing request: {e}") | |
return context.Response( | |
body=json.dumps({ 'error': str(e) }), | |
headers={}, | |
content_type='application/json', | |
status_code=500 | |
) |
This looks really good! What would it take to also integrate the tracking capabilities of SAM 2? |
@HanClinto I think the best approach would be to return the SAM2 memory bank queue to the user or a DB. This way, we could ensure the SAM2 service is stateless. At the moment I ignore the overhead of doing so, but the article states that the memory banks is composed of "spatial feature maps" and " lightweight vectors for high-level semantic information". The spatial feature maps transfer GPU -> CPU -> Network might be a bottleneck here depending on their size. I can help achieving this. The article is here |
@KTXKIKI you're right. There is currently a big overhead related to the request being sent for each click. I underestimated the request bottleneck, especially for large images. I thought it could be viable, given that SAM2 inference is faster than SAM1. I'll suggest an improvement tonight. |
@KTXKIKI I wont be able to produce the solution tonight. It would require to write a new cvat_ui plugin to decode the SAM2 embeddings client-side using onnxruntime-web, just like it has been done for SAM1 (cvat-ui/plugins/sam/src/ts/index.tsx). It would also require to export the SAM2 decoder in onnx format. This thread is an excellent starting point: facebookresearch/sam2#3 |
I think we need the help of official CVAT personnel |
Hi @jeanchristopheruel, thanks for your great work. As far as I understood, you added SAM2 as an interactor tool, which is working the same way as SAM does. However, the biggest improvement of SAM2 is the video tracking. Even if we somehow implement SAM2 as a tracker, CVAT UI would require us to manually go to the next frame one by one. But SAM2 video tracker is capable of tracking the object over the whole video after the first frame and points are selected. Do you know if it's possible to merge that functionality with CVAT? Is that supported somewhere in the UI at all? |
@ozangungor12 It is possible to integrate SAM2 for video tracking with its featured memory embeddings. It would require to write a new cvat_ui plugin to decode the SAM2 embeddings (encoded featuremap & memory bank) client-side using onnxruntime-web. This would allow the full serverless compatibility with nuclio and ensure scalability (stateless) for cloud applications. Alternatively, for your own interest, you can modify the SAM2 |
Hi, this looks great. I noticed that there isn't a function.yaml for serverless without gpu. Any reason for that? |
I don't think SAM2 can work without a GPU. @jeanchristopheruel also said in the PR:
|
See the thread at facebookresearch/sam2#155 |
Great, thanks for sending it! |
…e not embedded into installation package)
@ozangungor12, @bhack and @realtimshady1, I added support for cpu based on this. Thanks for the info. |
@jeanchristopheruel , thank you for the PR. Could you please look at linters? |
Quality Gate passedIssues Measures |
@jeanchristopheruel , we will be happy to merge the version of SAM2 into CVAT open-source repository. Need to say that our team implemented optimized version of SAM2: https://www.cvat.ai/post/meta-segment-anything-model-v2-is-now-available-in-cvat-ai. It will be available on SaaS for our paid customers and for Enterprise customers. |
@nmanovic Thanks for your response. However, I’m disappointed to see that key advancements like SAM2 are becoming restricted to paid users. CVAT has always been a strong open-source tool, and limiting such features seems to move away from that spirit. I hope you will reconsider and keep these innovations accessible to the broader open-source community. |
@jeanchristopheruel , I would make all features open-source if it were possible. However, delivering new and innovative features to the open-source repository, such as the YOLOv8 format support (#8240), and addressing security issues and bugs, requires financial backing. To sustain this level of development, we rely on the support of paying customers. The best way to help CVAT continue thriving is by purchasing a SaaS subscription (https://www.cvat.ai/pricing/cloud) or becoming an Enterprise customer (https://www.cvat.ai/pricing/on-prem). It's worth noting that around 80% of our contributions go directly into the open-source repository. |
@nmanovic, I understand the need for financial support to sustain development, and I appreciate all the work your team does. However, history has shown that moving key features behind paywalls can sometimes alienate open-source communities. For example, when Elasticsearch restricted features, it led to the community forking it into OpenSearch. I hope CVAT can find a balance that supports both its financial needs and keeps innovation accessible to the open-source community, as that's what has made CVAT so valuable to so many. 😌 |
For those with stronger frontend skills, I recommend checking out this repository, which contains a complete frontend implementation of SAM2 using I also attempted a frontend implementation, and you can find my initial trial here. It's still a work in progress, but feel free to take a look. |
Another possible and very useful feature with models like SAM and SAM2 would be precision annotation in bounding boxes. The idea is to make an imprecise bounding box around the object to be annotated. The bounding box is sent to the SAM or SAM2 model, which segments the main object from the bounding box it receives. Finally, the precise bounding box is recreated by taking the extremum coordinates at the top, left, bottom, right. This would allow very quick and precise annotating, without having to zoom in on the image (very useful for precise annotation of small objects for example). In my free time, I made a python script using this logic with SAM to make precision annotation, taking as input an annotation json (COCO format I think) and which output a json in the same format, with the precise bounding boxes recalculated. I could make it available to you if necessary. |
@Youho99 Very cool indeed! I suggest you create a separate issue to express your feature idea.🙂 |
Great! Do you have an estimate when the tracking aspect / video annotation aspect will be implemented? |
@tpsvoets The current PR adds support for an encoder-decoder sam2 backend, which makes the thing slower than sam1 plugin due to the request overhead. (Sam1 plugin has the decoder running in frontend). Can't give a timeline for sam2 encoder-decoder frontend support since I am not currently working on it. Maybe in the next year.. |
Motivation and context
Regarding #8230 and #8231, I added support for the Segment Anything 2.0 as a Nuclio serverless function. The original Facebook Research repository required some modifications (see pull request) to ease the integration with Nuclio.
Note [EDITED]: This is GPU and CPU.
EDIT: Additional efforts are required to enhance the annotation experience, making it faster by decoding the embeddings client-side with onnxruntime-web. See this comment.
How has this been tested?
The changes were tested on a machine with a GPU and CUDA installed. I verified that the Nuclio function deployed correctly and was able to perform segmentation tasks using Segment Anything 2.0. The integration was tested by running various segmentation tasks and ensuring the expected output was generated. Additionally, the function's performance was monitored to ensure it operated efficiently within the Nuclio environment.
Checklist
develop
branch(cvat-canvas,
cvat-core,
cvat-data and
cvat-ui)
License
Feel free to contact the maintainers if that's a concern.