diff --git a/pyproject.toml b/pyproject.toml index cae78492a..a7c46fcc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ "pillow>=9.4", "requests>=2.26.0", "tqdm>=4.62.3", - "opencv-python>=4.5.5.64" + "opencv-python>=4.5.5.64", ] [project.urls] @@ -58,6 +58,8 @@ Documentation = "https://supervision.roboflow.com/latest/" metrics = [ "pandas>=2.0.0", ] +ffmpeg = ["av (>=15.0.0)"] +rich_display = ["ipython (>=8.15,<9.0)"] [dependency-groups] dev = [ diff --git a/supervision/__init__.py b/supervision/__init__.py index ab45651ac..7ae921ee3 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -130,11 +130,12 @@ from supervision.utils.notebook import plot_image, plot_images_grid from supervision.utils.video import ( FPSMonitor, - VideoInfo, VideoSink, get_video_frames_generator, process_video, ) +from supervision.video import Video, VideoInfo +from supervision.video.backend import VideoBackendType __all__ = [ "LMM", @@ -192,6 +193,8 @@ "TriangleAnnotator", "VertexAnnotator", "VertexLabelAnnotator", + "Video", + "VideoBackendType", "VideoInfo", "VideoSink", "approximate_polygon", diff --git a/supervision/utils/video.py b/supervision/utils/video.py index 3b281b4e2..d3408b90e 100644 --- a/supervision/utils/video.py +++ b/supervision/utils/video.py @@ -9,7 +9,13 @@ import numpy as np from tqdm.auto import tqdm +from supervision.utils.internal import deprecated + +@deprecated( + "`process_video` is deprecated and will be removed in " + "`supervision-0.32.0`. Use `sv.VideoInfo` instead." +) @dataclass class VideoInfo: """ @@ -60,6 +66,10 @@ def resolution_wh(self) -> tuple[int, int]: return self.width, self.height +@deprecated( + "`process_video` is deprecated and will be removed in " + "`supervision-0.32.0`. Use `sv.Video().save` instead." +) class VideoSink: """ Context manager that saves video frames to a file using OpenCV. @@ -141,6 +151,10 @@ def _validate_and_setup_video( return video, start, end +@deprecated( + "`process_video` is deprecated and will be removed in " + "`supervision-0.32.0`. Use `sv.Video().frame()` or `sv.Video()` instead." +) def get_video_frames_generator( source_path: str, stride: int = 1, @@ -192,6 +206,10 @@ def get_video_frames_generator( video.release() +@deprecated( + "`process_video` is deprecated and will be removed in " + "`supervision-0.32.0`. Use `sv.Video().save` instead." +) def process_video( source_path: str, target_path: str, diff --git a/supervision/video/__init__.py b/supervision/video/__init__.py new file mode 100644 index 000000000..54393f5d0 --- /dev/null +++ b/supervision/video/__init__.py @@ -0,0 +1,4 @@ +from supervision.video.core import Video +from supervision.video.utils import SourceType, VideoInfo + +__all__ = ["SourceType", "Video", "VideoInfo"] diff --git a/supervision/video/backend/__init__.py b/supervision/video/backend/__init__.py new file mode 100644 index 000000000..5f5a83c04 --- /dev/null +++ b/supervision/video/backend/__init__.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from enum import Enum +from typing import Literal, Union + +from supervision.video.backend.opencv import OpenCVBackend, OpenCVWriter +from supervision.video.backend.pyav import pyAVBackend, pyAVWriter + +VideoBackendTypes = Union[OpenCVBackend, pyAVBackend] +VideoWriterTypes = Union[OpenCVWriter, pyAVWriter] + + +class VideoBackendType(Enum): + """ + Enumeration of Backends. + + Attributes: + PYAV (str): PyAV backend (powered by FFmpeg, supports audio rendering) + OPENCV (str): OpenCV backend + + """ + + PYAV = "pyav" + OPENCV = "opencv" + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + @classmethod + def from_value(cls, value: VideoBackendType | str) -> VideoBackendType: + if isinstance(value, cls): + return value + if isinstance(value, str): + value = value.lower() + try: + return cls(value) + except ValueError: + raise ValueError(f"Invalid value: {value}. Must be one of {cls.list()}") + raise ValueError( + f"Invalid value type: {type(value)}. Must be an instance of " + f"{cls.__name__} or str." + ) + + +VideoBackendDict = { + VideoBackendType.PYAV: pyAVBackend, + VideoBackendType.OPENCV: OpenCVBackend, +} + +VideoWriterDict = { + VideoBackendType.PYAV: pyAVWriter, + VideoBackendType.OPENCV: OpenCVWriter, +} + +__all__ = [ + "VideoBackendDict", + "VideoBackendType", + "VideoBackendTypes", + "VideoWriterDict", + "VideoWriterTypes", +] diff --git a/supervision/video/backend/base.py b/supervision/video/backend/base.py new file mode 100644 index 000000000..5ba634708 --- /dev/null +++ b/supervision/video/backend/base.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod + +import numpy as np + +from supervision.video.utils import VideoInfo + + +class BaseBackend(ABC): + def __init__(self): + self.cap = None + self.video_info: VideoInfo | None = None + self.writer: BaseWriter | None = None + self.path = None + + @abstractmethod + def open(self, path: str | int) -> None: + pass + + @abstractmethod + def isOpened(self) -> bool: + pass + + @abstractmethod + def _set_video_info(self) -> VideoInfo: + pass + + @abstractmethod + def info(self) -> VideoInfo: + pass + + @abstractmethod + def read(self) -> tuple[bool, np.ndarray]: + pass + + @abstractmethod + def grab(self) -> bool: + pass + + @abstractmethod + def seek(self, frame_idx: int) -> None: + pass + + @abstractmethod + def release(self) -> None: + pass + + +class BaseWriter(ABC): + @abstractmethod + def __init__( + self, + filename: str, + fps: int, + frame_size: tuple[int, int], + codec: str | None = None, + backend: BaseBackend | None = None, + render_audio: bool | None = None, + ): + pass + + @abstractmethod + def __enter__(self): + pass + + @abstractmethod + def __exit__(self, exc_type, exc_value, traceback): + pass + + @abstractmethod + def write(self, frame: np.ndarray) -> None: + pass + + @abstractmethod + def close(self) -> None: + pass diff --git a/supervision/video/backend/opencv.py b/supervision/video/backend/opencv.py new file mode 100644 index 000000000..6d49b668a --- /dev/null +++ b/supervision/video/backend/opencv.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import cv2 +import numpy as np + +from supervision.video.backend.base import BaseBackend, BaseWriter +from supervision.video.utils import SourceType, VideoInfo + + +class OpenCVBackend(BaseBackend): + """ + OpenCV-based video backend implementation for video capture and processing. + + This backend provides video reading capabilities using OpenCV's VideoCapture. + It supports: + - Local video files + - Webcam streams + - RTSP network streams + + Attributes: + cap (cv2.VideoCapture): OpenCV video capture instance. + video_info (VideoInfo): Metadata about the video source. + writer (class): Reference to the OpenCVWriter class for video writing. + path (str | int): Path to the video source or webcam index. + + """ + + def __init__(self): + """Initialize the OpenCV backend with no active capture.""" + self.cap = None + self.video_info = None + self.writer = OpenCVWriter + self.path = None + + def open(self, path: str | int) -> None: + """ + Open a video source for reading. + + Args: + path (str | int): Path to video file, RTSP URL, or webcam index. + Webcam indices are typically 0 for default camera. + + Raises: + RuntimeError: If the source cannot be opened. + ValueError: If the source type is unsupported. + """ + self.cap = cv2.VideoCapture(path) + self.path = path + + if not self.cap.isOpened(): + raise RuntimeError(f"Cannot open video source: {path}") + + self.video_info = self._set_video_info() + + if isinstance(path, int): + self.video_info.SourceType = SourceType.WEBCAM + elif isinstance(path, str): + self.video_info.SourceType = ( + SourceType.RTSP + if path.lower().startswith("rtsp://") + else SourceType.VIDEO_FILE + ) + else: + raise ValueError("Unsupported source type") + + def isOpened(self) -> bool: + """ + Check if the video source is currently open and available. + + Returns: + bool: True if source is open and ready for reading, False otherwise. + """ + return self.cap.isOpened() + + def _set_video_info(self) -> VideoInfo: + """ + Extract and store video metadata from the opened source. + + Returns: + VideoInfo: Object containing: + - width (int): Frame width in pixels + - height (int): Frame height in pixels + - fps (int): Frames per second + - total_frames (int): Total frame count (0 for streams) + + Raises: + RuntimeError: If no source is open. + """ + if not self.isOpened(): + raise RuntimeError("Video not opened yet.") + + width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = round(self.cap.get(cv2.CAP_PROP_FPS)) + total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + return VideoInfo(width, height, fps, total_frames) + + def info(self) -> VideoInfo: + """ + Retrieve stored video metadata. + + Returns: + VideoInfo: Video properties including dimensions, FPS, and frame count. + + Raises: + RuntimeError: If no source is open. + """ + if not self.isOpened(): + raise RuntimeError("Video not opened yet.") + return self.video_info + + def read(self) -> tuple[bool, np.ndarray]: + """ + Read and decode the next frame from the video source. + + Returns: + tuple[bool, np.ndarray]: + - bool: True if frame was read successfully, False at end of stream + - np.ndarray: Frame data in BGR format (height, width, 3) + + Raises: + RuntimeError: If no source is open. + """ + if self.cap is None: + raise RuntimeError("Video not opened yet.") + return self.cap.read() + + def grab(self) -> bool: + """ + Advance to the next frame without decoding. + + Useful for quickly skipping frames when pixel data isn't needed. + + Returns: + bool: True if frame was advanced successfully, False otherwise + + Raises: + RuntimeError: If no source is open. + """ + if self.cap is None: + raise RuntimeError("Video not opened yet.") + return self.cap.grab() + + def seek(self, frame_idx: int) -> None: + """ + Seek to a specific frame index. + + Note: Seeking may be imprecise with compressed video formats. + + Args: + frame_idx (int): Zero-based index of target frame. + + Raises: + RuntimeError: If no source is open. + """ + if self.cap is None: + raise RuntimeError("Video not opened yet.") + self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + + def release(self) -> None: + """Release the video capture resources.""" + if self.cap is not None and self.cap.isOpened(): + self.cap.release() + self.cap = None + + +class OpenCVWriter(BaseWriter): + """ + OpenCV-based video writer for creating video files. + + This writer provides basic video encoding capabilities using OpenCV's VideoWriter. + Note: Does not support audio writing - use pyAVWriter for audio support. + """ + + def __init__( + self, + filename: str, + fps: int, + frame_size: tuple[int, int], + codec: str = "mp4v", + backend: OpenCVBackend | None = None, + render_audio: bool | None = None, + ): + """ + Initialize the video writer. + + Args: + filename (str): Output video file path (e.g., "output.mp4"). + fps (int): Target frames per second for output video. + frame_size (tuple[int, int]): (width, height) of output frames. + codec (str, optional): FourCC codec code (default "mp4v"). + backend (OpenCVBackend, optional): Unused (for API compatibility). + render_audio (bool, optional): Must be None (OpenCV doesn't support audio). + + Raises: + ValueError: If render_audio is specified (not supported). + RuntimeError: If writer cannot be initialized. + + Note: + Falls back to "mp4v" codec if specified codec fails. + """ + if render_audio or render_audio is False: + raise ValueError( + "OpenCV backend does not support audio. " + "Please use `pyav` backend instead or set `render_audio=None`" + ) + + self.backend = backend + try: + fourcc_int = cv2.VideoWriter_fourcc(*codec) + self.writer = cv2.VideoWriter(filename, fourcc_int, fps, frame_size) + except Exception: + fourcc_int = cv2.VideoWriter_fourcc(*"mp4v") + self.writer = cv2.VideoWriter(filename, fourcc_int, fps, frame_size) + + if not self.writer.isOpened(): + raise RuntimeError(f"Cannot open video writer for file: {filename}") + + def __enter__(self): + """Enable context manager support (with statement).""" + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Ensure proper cleanup when exiting context.""" + self.close() + + def write(self, frame: np.ndarray) -> None: + """ + Write a single frame to the output video. + + Args: + frame (np.ndarray): Frame data in BGR format (height, width, 3). + """ + self.writer.write(frame) + + def close(self) -> None: + """Finalize and close the output video file.""" + self.writer.release() diff --git a/supervision/video/backend/pyav.py b/supervision/video/backend/pyav.py new file mode 100644 index 000000000..4db1e0a9b --- /dev/null +++ b/supervision/video/backend/pyav.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +import platform +import re +import sys +from fractions import Fraction + +import numpy as np + +from supervision.video.backend.base import BaseBackend, BaseWriter +from supervision.video.utils import SourceType, VideoInfo + +av = None + + +def get_av(): + if "av" in sys.modules and sys.modules["av"] is None: + del sys.modules["av"] + + try: + import av + + return av + except ImportError: + raise RuntimeError("PyAV (`av` module) is not installed. Run `pip install av`.") + + +class pyAVBackend(BaseBackend): + """ + PyAV-based implementation of the `BaseBackend` interface for video processing. + + This backend provides video capture and frame reading capabilities using the PyAV + library, which is a Pythonic binding for FFmpeg. It supports: + - Local video files + - Webcam streams (platform-specific) + - RTSP network streams + """ + + def __init__(self): + """ + Initialize the pyAVBackend instance. + + Raises: + RuntimeError: If PyAV (`av` module) is not installed. + """ + super().__init__() + + global av + av = get_av() + + self.container = None + self.stream = None + self.writer = pyAVWriter + self.frame_generator = None + self.video_info = None + self.current_frame_idx = 0 + + def open(self, path: str | int) -> None: + """ + Open and initialize a video source. + + This method opens a video file, RTSP stream, or webcam, and sets up + the necessary components for decoding and reading frames. + + Args: + path (str | int): Path to the video file, RTSP URL, or webcam path. + + Raises: + RuntimeError: If the video source cannot be opened. + ValueError: If the source type is unsupported. + """ + _source_type = None + _format = None + + def is_webcam_path(path: str) -> tuple[bool, str]: + """ + Determine if the path refers to a webcam and get platform-specific format. + + Args: + path (str): The path to check. + + Returns: + tuple[bool, str]: (True if webcam, FFmpeg format string) + """ + if not isinstance(path, str): + return False, None + + system = platform.system() + path_lower = path.lower() + + if system == "Windows": + return path_lower.startswith("video="), "dshow" + elif system == "Linux": + return bool(re.match(r"^/dev/video\d+$", path_lower)), "v4l2" + elif system == "Darwin": + return path_lower.isdigit(), "avfoundation" + else: + return False, None + + isWebcam, ffmpeg_os_format = is_webcam_path(path=path) + if isWebcam: + _source_type = SourceType.WEBCAM + _format = ffmpeg_os_format + elif isinstance(path, str): + _source_type = ( + SourceType.RTSP + if path.lower().startswith("rtsp://") + else SourceType.VIDEO_FILE + ) + else: + raise ValueError("Unsupported source type") + + try: + self.container = av.open(path, format=_format) + self.path = path + self.stream = self.container.streams.video[0] + self.stream.thread_type = "AUTO" + self.cap = self.container + + self.frame_generator = self.container.decode(video=0) + self.video_info = self._set_video_info() + self.current_frame_idx = 0 + + # If audio exists + if len(self.container.streams.audio) > 0: + self.audio_stream = self.container.streams.audio[0] + else: + self.audio_stream = None + + self.video_info.SourceType = _source_type + + except Exception as e: + raise RuntimeError(f"Cannot open video source: {path}") from e + + def isOpened(self) -> bool: + """ + Check if the video source has been successfully opened. + + Returns: + bool: True if video source is opened and ready, False otherwise. + """ + return self.container is not None and self.stream is not None + + def _set_video_info(self) -> VideoInfo: + """ + Extract and calculate video information from the opened source. + + Returns: + VideoInfo: Object containing: + - width (int): Frame width in pixels + - height (int): Frame height in pixels + - fps (int): Frames per second (estimated if not available) + - total_frames (int | None): Total frame count if available + + Raises: + RuntimeError: If the video source is not opened. + """ + if not self.isOpened(): + raise RuntimeError("Video not opened yet.") + + width = self.stream.width + height = self.stream.height + fps = float(self.stream.average_rate or self.stream.guessed_rate) + if fps <= 0: + fps = 30 # Default FPS if cannot be determined + + total_frames = self.stream.frames + if total_frames == 0: + total_frames = None # Unknown frame count + + return VideoInfo(width, height, round(fps), total_frames) + + def info(self) -> VideoInfo: + """ + Retrieve video information for the opened source. + + Returns: + VideoInfo: Video properties including dimensions, FPS, and frame count. + + Raises: + RuntimeError: If the video source is not opened. + """ + if not self.isOpened(): + raise RuntimeError("Video not opened yet.") + return self.video_info + + def read(self) -> tuple[bool, np.ndarray]: + """ + Read and decode the next frame from the video source. + + Returns: + tuple[bool, np.ndarray]: + - bool: True if frame was read successfully, False at end of stream + - np.ndarray: Frame data in BGR format with shape (height, width, 3) + + Raises: + RuntimeError: If the video source is not opened. + """ + if not self.isOpened(): + raise RuntimeError("Video not opened yet.") + + try: + frame = next(self.frame_generator) + self.current_frame_idx += 1 + frame_bgr = frame.to_ndarray(format="bgr24") + return True, frame_bgr + except (StopIteration, av.error.EOFError): + return False, np.array([]) + + def grab(self) -> bool: + """ + Advance to the next frame packet without decoding it. + + This is useful for quickly skipping frames when decoding isn't needed. + + Returns: + bool: True if frame packet was advanced, False at end of stream + + Raises: + RuntimeError: If the video source is not opened. + """ + if not self.isOpened(): + raise RuntimeError("Video not opened yet.") + + try: + for packet in self.container.demux(video=0): + if packet.stream.type == "video": + return True + return False + except (StopIteration, av.error.EOFError): + return False + + def seek(self, frame_idx: int) -> None: + """ + Seek to a specific frame index in the video. + + Uses keyframe-based seeking followed by sequential decoding to reach + the exact requested frame. This is more efficient than sequential seeking + but may be slower for very large jumps. + + Args: + frame_idx (int): Zero-based index of the target frame. + + Raises: + RuntimeError: If the video source is not opened. + """ + if not self.isOpened(): + raise RuntimeError("Video not opened yet.") + + framerate = float(self.stream.average_rate or self.stream.guessed_rate or 30.0) + if framerate <= 0: + framerate = 30.0 + + time_base = float(self.stream.time_base) + timestamp = int((frame_idx / framerate) / time_base) + + self.container.seek( + timestamp, stream=self.stream, any_frame=False, backward=True + ) + self.frame_generator = self.container.decode(video=0) + + self.current_frame_idx = 0 + while True: + try: + frame = next(self.frame_generator) + except (StopIteration, av.error.EOFError): + break + + if getattr(frame, "time", None) is not None: + self.current_frame_idx = round(frame.time * framerate) + elif getattr(frame, "pts", None) is not None: + self.current_frame_idx = round((frame.pts * time_base) * framerate) + else: + self.current_frame_idx += 1 + + if self.current_frame_idx >= frame_idx: + + def _prepend_frame(first_frame, gen): + yield first_frame + yield from gen + + self.frame_generator = _prepend_frame(frame, self.frame_generator) + break + + def release(self) -> None: + """ + Release the video source and free all associated resources. + + This closes the video container and resets all internal state. + Should be called when finished with the video source. + """ + if self.container: + self.container.close() + self.container = None + self.stream = None + self.frame_generator = None + + +class pyAVWriter(BaseWriter): + """ + PyAV-based video writer for creating video files with optional audio. + + This writer provides high-quality video encoding with precise frame timing + (millisecond accuracy) and supports audio muxing from a source video. + + Methods: + write(frame): Write a video frame. + close(): Finalize and close the video file. + """ + + def __init__( + self, + filename: str, + fps: int, + frame_size: tuple[int, int], + codec: str = "h264", + backend: pyAVBackend | None = None, + render_audio: bool | None = None, + ): + """ + Initialize the video writer. + + Args: + filename (str): Path to the output video file. + fps (int): Target frames per second for the output video. + frame_size (tuple[int, int]): (width, height) of output frames. + codec (str, optional): Video codec name (default "h264"). + backend (pyAVBackend, optional): Source backend for audio muxing. + render_audio (bool, optional): Include audio (default True if available). + + Raises: + RuntimeError: If the output file cannot be created. + """ + try: + self.container = av.open(filename, mode="w") + self.path = filename + self.backend = backend + + if render_audio is None: + render_audio = True + + if codec is None: + codec = "h264" + self.stream = self.container.add_stream(codec, rate=fps) + self.stream.width = frame_size[0] + self.stream.height = frame_size[1] + self.stream.pix_fmt = "yuv420p" + + # Use finer time_base (1/1000) for millisecond precision timestamps + self.stream.codec_context.time_base = Fraction(1, 1000) + + self.frame_idx = 0 + self.fps = fps # Store FPS for timestamp calculations + + self.audio_stream_out = None + self.audio_packets = [] + + if render_audio and backend and backend.audio_stream: + audio_codec_name = backend.audio_stream.codec_context.name + audio_rate = backend.audio_stream.codec_context.rate + self.audio_stream_out = self.container.add_stream( + audio_codec_name, rate=audio_rate + ) + + except Exception as e: + raise RuntimeError(f"Cannot open video writer for file: {filename}") from e + + def __enter__(self): + """Enable use as a context manager (with statement).""" + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Ensure proper cleanup when exiting context.""" + self.close() + + def write(self, frame: np.ndarray) -> None: + """ + Write a single video frame to the output file. + + Args: + frame (np.ndarray): Frame data in BGR format (height, width, 3). + """ + # Calculate PTS as milliseconds: frame_index * (1000 ms / fps) + pts = int(self.frame_idx * (1000 / self.fps)) + + frame_rgb = frame[..., ::-1] # Convert BGR to RGB + av_frame = av.VideoFrame.from_ndarray(frame_rgb, format="rgb24") + + av_frame.pts = pts + av_frame.time_base = self.stream.codec_context.time_base + self.frame_idx += 1 + + packets = self.stream.encode(av_frame) + for packet in packets: + self.container.mux(packet) + + def close(self) -> None: + """ + Finalize and close the video file, including audio processing if enabled. + + This method performs several critical operations: + 1. If audio is enabled, processes and muxes the audio stream from the source + 2. Applies tempo adjustment to match the output video FPS + 3. Flushes all remaining video frames from the encoder + 4. Properly closes the output container + + The audio processing uses FFmpeg filters to: + - Read audio from the original source + - Apply tempo scaling based on FPS differences between source and output + - Encode and mux the processed audio into the output file + + Note: + This method should always be called when finished writing frames. + It ensures proper file finalization and resource cleanup. + """ + if self.audio_stream_out is not None: + + def atempo_chain(speed: float) -> list[str]: + if speed <= 0: + raise ValueError("Speed factor must be > 0") + + chain = [] + + while speed > 2.0: + chain.append("2.0") + speed /= 2.0 + + while speed < 0.5: + chain.append("0.5") + speed /= 0.5 + + if abs(speed - 1.0) > 1e-6: + chain.append(f"{speed:.6f}") + + return chain + + src = av.open(self.backend.path) + src_fps = ( + src.streams.video[0].average_rate or src.streams.video[0].guessed_rate + ) + audio_stream = src.streams.audio[0] + + graph = av.filter.Graph() + filters = atempo_chain(self.fps / src_fps) + nodes = [graph.add_abuffer(template=audio_stream)] + for f in filters: + nodes.append(graph.add("atempo", f)) + + nodes.append(graph.add("abuffersink")) + graph.link_nodes(*nodes) + graph.configure() + + for packet in src.demux(audio_stream): + for frame in packet.decode(): + graph.push(frame) + + while True: + try: + f = graph.pull() + except Exception: + break + for pkt in self.audio_stream_out.encode(f): + self.container.mux(pkt) + + graph.push(None) + while True: + try: + f = graph.pull() + except Exception: + break + for pkt in self.audio_stream_out.encode(f): + self.container.mux(pkt) + + for pkt in self.audio_stream_out.encode(None): + self.container.mux(pkt) + + src.close() + + # flush video + for pkt in self.stream.encode(): + self.container.mux(pkt) + + self.container.close() diff --git a/supervision/video/core.py b/supervision/video/core.py new file mode 100644 index 000000000..c866ba275 --- /dev/null +++ b/supervision/video/core.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +import os +import sys +from collections.abc import Callable + +import cv2 +import numpy as np +from tqdm.auto import tqdm + +from supervision.video.backend import ( + VideoBackendDict, + VideoBackendType, + VideoBackendTypes, + VideoWriterTypes, +) +from supervision.video.utils import SourceType, VideoInfo + + +def get_iPython(): + if "IPython" in sys.modules and sys.modules["IPython"] is None: + del sys.modules["IPython"] + + try: + import IPython + + return IPython + except ImportError: + raise RuntimeError( + "IPython (`IPython` module) is not installed. Run `pip install IPython`." + ) + + +class Video: + """ + A high-level interface for reading, processing, and writing video files or streams. + + Attributes: + info (VideoInfo): Metadata about the opened video. + source (str | int): Path to the video file or index of the camera device. + backend (VideoBackendTypes): Video backend used for I/O operations. + """ + + info: VideoInfo + source: str | int + backend: VideoBackendTypes + + def __init__( + self, + source: str | int, + backend: VideoBackendType | str = VideoBackendType.OPENCV, + ) -> None: + """ + Initialize the Video object and open the source. + + Args: + source (str | int): Path to a video file or index of a camera device. + backend (Backend | str, optional): Backend type or name for video I/O. + Defaults to Backend.OPENCV. + + Raises: + ValueError: If the specified backend is not supported. + """ + self.backend = VideoBackendDict.get(VideoBackendType.from_value(backend)) + if self.backend is None: + raise ValueError(f"Unsupported backend: {backend}") + + # Instantiate the backend class once sanity check is done + self.backend = self.backend() + + self.backend.open(source) + self.info = self.backend.info() + self.source = source + + def __iter__(self): + """ + Make the Video object directly iterable over frames. + + Yields: + np.ndarray: The next frame in the video stream. + """ + return self.frames() + + def sink( + self, + target_path: str, + info: VideoInfo, + codec: str | None = None, + render_audio: bool | None = None, + ) -> VideoWriterTypes: + """ + Create a video writer for saving frames to a file. + + Args: + target_path (str): Output file path for the video. + info (VideoInfo): Video metadata including resolution and FPS. + codec (str, optional): FourCC video codec code. + If None, the backend's default codec is used. + render_audio (bool | None, optional): Whether to include audio if supported. + + Returns: + WriterTypes: Video writer instance for writing frames. + """ + return self.backend.writer( + target_path, info.fps, info.resolution_wh, codec, self.backend, render_audio + ) + + def frames( + self, + stride: int = 1, + start: int = 0, + end: int | None = None, + resolution_wh: tuple[int, int] | None = None, + ): + """ + Generate frames from the video with optional skipping, seeking, and resizing. + + Args: + stride (int, optional): Number of frames to skip between each yield. + Defaults to 1 (no skipping). + start (int, optional): Index of the first frame to read. Defaults to 0. + end (int | None, optional): Index after the last frame to read. + If None, reads until the end of the video. + resolution_wh (tuple[int, int] | None, optional): Target resolution + (width, height) for resizing frames. If None, keeps original size. + + Yields: + np.ndarray: The next frame in the video. + + Raises: + RuntimeError: If the video has not been opened. + """ + if self.backend.cap is None: + raise RuntimeError("Video not opened yet.") + + total_frames = ( + self.backend.video_info.total_frames if self.backend.video_info else 0 + ) + is_live_stream = total_frames is None or total_frames <= 0 + + if is_live_stream: + # Live stream handling + while True: + for _ in range(stride - 1): + if not self.backend.grab(): + return + ret, frame = self.backend.read() + if not ret: + return + if resolution_wh is not None: + frame = cv2.resize(frame, resolution_wh) + yield frame + else: + # Video file handling + if end is None or end > total_frames: + end = total_frames + + frame_idx = start + while frame_idx < end: + self.backend.seek(frame_idx) + ret, frame = self.backend.read() + if not ret: + break + if resolution_wh is not None: + frame = cv2.resize(frame, resolution_wh) + yield frame + frame_idx += stride + + def save( + self, + target_path: str, + callback: Callable[[np.ndarray, int], np.ndarray], + fps: int | None = None, + progress_message: str = "Processing video", + show_progress: bool = False, + codec: str | None = None, + render_audio: bool | None = None, + ): + """ + Process and save video frames to a file. + + Reads frames from the source, applies the given `callback` function to each + frame, and writes the processed frames to the specified output file. + + Args: + target_path (str): Output file path for the processed video. + callback (Callable[[np.ndarray, int], np.ndarray]): A function that takes in + a video frame (numpy array) and its frame index, and returns a frame. + fps (int | None, optional): Frames per second of the output video. + If None, uses the original FPS. + progress_message (str, optional): Message displayed in the progress bar. + Defaults to "Processing video". + show_progress (bool, optional): If True, displays a tqdm progress bar. + Defaults to False. + codec (str | None, optional): FourCC video codec code. + If None, uses the backend's default codec. + render_audio (bool | None, optional): Whether to include audio if supported. + + Raises: + RuntimeError: If the video has not been opened. + ValueError: If the video source is not a file. + + Returns: + None + """ + if self.backend.cap is None: + raise RuntimeError("Video not opened yet.") + + if self.backend.video_info.SourceType != SourceType.VIDEO_FILE: + raise ValueError("Only video files can be saved.") + + if fps is None: + fps = self.backend.video_info.fps + + writer = self.backend.writer( + target_path, + fps, + self.backend.video_info.resolution_wh, + codec, + self.backend, + render_audio, + ) + total_frames = self.backend.video_info.total_frames + frames_generator = self.frames() + for index, frame in enumerate( + tqdm( + frames_generator, + total=total_frames, + disable=not show_progress, + desc=progress_message, + ) + ): + result_frame = callback(frame, index) + writer.write(frame=result_frame) + + writer.close() + + def show( + self, + resolution_wh: tuple[int, int] | None = None, + callback: Callable[[np.ndarray, int], np.ndarray] = lambda f, i: f, + fps: int | None = None, + progress_message: str = "Processing video", + show_progress: bool = False, + render_audio: bool | None = None, + ): + """ + Display video frames in a window with interactive playback controls. + + This method streams video frames to an OpenCV window, allowing real-time + visualization. Press 'q' to quit playback. The method handles various + display-related exceptions gracefully. + + Args: + resolution_wh (tuple[int, int] | None): Optional target resolution as + (width, height) tuple. If None, uses native video resolution. + Note: Aspect ratio may not be preserved. + """ + + # On Jupyter Notebook + def in_notebook(): + argv = getattr(sys, "argv", []) + return any("jupyter" in arg or "ipykernel_launcher" in arg for arg in argv) + + def is_Headless(): + if sys.platform.startswith("linux"): + return not bool(os.environ.get("DISPLAY", "")) + if sys.platform == "darwin": + return not bool( + os.environ.get("TERM_PROGRAM") or os.environ.get("DISPLAY") + ) + if sys.platform.startswith("win"): + try: + import ctypes + + user32 = ctypes.windll.user32 + return user32.GetDesktopWindow() == 0 + except Exception: + return True + return True + + # On a notebook + if in_notebook(): + iPyDisplay = get_iPython().display + + self.save( + "temp.mp4", + callback=callback, + fps=fps, + progress_message=progress_message, + show_progress=show_progress, + render_audio=render_audio, + ) + + width = resolution_wh[0] if resolution_wh is not None else None + height = resolution_wh[1] if resolution_wh is not None else None + iPyDisplay.display( + iPyDisplay.Video("temp.mp4", embed=True, width=width, height=height) + ) + os.remove("temp.mp4") + # On a computer + elif not is_Headless(): + for frame in self.frames(resolution_wh=resolution_wh): + cv2.imshow(str(self.source), frame) + key = cv2.waitKey(1) & 0xFF + + if key == ord("q"): + break + + while True: + if cv2.getWindowProperty(str(self.source), cv2.WND_PROP_VISIBLE) < 1: + break + cv2.waitKey(100) + cv2.destroyAllWindows() + # On a headless system + else: + if iPyDisplay is None: + raise RuntimeError( + "IPython (`IPython` module) is not installed. " + "Run `pip install IPython`." + ) + + self.save( + "temp.mp4", + callback=callback, + fps=fps, + progress_message=progress_message, + show_progress=show_progress, + render_audio=render_audio, + ) + + width = resolution_wh[0] if resolution_wh is not None else None + height = resolution_wh[1] if resolution_wh is not None else None + + display_video = iPyDisplay.Video( + "temp.mp4", embed=True, width=width, height=height + ) + html_code = display_video._repr_html_() + export_path = "video_display.html" + + with open(export_path, "w") as f: + f.write(html_code) + print(f"Video exported as HTML to {export_path}") + + os.remove("temp.mp4") diff --git a/supervision/video/utils.py b/supervision/video/utils.py new file mode 100644 index 000000000..cb9cdbc10 --- /dev/null +++ b/supervision/video/utils.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + + +class SourceType(Enum): + """ + Enumeration of supported video source types. + + Attributes: + VIDEO_FILE: A standard video file on disk. + WEBCAM: A webcam or other direct camera device. + RTSP: A network RTSP video stream. + """ + + VIDEO_FILE = "video_file" + WEBCAM = "webcam" + RTSP = "rtsp" + + @classmethod + def list(cls) -> list[str]: + """ + Get a list of all supported source type values. + + Returns: + list[str]: List of enum values as lowercase strings. + + Example: + >>> SourceType.list() + ['video_file', 'webcam', 'rtsp'] + """ + return list(map(lambda c: c.value, cls)) + + @classmethod + def from_value(cls, value: SourceType | str) -> SourceType: + """ + Convert a string or SourceType instance to a SourceType enum member. + + Args: + value (SourceType | str): The value to convert. + + Returns: + SourceType: Corresponding SourceType enum member. + + Raises: + ValueError: If the value is invalid or not a supported type. + + Example: + >>> SourceType.from_value("webcam") + + """ + if isinstance(value, cls): + return value + if isinstance(value, str): + value = value.lower() + try: + return cls(value) + except ValueError: + raise ValueError(f"Invalid value: {value}. Must be one of {cls.list()}") + raise ValueError( + f"Invalid value type: {type(value)}. Must be an instance of " + f"{cls.__name__} or str." + ) + + +@dataclass +class VideoInfo: + """ + Stores metadata about a video, such as dimensions, frame rate, and source type. + + Attributes: + width (int): Width of the video in pixels. + height (int): Height of the video in pixels. + fps (int): Frames per second of the video. + total_frames (int | None): Total number of frames, or None if unknown. + SourceType (SourceType | None): Source type (VIDEO_FILE, WEBCAM, or RTSP). + + Properties: + resolution_wh (tuple[int, int]): The (width, height) tuple for the video. + + Example: + >>> import supervision as sv + >>> video_info = sv.VideoInfo.from_video_path("video.mp4") + >>> print(video_info) + VideoInfo(width=3840, height=2160, fps=25, total_frames=538) + >>> video_info.resolution_wh + (3840, 2160) + """ + + width: int + height: int + fps: int + total_frames: int | None = None + SourceType: SourceType | None = None + + @property + def resolution_wh(self) -> tuple[int, int]: + """ + Get the video resolution as a (width, height) tuple. + + Returns: + tuple[int, int]: The video dimensions in pixels. + + Example: + >>> VideoInfo(width=1920, height=1080, fps=30).resolution_wh + (1920, 1080) + """ + return self.width, self.height diff --git a/test/video/test_video.py b/test/video/test_video.py new file mode 100644 index 000000000..785b0c704 --- /dev/null +++ b/test/video/test_video.py @@ -0,0 +1,78 @@ +import os + +import cv2 +import numpy as np + +import supervision as sv + + +def _create_temp_video(path: str, width=320, height=240, fps=30, frames=10): + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + writer = cv2.VideoWriter(path, fourcc, fps, (width, height)) + for _ in range(frames): + frame = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) + writer.write(frame) + writer.release() + + +def test_video_info_and_iteration(tmp_path): + vid_path = tmp_path / "test.mp4" + _create_temp_video(str(vid_path)) + + video = sv.Video(str(vid_path)) + info = video.info + + assert info.width == 320 + assert info.height == 240 + assert info.total_frames == 10 + + frames = list(video.frames()) + assert len(frames) == 10 + + +def test_frames_stride(tmp_path): + vid_path = tmp_path / "test_stride.mp4" + _create_temp_video(str(vid_path), frames=9) + + video = sv.Video(str(vid_path)) + frames = list(video.frames(stride=2)) + assert len(frames) == 5 # ceil(9/2) + + +def test_save_with_callback(tmp_path): + src = tmp_path / "src.mp4" + dst = tmp_path / "dst.mp4" + _create_temp_video(str(src)) + + def identity(frame, i): + return frame + + sv.Video(str(src)).save(str(dst), callback=identity, show_progress=False) + + # confirm destination exists and metadata matches + dst_video = sv.Video(str(dst)) + assert dst_video.info.total_frames == 10 + + +def test_legacy_get_video_frames_generator(tmp_path): + vid_path = tmp_path / "legacy.mp4" + _create_temp_video(str(vid_path), frames=6) + + frames = list(sv.get_video_frames_generator(str(vid_path))) + assert len(frames) == 6 + + +def test_legacy_process_video(tmp_path): + src = tmp_path / "legacy_src.mp4" + dst = tmp_path / "legacy_dst.mp4" + _create_temp_video(str(src), frames=4) + + sv.process_video( + source_path=str(src), + target_path=str(dst), + callback=lambda f, i: f, + show_progress=False, + ) + + assert os.path.exists(dst) + assert sv.Video(str(dst)).info.total_frames == 4