diff --git a/README.md b/README.md
index 2e58656be..e6a1d91f7 100644
--- a/README.md
+++ b/README.md
@@ -84,6 +84,26 @@ See the examples notebooks on [using SAM with prompts](/notebooks/predictor_exam
+## Video Segmentation
+
+To use the new video segmentation feature, follow these steps:
+
+1. Import the necessary modules and initialize the SAM model and predictor:
+
+```
+from segment_anything import SamPredictor, sam_model_registry, segment_video
+sam = sam_model_registry[""](checkpoint="")
+predictor = SamPredictor(sam)
+```
+
+2. Call the `segment_video` function with the path to your video file and the predictor:
+
+```
+segment_video("", predictor)
+```
+
+This will read the video frames, segment objects using SAM, and display the segmented frames.
+
## ONNX Export
SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with
diff --git a/demo/src/App.tsx b/demo/src/App.tsx
index a42655356..7df74c7c5 100644
--- a/demo/src/App.tsx
+++ b/demo/src/App.tsx
@@ -1,9 +1,3 @@
-// Copyright (c) Meta Platforms, Inc. and affiliates.
-// All rights reserved.
-
-// This source code is licensed under the license found in the
-// LICENSE file in the root directory of this source tree.
-
import { InferenceSession, Tensor } from "onnxruntime-web";
import React, { useContext, useEffect, useState } from "react";
import "./assets/scss/App.scss";
@@ -13,6 +7,7 @@ import { onnxMaskToImage } from "./components/helpers/maskUtils";
import { modelData } from "./components/helpers/onnxModelAPI";
import Stage from "./components/Stage";
import AppContext from "./components/hooks/createContext";
+import { segment_video } from "segment_anything/predictor";
const ort = require("onnxruntime-web");
/* @ts-ignore */
import npyjs from "npyjs";
@@ -30,6 +25,7 @@ const App = () => {
} = useContext(AppContext)!;
const [model, setModel] = useState(null); // ONNX model
const [tensor, setTensor] = useState(null); // Image embedding tensor
+ const [isVideo, setIsVideo] = useState(false); // State variable to handle video input
// The ONNX model expects the input to be rescaled to 1024.
// The modelScale state variable keeps track of the scale values.
@@ -124,6 +120,16 @@ const App = () => {
}
};
+ const handleVideoUpload = async (videoFile: File) => {
+ setIsVideo(true);
+ const videoURL = URL.createObjectURL(videoFile);
+ const videoElement = document.createElement("video");
+ videoElement.src = videoURL;
+ videoElement.onloadeddata = () => {
+ segment_video(videoElement, model);
+ };
+ };
+
return ;
};
diff --git a/demo/src/components/Tool.tsx b/demo/src/components/Tool.tsx
index 31afbe5c6..9f58d081c 100644
--- a/demo/src/components/Tool.tsx
+++ b/demo/src/components/Tool.tsx
@@ -1,23 +1,19 @@
-// Copyright (c) Meta Platforms, Inc. and affiliates.
-// All rights reserved.
-
-// This source code is licensed under the license found in the
-// LICENSE file in the root directory of this source tree.
-
import React, { useContext, useEffect, useState } from "react";
import AppContext from "./hooks/createContext";
import { ToolProps } from "./helpers/Interfaces";
import * as _ from "underscore";
-const Tool = ({ handleMouseMove }: ToolProps) => {
+interface ToolProps {
+ handleMouseMove: (e: any) => void;
+ isVideo: boolean;
+}
+
+const Tool = ({ handleMouseMove, isVideo }: ToolProps) => {
const {
image: [image],
maskImg: [maskImg, setMaskImg],
} = useContext(AppContext)!;
- // Determine if we should shrink or grow the images to match the
- // width or the height of the page and setup a ResizeObserver to
- // monitor changes in the size of the page
const [shouldFitToWidth, setShouldFitToWidth] = useState(true);
const bodyEl = document.body;
const fitToPage = () => {
@@ -44,27 +40,37 @@ const Tool = ({ handleMouseMove }: ToolProps) => {
const imageClasses = "";
const maskImageClasses = `absolute opacity-40 pointer-events-none`;
- // Render the image and the predicted mask image on top
return (
<>
- {image && (
-
_.defer(() => setMaskImg(null))}
onTouchStart={handleMouseMove}
- src={image.src}
- className={`${
- shouldFitToWidth ? "w-full" : "h-full"
- } ${imageClasses}`}
- >
- )}
- {maskImg && (
-
+ src={image?.src}
+ className={`${shouldFitToWidth ? "w-full" : "h-full"} ${imageClasses}`}
+ autoPlay
+ loop
+ muted
+ >
+ ) : (
+ <>
+ {image && (
+
_.defer(() => setMaskImg(null))}
+ onTouchStart={handleMouseMove}
+ src={image.src}
+ className={`${shouldFitToWidth ? "w-full" : "h-full"} ${imageClasses}`}
+ >
+ )}
+ {maskImg && (
+
+ )}
+ >
)}
>
);
diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py
index 8a6e6d816..8910b7c3d 100644
--- a/segment_anything/predictor.py
+++ b/segment_anything/predictor.py
@@ -6,6 +6,7 @@
import numpy as np
import torch
+import cv2
from segment_anything.modeling import Sam
@@ -267,3 +268,34 @@ def reset_image(self) -> None:
self.orig_w = None
self.input_h = None
self.input_w = None
+
+
+def segment_video(video_path: str, sam_predictor: SamPredictor) -> None:
+ cap = cv2.VideoCapture(video_path)
+ if not cap.isOpened():
+ print(f"Error: Could not open video {video_path}")
+ return
+
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if not ret:
+ break
+
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ sam_predictor.set_image(frame_rgb)
+
+ # Example of using the predictor with some dummy points
+ point_coords = np.array([[100, 100], [200, 200]])
+ point_labels = np.array([1, 0])
+ masks, _, _ = sam_predictor.predict(point_coords=point_coords, point_labels=point_labels)
+
+ # Process masks as needed, e.g., overlay on the frame
+ for mask in masks:
+ frame[mask > 0] = [0, 255, 0] # Example: color mask area green
+
+ cv2.imshow('Segmented Video', frame)
+ if cv2.waitKey(1) & 0xFF == ord('q'):
+ break
+
+ cap.release()
+ cv2.destroyAllWindows()