diff --git a/README.md b/README.md index 2e58656be..e6a1d91f7 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,26 @@ See the examples notebooks on [using SAM with prompts](/notebooks/predictor_exam

+## Video Segmentation + +To use the new video segmentation feature, follow these steps: + +1. Import the necessary modules and initialize the SAM model and predictor: + +``` +from segment_anything import SamPredictor, sam_model_registry, segment_video +sam = sam_model_registry[""](checkpoint="") +predictor = SamPredictor(sam) +``` + +2. Call the `segment_video` function with the path to your video file and the predictor: + +``` +segment_video("", predictor) +``` + +This will read the video frames, segment objects using SAM, and display the segmented frames. + ## ONNX Export SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with diff --git a/demo/src/App.tsx b/demo/src/App.tsx index a42655356..7df74c7c5 100644 --- a/demo/src/App.tsx +++ b/demo/src/App.tsx @@ -1,9 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. - -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - import { InferenceSession, Tensor } from "onnxruntime-web"; import React, { useContext, useEffect, useState } from "react"; import "./assets/scss/App.scss"; @@ -13,6 +7,7 @@ import { onnxMaskToImage } from "./components/helpers/maskUtils"; import { modelData } from "./components/helpers/onnxModelAPI"; import Stage from "./components/Stage"; import AppContext from "./components/hooks/createContext"; +import { segment_video } from "segment_anything/predictor"; const ort = require("onnxruntime-web"); /* @ts-ignore */ import npyjs from "npyjs"; @@ -30,6 +25,7 @@ const App = () => { } = useContext(AppContext)!; const [model, setModel] = useState(null); // ONNX model const [tensor, setTensor] = useState(null); // Image embedding tensor + const [isVideo, setIsVideo] = useState(false); // State variable to handle video input // The ONNX model expects the input to be rescaled to 1024. // The modelScale state variable keeps track of the scale values. @@ -124,6 +120,16 @@ const App = () => { } }; + const handleVideoUpload = async (videoFile: File) => { + setIsVideo(true); + const videoURL = URL.createObjectURL(videoFile); + const videoElement = document.createElement("video"); + videoElement.src = videoURL; + videoElement.onloadeddata = () => { + segment_video(videoElement, model); + }; + }; + return ; }; diff --git a/demo/src/components/Tool.tsx b/demo/src/components/Tool.tsx index 31afbe5c6..9f58d081c 100644 --- a/demo/src/components/Tool.tsx +++ b/demo/src/components/Tool.tsx @@ -1,23 +1,19 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. - -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - import React, { useContext, useEffect, useState } from "react"; import AppContext from "./hooks/createContext"; import { ToolProps } from "./helpers/Interfaces"; import * as _ from "underscore"; -const Tool = ({ handleMouseMove }: ToolProps) => { +interface ToolProps { + handleMouseMove: (e: any) => void; + isVideo: boolean; +} + +const Tool = ({ handleMouseMove, isVideo }: ToolProps) => { const { image: [image], maskImg: [maskImg, setMaskImg], } = useContext(AppContext)!; - // Determine if we should shrink or grow the images to match the - // width or the height of the page and setup a ResizeObserver to - // monitor changes in the size of the page const [shouldFitToWidth, setShouldFitToWidth] = useState(true); const bodyEl = document.body; const fitToPage = () => { @@ -44,27 +40,37 @@ const Tool = ({ handleMouseMove }: ToolProps) => { const imageClasses = ""; const maskImageClasses = `absolute opacity-40 pointer-events-none`; - // Render the image and the predicted mask image on top return ( <> - {image && ( - _.defer(() => setMaskImg(null))} onTouchStart={handleMouseMove} - src={image.src} - className={`${ - shouldFitToWidth ? "w-full" : "h-full" - } ${imageClasses}`} - > - )} - {maskImg && ( - + src={image?.src} + className={`${shouldFitToWidth ? "w-full" : "h-full"} ${imageClasses}`} + autoPlay + loop + muted + > + ) : ( + <> + {image && ( + _.defer(() => setMaskImg(null))} + onTouchStart={handleMouseMove} + src={image.src} + className={`${shouldFitToWidth ? "w-full" : "h-full"} ${imageClasses}`} + > + )} + {maskImg && ( + + )} + )} ); diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py index 8a6e6d816..8910b7c3d 100644 --- a/segment_anything/predictor.py +++ b/segment_anything/predictor.py @@ -6,6 +6,7 @@ import numpy as np import torch +import cv2 from segment_anything.modeling import Sam @@ -267,3 +268,34 @@ def reset_image(self) -> None: self.orig_w = None self.input_h = None self.input_w = None + + +def segment_video(video_path: str, sam_predictor: SamPredictor) -> None: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + print(f"Error: Could not open video {video_path}") + return + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + sam_predictor.set_image(frame_rgb) + + # Example of using the predictor with some dummy points + point_coords = np.array([[100, 100], [200, 200]]) + point_labels = np.array([1, 0]) + masks, _, _ = sam_predictor.predict(point_coords=point_coords, point_labels=point_labels) + + # Process masks as needed, e.g., overlay on the frame + for mask in masks: + frame[mask > 0] = [0, 255, 0] # Example: color mask area green + + cv2.imshow('Segmented Video', frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + cap.release() + cv2.destroyAllWindows()