From f72812919e62deac2c31628254f9c1c04b77a222 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Apr 2025 17:09:48 +0100 Subject: [PATCH] Crazy stuff --- src/torchcodec/_core/AVIOBytesContext.cpp | 41 +- src/torchcodec/_core/AVIOBytesContext.h | 4 +- src/torchcodec/_core/AVIOContextHolder.cpp | 26 +- src/torchcodec/_core/AVIOContextHolder.h | 54 +- src/torchcodec/_core/AVIOFileLikeContext.cpp | 39 +- src/torchcodec/_core/AVIOFileLikeContext.h | 6 +- src/torchcodec/_core/CPUOnlyDevice.cpp | 32 +- src/torchcodec/_core/CudaDevice.cpp | 188 +- src/torchcodec/_core/DeviceInterface.h | 24 +- src/torchcodec/_core/FFMPEGCommon.cpp | 96 +- src/torchcodec/_core/FFMPEGCommon.h | 50 +- src/torchcodec/_core/SingleStreamDecoder.cpp | 1734 +++++++++--------- src/torchcodec/_core/SingleStreamDecoder.h | 403 ++-- src/torchcodec/_core/custom_ops.cpp | 374 ++-- src/torchcodec/_core/pybind_ops.cpp | 12 +- 15 files changed, 1568 insertions(+), 1515 deletions(-) diff --git a/src/torchcodec/_core/AVIOBytesContext.cpp b/src/torchcodec/_core/AVIOBytesContext.cpp index 3e1481be..ffaf331b 100644 --- a/src/torchcodec/_core/AVIOBytesContext.cpp +++ b/src/torchcodec/_core/AVIOBytesContext.cpp @@ -9,55 +9,56 @@ namespace facebook::torchcodec { -AVIOBytesContext::AVIOBytesContext(const void* data, int64_t dataSize) - : dataContext_{static_cast(data), dataSize, 0} { +AVIOBytesContext::AVIOBytesContext(const void* data, int64_t data_size) + : data_context_{static_cast(data), data_size, 0} { TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!"); TORCH_CHECK(dataSize > 0, "Video data size must be positive"); - createAVIOContext(&read, &seek, &dataContext_); + create_avio_context(&read, &seek, &dataContext_); } // The signature of this function is defined by FFMPEG. int AVIOBytesContext::read(void* opaque, uint8_t* buf, int buf_size) { - auto dataContext = static_cast(opaque); + auto data_context = static_cast<_data_context*>(opaque); TORCH_CHECK( - dataContext->current <= dataContext->size, + data_context->current <= data_context->size, "Tried to read outside of the buffer: current=", - dataContext->current, + data_context->current, ", size=", - dataContext->size); + data_context->size); - int64_t numBytesRead = std::min( - static_cast(buf_size), dataContext->size - dataContext->current); + int64_t num_bytes_read = std::min( + static_cast(buf_size), + data_context->size - data_context->current); TORCH_CHECK( - numBytesRead >= 0, - "Tried to read negative bytes: numBytesRead=", - numBytesRead, + num_bytes_read >= 0, + "Tried to read negative bytes: num_bytes_read=", + num_bytes_read, ", size=", - dataContext->size, + data_context->size, ", current=", - dataContext->current); + data_context->current); if (numBytesRead == 0) { return AVERROR_EOF; } - std::memcpy(buf, dataContext->data + dataContext->current, numBytesRead); - dataContext->current += numBytesRead; - return numBytesRead; + std::memcpy(buf, data_context->data + data_context->current, num_bytes_read); + data_context->current += num_bytes_read; + return num_bytes_read; } // The signature of this function is defined by FFMPEG. int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) { - auto dataContext = static_cast(opaque); + auto data_context = static_cast<_data_context*>(opaque); int64_t ret = -1; switch (whence) { case AVSEEK_SIZE: - ret = dataContext->size; + ret = data_context->size; break; case SEEK_SET: - dataContext->current = offset; + data_context->current = offset; ret = offset; break; default: diff --git a/src/torchcodec/_core/AVIOBytesContext.h b/src/torchcodec/_core/AVIOBytesContext.h index c4fb7185..8cea346a 100644 --- a/src/torchcodec/_core/AVIOBytesContext.h +++ b/src/torchcodec/_core/AVIOBytesContext.h @@ -14,7 +14,7 @@ namespace facebook::torchcodec { // functions then traverse the bytes in memory. class AVIOBytesContext : public AVIOContextHolder { public: - explicit AVIOBytesContext(const void* data, int64_t dataSize); + explicit AVIOBytesContext(const void* data, int64_t data_size); private: struct DataContext { @@ -26,7 +26,7 @@ class AVIOBytesContext : public AVIOContextHolder { static int read(void* opaque, uint8_t* buf, int buf_size); static int64_t seek(void* opaque, int64_t offset, int whence); - DataContext dataContext_; + DataContext data_context_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOContextHolder.cpp b/src/torchcodec/_core/AVIOContextHolder.cpp index f0ef095f..26ba4206 100644 --- a/src/torchcodec/_core/AVIOContextHolder.cpp +++ b/src/torchcodec/_core/AVIOContextHolder.cpp @@ -9,24 +9,24 @@ namespace facebook::torchcodec { -void AVIOContextHolder::createAVIOContext( +void AVIOContextHolder::create_avio_context( AVIOReadFunction read, AVIOSeekFunction seek, - void* heldData, - int bufferSize) { + void* held_data, + int buffer_size) { TORCH_CHECK( - bufferSize > 0, - "Buffer size must be greater than 0; is " + std::to_string(bufferSize)); - auto buffer = static_cast(av_malloc(bufferSize)); + buffer_size > 0, + "Buffer size must be greater than 0; is " + std::to_string(buffer_size)); + auto buffer = static_cast(av_malloc(buffer_size)); TORCH_CHECK( buffer != nullptr, - "Failed to allocate buffer of size " + std::to_string(bufferSize)); + "Failed to allocate buffer of size " + std::to_string(buffer_size)); - avioContext_.reset(avio_alloc_context( + avio_context_.reset(avio_alloc_context( buffer, - bufferSize, + buffer_size, 0, - heldData, + held_data, read, nullptr, // write function; not supported yet seek)); @@ -39,12 +39,12 @@ void AVIOContextHolder::createAVIOContext( AVIOContextHolder::~AVIOContextHolder() { if (avioContext_) { - av_freep(&avioContext_->buffer); + av_freep(&avio_context_->buffer); } } -AVIOContext* AVIOContextHolder::getAVIOContext() { - return avioContext_.get(); +AVIOContext* AVIOContextHolder::get_avio_context() { + return avio_context_.get(); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOContextHolder.h b/src/torchcodec/_core/AVIOContextHolder.h index 3b094c26..ed175bd2 100644 --- a/src/torchcodec/_core/AVIOContextHolder.h +++ b/src/torchcodec/_core/AVIOContextHolder.h @@ -12,31 +12,31 @@ namespace facebook::torchcodec { // The AVIOContextHolder serves several purposes: // -// 1. It is a smart pointer for the AVIOContext. It has the logic to create -// a new AVIOContext and will appropriately free the AVIOContext when it -// goes out of scope. Note that this requires more than just having a -// UniqueAVIOContext, as the AVIOContext points to a buffer which must be -// freed. -// 2. It is a base class for AVIOContext specializations. When specializing a -// AVIOContext, we need to provide four things: -// 1. A read callback function. -// 2. A seek callback function. -// 3. A write callback function. (Not supported yet; it's for encoding.) -// 4. A pointer to some context object that has the same lifetime as the -// AVIOContext itself. This context object holds the custom state that -// tracks the custom behavior of reading, seeking and writing. It is -// provided upon AVIOContext creation and to the read, seek and -// write callback functions. -// While it's not required, it is natural for the derived classes to make -// all of the above members. Base classes need to call -// createAVIOContext(), ideally in their constructor. -// 3. A generic handle for those that just need to manage having access to an -// AVIOContext, but aren't necessarily concerned with how it was customized: -// typically, the SingleStreamDecoder. +// 1. It is a smart pointer for the AVIOContext. It has the logic to create +// a new AVIOContext and will appropriately free the AVIOContext when it +// goes out of scope. Note that this requires more than just having a +// UniqueAVIOContext, as the AVIOContext points to a buffer which must be +// freed. +// 2. It is a base class for AVIOContext specializations. When specializing a +// AVIOContext, we need to provide four things: +// 1. A read callback function. +// 2. A seek callback function. +// 3. A write callback function. (Not supported yet; it's for encoding.) +// 4. A pointer to some context object that has the same lifetime as the +// AVIOContext itself. This context object holds the custom state that +// tracks the custom behavior of reading, seeking and writing. It is +// provided upon AVIOContext creation and to the read, seek and +// write callback functions. +// While it's not required, it is natural for the derived classes to make +// all of the above members. Base classes need to call +// createAVIOContext(), ideally in their constructor. +// 3. A generic handle for those that just need to manage having access to an +// AVIOContext, but aren't necessarily concerned with how it was customized: +// typically, the SingleStreamDecoder. class AVIOContextHolder { public: virtual ~AVIOContextHolder(); - AVIOContext* getAVIOContext(); + AVIOContext* get_avio_context(); protected: // Make constructor protected to prevent anyone from constructing @@ -49,17 +49,17 @@ class AVIOContextHolder { using AVIOSeekFunction = int64_t (*)(void*, int64_t, int); // Deriving classes should call this function in their constructor. - void createAVIOContext( + void create_avio_context( AVIOReadFunction read, AVIOSeekFunction seek, - void* heldData, - int bufferSize = defaultBufferSize); + void* held_data, + int buffer_size = default_buffer_size); private: - UniqueAVIOContext avioContext_; + UniqueAVIOContext avio_context_; // Defaults to 64 KB - static const int defaultBufferSize = 64 * 1024; + static const int default_buffer_size = 64 * 1024; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index 4a905b93..25ccb7d3 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -9,61 +9,62 @@ namespace facebook::torchcodec { -AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) - : fileLike_{UniquePyObject(new py::object(fileLike))} { +AVIOFileLikeContext::AVIOFileLikeContext(py::object file_like) + : file_like_{UniquePyObject(new py::object(file_like))} { { // TODO: Is it necessary to acquire the GIL here? Is it maybe even // harmful? At the moment, this is only called from within a pybind // function, and pybind guarantees we have the GIL. py::gil_scoped_acquire gil; TORCH_CHECK( - py::hasattr(fileLike, "read"), + py::hasattr(file_like, "read"), "File like object must implement a read method."); TORCH_CHECK( - py::hasattr(fileLike, "seek"), + py::hasattr(file_like, "seek"), "File like object must implement a seek method."); } - createAVIOContext(&read, &seek, &fileLike_); + create_avio_context(&read, &seek, &fileLike_); } int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { - auto fileLike = static_cast(opaque); + auto file_like = static_cast(opaque); // Note that we acquire the GIL outside of the loop. This is likely more // efficient than releasing and acquiring it each loop iteration. py::gil_scoped_acquire gil; - int totalNumRead = 0; + int total_num_read = 0; while (totalNumRead < buf_size) { - int request = buf_size - totalNumRead; + int request = buf_size - total_num_read; // The Python method returns the actual bytes, which we access through the // py::bytes wrapper. That wrapper, however, does not provide us access to // the underlying data pointer, which we need for the memcpy below. So we // convert the bytes to a string_view to get access to the data pointer. // Becauase it's a view and not a copy, it should be cheap. - auto bytesRead = static_cast((*fileLike)->attr("read")(request)); - auto bytesView = static_cast(bytesRead); + auto bytes_read = + static_cast((*file_like)->attr("read")(request)); + auto bytes_view = static_cast(bytes_read); - int numBytesRead = static_cast(bytesView.size()); + int num_bytes_read = static_cast(bytes_view.size()); if (numBytesRead == 0) { break; } TORCH_CHECK( - numBytesRead <= request, + num_bytes_read <= request, "Requested up to ", request, " bytes but, received ", - numBytesRead, + num_bytes_read, " bytes. The given object does not conform to read protocol of file object."); - std::memcpy(buf, bytesView.data(), numBytesRead); - buf += numBytesRead; - totalNumRead += numBytesRead; + std::memcpy(buf, bytes_view.data(), num_bytes_read); + buf += num_bytes_read; + total_num_read += num_bytes_read; } - return totalNumRead == 0 ? AVERROR_EOF : totalNumRead; + return total_num_read == 0 ? AVERROR_EOF : total_num_read; } int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { @@ -72,9 +73,9 @@ int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) { return AVERROR(EIO); } - auto fileLike = static_cast(opaque); + auto file_like = static_cast(opaque); py::gil_scoped_acquire gil; - return py::cast((*fileLike)->attr("seek")(offset, whence)); + return py::cast((*file_like)->attr("seek")(offset, whence)); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOFileLikeContext.h b/src/torchcodec/_core/AVIOFileLikeContext.h index 3e80f1c6..34fc8528 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.h +++ b/src/torchcodec/_core/AVIOFileLikeContext.h @@ -19,7 +19,7 @@ namespace facebook::torchcodec { // and seek calls back up to the methods on the Python object. class AVIOFileLikeContext : public AVIOContextHolder { public: - explicit AVIOFileLikeContext(py::object fileLike); + explicit AVIOFileLikeContext(py::object file_like); private: static int read(void* opaque, uint8_t* buf, int buf_size); @@ -32,7 +32,7 @@ class AVIOFileLikeContext : public AVIOContextHolder { // we'd have to ensure whatever enclosing scope holds the object has the GIL, // and that's, at least, hard. For all of the common pitfalls, see: // - // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors + // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#common-sources-of-global-interpreter-lock-errors // // We maintain a reference to the file-like object because the file-like // object that was created on the Python side must live as long as our @@ -48,7 +48,7 @@ class AVIOFileLikeContext : public AVIOContextHolder { }; using UniquePyObject = std::unique_ptr; - UniquePyObject fileLike_; + UniquePyObject file_like_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CPUOnlyDevice.cpp b/src/torchcodec/_core/CPUOnlyDevice.cpp index 1d5b477d..b2d8f063 100644 --- a/src/torchcodec/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/_core/CPUOnlyDevice.cpp @@ -7,39 +7,39 @@ namespace facebook::torchcodec { // So all functions will throw an error because they should only be called if // the device is not CPU. -[[noreturn]] void throwUnsupportedDeviceError(const torch::Device& device) { +[[noreturn]] void throw_unsupported_device_error(const torch::Device& device) { TORCH_CHECK( device.type() != torch::kCPU, "Device functions should only be called if the device is not CPU.") TORCH_CHECK(false, "Unsupported device: " + device.str()); } -void convertAVFrameToFrameOutputOnCuda( +void convert_avframe_to_frame_output_on_cuda( const torch::Device& device, [[maybe_unused]] const SingleStreamDecoder::VideoStreamOptions& - videoStreamOptions, - [[maybe_unused]] UniqueAVFrame& avFrame, - [[maybe_unused]] SingleStreamDecoder::FrameOutput& frameOutput, - [[maybe_unused]] std::optional preAllocatedOutputTensor) { - throwUnsupportedDeviceError(device); + video_stream_options, + [[maybe_unused]] UniqueAVFrame& avframe, + [[maybe_unused]] SingleStreamDecoder::FrameOutput& frame_output, + [[maybe_unused]] std::optional pre_allocated_output_tensor) { + throw_unsupported_device_error(device); } -void initializeContextOnCuda( +void initialize_context_on_cuda( const torch::Device& device, - [[maybe_unused]] AVCodecContext* codecContext) { - throwUnsupportedDeviceError(device); + [[maybe_unused]] AVCodecContext* codec_context) { + throw_unsupported_device_error(device); } -void releaseContextOnCuda( +void release_context_on_cuda( const torch::Device& device, - [[maybe_unused]] AVCodecContext* codecContext) { - throwUnsupportedDeviceError(device); + [[maybe_unused]] AVCodecContext* codec_context) { + throw_unsupported_device_error(device); } -std::optional findCudaCodec( +std::optional find_cuda_codec( const torch::Device& device, - [[maybe_unused]] const AVCodecID& codecId) { - throwUnsupportedDeviceError(device); + [[maybe_unused]] const AVCodecID& codec_id) { + throw_unsupported_device_error(device); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDevice.cpp b/src/torchcodec/_core/CudaDevice.cpp index fd8be9de..c7695fd1 100644 --- a/src/torchcodec/_core/CudaDevice.cpp +++ b/src/torchcodec/_core/CudaDevice.cpp @@ -19,13 +19,13 @@ namespace { // We reuse cuda contexts across VideoDeoder instances. This is because // creating a cuda context is expensive. The cache mechanism is as follows: // 1. There is a cache of size MAX_CONTEXTS_PER_GPU_IN_CACHE cuda contexts for -// each GPU. +// each GPU. // 2. When we destroy a SingleStreamDecoder instance we release the cuda context // to -// the cache if the cache is not full. +// the cache if the cache is not full. // 3. When we create a SingleStreamDecoder instance we try to get a cuda context // from -// the cache. If the cache is empty we create a new cuda context. +// the cache. If the cache is empty we create a new cuda context. // Pytorch can only handle up to 128 GPUs. // https://github.com/pytorch/pytorch/blob/e30c55ee527b40d67555464b9e402b4b7ce03737/c10/cuda/CUDAMacros.h#L44 @@ -33,46 +33,47 @@ const int MAX_CUDA_GPUS = 128; // Set to -1 to have an infinitely sized cache. Set it to 0 to disable caching. // Set to a positive number to have a cache of that size. const int MAX_CONTEXTS_PER_GPU_IN_CACHE = -1; -std::vector g_cached_hw_device_ctxs[MAX_CUDA_GPUS]; +std::vector<_avbuffer_ref*> g_cached_hw_device_ctxs[MAX_CUDA_GPUS]; std::mutex g_cached_hw_device_mutexes[MAX_CUDA_GPUS]; -torch::DeviceIndex getFFMPEGCompatibleDeviceIndex(const torch::Device& device) { - torch::DeviceIndex deviceIndex = device.index(); - deviceIndex = std::max(deviceIndex, 0); +torch::DeviceIndex get_ffmpeg_compatible_device_index( + const torch::Device& device) { + torch::DeviceIndex device_index = device.index(); + device_index = std::max(deviceIndex, 0); TORCH_CHECK(deviceIndex >= 0, "Device index out of range"); // FFMPEG cannot handle negative device indices. // For single GPU- machines libtorch returns -1 for the device index. So for // that case we set the device index to 0. // TODO: Double check if this works for multi-GPU machines correctly. - return deviceIndex; + return device_index; } -void addToCacheIfCacheHasCapacity( +void add_to_cache_if_cache_has_capacity( const torch::Device& device, - AVCodecContext* codecContext) { - torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device); + AVCodecContext* codec_context) { + torch::DeviceIndex device_index = get_ffmpeg_compatible_device_index(device); if (static_cast(deviceIndex) >= MAX_CUDA_GPUS) { return; } - std::scoped_lock lock(g_cached_hw_device_mutexes[deviceIndex]); + std::scoped_lock lock(g_cached_hw_device_mutexes[device_index]); if (MAX_CONTEXTS_PER_GPU_IN_CACHE >= 0 && - g_cached_hw_device_ctxs[deviceIndex].size() >= + g_cached_hw_device_ctxs[device_index].size() >= MAX_CONTEXTS_PER_GPU_IN_CACHE) { return; } - g_cached_hw_device_ctxs[deviceIndex].push_back(codecContext->hw_device_ctx); - codecContext->hw_device_ctx = nullptr; + g_cached_hw_device_ctxs[device_index].push_back(codec_context->hw_device_ctx); + codec_context->hw_device_ctx = nullptr; } -AVBufferRef* getFromCache(const torch::Device& device) { - torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device); +AVBufferRef* get_from_cache(const torch::Device& device) { + torch::DeviceIndex device_index = get_ffmpeg_compatible_device_index(device); if (static_cast(deviceIndex) >= MAX_CUDA_GPUS) { return nullptr; } - std::scoped_lock lock(g_cached_hw_device_mutexes[deviceIndex]); + std::scoped_lock lock(g_cached_hw_device_mutexes[device_index]); if (g_cached_hw_device_ctxs[deviceIndex].size() > 0) { - AVBufferRef* hw_device_ctx = g_cached_hw_device_ctxs[deviceIndex].back(); - g_cached_hw_device_ctxs[deviceIndex].pop_back(); + AVBufferRef* hw_device_ctx = g_cached_hw_device_ctxs[device_index].back(); + g_cached_hw_device_ctxs[device_index].pop_back(); return hw_device_ctx; } return nullptr; @@ -80,33 +81,33 @@ AVBufferRef* getFromCache(const torch::Device& device) { #if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) -AVBufferRef* getFFMPEGContextFromExistingCudaContext( +AVBufferRef* get_ffmpeg_context_from_existing_cuda_context( const torch::Device& device, - torch::DeviceIndex nonNegativeDeviceIndex, + torch::DeviceIndex non_negative_device_index, enum AVHWDeviceType type) { - c10::cuda::CUDAGuard deviceGuard(device); + c10::cuda::CUDAGuard device_guard(device); // Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1: // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb // So we ensure the deviceIndex is not negative. // We set the device because we may be called from a different thread than // the one that initialized the cuda context. - cudaSetDevice(nonNegativeDeviceIndex); + cuda_set_device(non_negative_device_index); AVBufferRef* hw_device_ctx = nullptr; - std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex); + std::string device_ordinal = std::to_string(non_negative_device_index); int err = av_hwdevice_ctx_create( &hw_device_ctx, type, - deviceOrdinal.c_str(), + device_ordinal.c_str(), nullptr, AV_CUDA_USE_CURRENT_CONTEXT); if (err < 0) { /* clang-format off */ - TORCH_CHECK( - false, - "Failed to create specified HW device. This typically happens when ", - "your installed FFmpeg doesn't support CUDA (see ", - "https://github.com/pytorch/torchcodec#installing-cuda-enabled-torchcodec", - "). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err)); +TORCH_CHECK( +false, +"Failed to create specified HW device. This typically happens when ", +"your installed FFmpeg doesn't support CUDA (see ", +"https://github.com/pytorch/torchcodec#installing-cuda-enabled-torchcodec", +"). FFmpeg error: ", get_ffmpeg_error_string_from_error_code(err)); /* clang-format on */ } return hw_device_ctx; @@ -114,98 +115,98 @@ AVBufferRef* getFFMPEGContextFromExistingCudaContext( #else -AVBufferRef* getFFMPEGContextFromNewCudaContext( +AVBufferRef* get_ffmpeg_context_from_new_cuda_context( [[maybe_unused]] const torch::Device& device, - torch::DeviceIndex nonNegativeDeviceIndex, + torch::DeviceIndex non_negative_device_index, enum AVHWDeviceType type) { AVBufferRef* hw_device_ctx = nullptr; - std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex); + std::string device_ordinal = std::to_string(non_negative_device_index); int err = av_hwdevice_ctx_create( - &hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0); + &hw_device_ctx, type, device_ordinal.c_str(), nullptr, 0); if (err < 0) { TORCH_CHECK( false, "Failed to create specified HW device", - getFFMPEGErrorStringFromErrorCode(err)); + get_ffmpeg_error_string_from_error_code(err)); } return hw_device_ctx; } #endif -AVBufferRef* getCudaContext(const torch::Device& device) { +AVBufferRef* get_cuda_context(const torch::Device& device) { enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); - torch::DeviceIndex nonNegativeDeviceIndex = - getFFMPEGCompatibleDeviceIndex(device); + torch::DeviceIndex non_negative_device_index = + get_ffmpeg_compatible_device_index(device); - AVBufferRef* hw_device_ctx = getFromCache(device); + AVBufferRef* hw_device_ctx = get_from_cache(device); if (hw_device_ctx != nullptr) { return hw_device_ctx; } - // 58.26.100 introduced the concept of reusing the existing cuda context - // which is much faster and lower memory than creating a new cuda context. - // So we try to use that if it is available. - // FFMPEG 6.1.2 appears to be the earliest release that contains version - // 58.26.100 of avutil. - // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265 +// 58.26.100 introduced the concept of reusing the existing cuda context +// which is much faster and lower memory than creating a new cuda context. +// So we try to use that if it is available. +// FFMPEG 6.1.2 appears to be the earliest release that contains version +// 58.26.100 of avutil. +// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265 #if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) - return getFFMPEGContextFromExistingCudaContext( - device, nonNegativeDeviceIndex, type); + return get_ffmpeg_context_from_existing_cuda_context( + device, non_negative_device_index, type); #else - return getFFMPEGContextFromNewCudaContext( - device, nonNegativeDeviceIndex, type); + return get_ffmpeg_context_from_new_cuda_context( + device, non_negative_device_index, type); #endif } -void throwErrorIfNonCudaDevice(const torch::Device& device) { +void throw_error_if_non_cuda_device(const torch::Device& device) { TORCH_CHECK( device.type() != torch::kCPU, "Device functions should only be called if the device is not CPU.") if (device.type() != torch::kCUDA) { - throw std::runtime_error("Unsupported device: " + device.str()); + throw std::runtime_error("_unsupported device: " + device.str()); } } } // namespace -void releaseContextOnCuda( +void release_context_on_cuda( const torch::Device& device, - AVCodecContext* codecContext) { - throwErrorIfNonCudaDevice(device); - addToCacheIfCacheHasCapacity(device, codecContext); + AVCodecContext* codec_context) { + throw_error_if_non_cuda_device(device); + add_to_cache_if_cache_has_capacity(device, codec_context); } -void initializeContextOnCuda( +void initialize_context_on_cuda( const torch::Device& device, - AVCodecContext* codecContext) { - throwErrorIfNonCudaDevice(device); + AVCodecContext* codec_context) { + throw_error_if_non_cuda_device(device); // It is important for pytorch itself to create the cuda context. If ffmpeg // creates the context it may not be compatible with pytorch. // This is a dummy tensor to initialize the cuda context. - torch::Tensor dummyTensorForCudaInitialization = torch::empty( + torch::Tensor dummy_tensor_for_cuda_initialization = torch::empty( {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device)); - codecContext->hw_device_ctx = getCudaContext(device); + codec_context->hw_device_ctx = get_cuda_context(device); return; } -void convertAVFrameToFrameOutputOnCuda( +void convert_avframe_to_frame_output_on_cuda( const torch::Device& device, - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, - UniqueAVFrame& avFrame, - SingleStreamDecoder::FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor) { + const SingleStreamDecoder::VideoStreamOptions& video_stream_options, + UniqueAVFrame& avframe, + SingleStreamDecoder::FrameOutput& frame_output, + std::optional pre_allocated_output_tensor) { TORCH_CHECK( - avFrame->format == AV_PIX_FMT_CUDA, + avframe->format == AV_PIX_FMT_CUDA, "Expected format to be AV_PIX_FMT_CUDA, got " + - std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format))); - auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); - int height = frameDims.height; - int width = frameDims.width; - torch::Tensor& dst = frameOutput.data; + std::string(av_get_pix_fmt_name((AVPixelFormat)avframe->format))); + auto frame_dims = get_height_and_width_from_options_or_avframe( + video_stream_options, avframe); + int height = frame_dims.height; + int width = frame_dims.width; + torch::Tensor& dst = frame_output.data; if (preAllocatedOutputTensor.has_value()) { - dst = preAllocatedOutputTensor.value(); + dst = pre_allocated_output_tensor.value(); auto shape = dst.sizes(); TORCH_CHECK( (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && @@ -217,40 +218,41 @@ void convertAVFrameToFrameOutputOnCuda( "x3, got ", shape); } else { - dst = allocateEmptyHWCTensor(height, width, videoStreamOptions.device); + dst = + allocate_empty_h_w_c_tensor(height, width, video_stream_options.device); } // Use the user-requested GPU for running the NPP kernel. - c10::cuda::CUDAGuard deviceGuard(device); + c10::cuda::CUDAGuard device_guard(device); - NppiSize oSizeROI = {width, height}; - Npp8u* input[2] = {avFrame->data[0], avFrame->data[1]}; + NppiSize o_sizeroi = {width, height}; + Npp8u* input[2] = {avFrame->data[0], avframe->data[1]}; auto start = std::chrono::high_resolution_clock::now(); NppStatus status; if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) { status = nppiNV12ToRGB_709CSC_8u_P2C3R( input, - avFrame->linesize[0], - static_cast(dst.data_ptr()), + avframe->linesize[0], + static_cast<_npp8u*>(dst.data_ptr()), dst.stride(0), - oSizeROI); + o_sizeroi); } else { - status = nppiNV12ToRGB_8u_P2C3R( + status = nppi_n_v12_to_r_g_b_8u__p2_c3_r( input, - avFrame->linesize[0], - static_cast(dst.data_ptr()), + avframe->linesize[0], + static_cast<_npp8u*>(dst.data_ptr()), dst.stride(0), - oSizeROI); + o_sizeroi); } TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); // Make the pytorch stream wait for the npp kernel to finish before using the // output. - at::cuda::CUDAEvent nppDoneEvent; - at::cuda::CUDAStream nppStreamWrapper = + at::cuda::CUDAEvent npp_done_event; + at::cuda::CUDAStream npp_stream_wrapper = c10::cuda::getStreamFromExternal(nppGetStream(), device.index()); - nppDoneEvent.record(nppStreamWrapper); + npp_done_event.record(npp_stream_wrapper); nppDoneEvent.block(at::cuda::getCurrentCUDAStream()); auto end = std::chrono::high_resolution_clock::now(); @@ -264,15 +266,15 @@ void convertAVFrameToFrameOutputOnCuda( // we have to do this because of an FFmpeg bug where hardware decoding is not // appropriately set, so we just go off and find the matching codec for the CUDA // device -std::optional findCudaCodec( +std::optional find_cuda_codec( const torch::Device& device, - const AVCodecID& codecId) { - throwErrorIfNonCudaDevice(device); + const AVCodecID& codec_id) { + throw_error_if_non_cuda_device(device); void* i = nullptr; const AVCodec* codec = nullptr; while ((codec = av_codec_iterate(&i)) != nullptr) { - if (codec->id != codecId || !av_codec_is_decoder(codec)) { + if (codec->id != codec_id || !av_codec_is_decoder(codec)) { continue; } diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 352b83d3..5c61bf87 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -20,28 +20,28 @@ namespace facebook::torchcodec { // SingleStreamDecoder implementation. // These functions should only be called from within an if block like this: // if (device.type() != torch::kCPU) { -// deviceFunction(device, ...); +// deviceFunction(device, ...); // } // Initialize the hardware device that is specified in `device`. Some builds // support CUDA and others only support CPU. -void initializeContextOnCuda( +void initialize_context_on_cuda( const torch::Device& device, - AVCodecContext* codecContext); + AVCodecContext* codec_context); -void convertAVFrameToFrameOutputOnCuda( +void convert_avframe_to_frame_output_on_cuda( const torch::Device& device, - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, - UniqueAVFrame& avFrame, - SingleStreamDecoder::FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = std::nullopt); + const SingleStreamDecoder::VideoStreamOptions& video_stream_options, + UniqueAVFrame& avframe, + SingleStreamDecoder::FrameOutput& frame_output, + std::optional pre_allocated_output_tensor = std::nullopt); -void releaseContextOnCuda( +void release_context_on_cuda( const torch::Device& device, - AVCodecContext* codecContext); + AVCodecContext* codec_context); -std::optional findCudaCodec( +std::optional find_cuda_codec( const torch::Device& device, - const AVCodecID& codecId); + const AVCodecID& codec_id); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 33c8b484..1ddb9211 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -10,120 +10,120 @@ namespace facebook::torchcodec { -AutoAVPacket::AutoAVPacket() : avPacket_(av_packet_alloc()) { - TORCH_CHECK(avPacket_ != nullptr, "Couldn't allocate avPacket."); +AutoAVPacket::AutoAVPacket() : avpacket_(avpacket_alloc()) { + TORCH_CHECK(avPacket_ != nullptr, "Couldn't allocate avpacket."); } AutoAVPacket::~AutoAVPacket() { - av_packet_free(&avPacket_); + avpacket_free(&avpacket_); } ReferenceAVPacket::ReferenceAVPacket(AutoAVPacket& shared) - : avPacket_(shared.avPacket_) {} + : avpacket_(shared.avpacket_) {} ReferenceAVPacket::~ReferenceAVPacket() { - av_packet_unref(avPacket_); + avpacket_unref(avpacket_); } AVPacket* ReferenceAVPacket::get() { - return avPacket_; + return avpacket_; } AVPacket* ReferenceAVPacket::operator->() { - return avPacket_; + return avpacket_; } AVCodecOnlyUseForCallingAVFindBestStream -makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) { +make_avcodec_only_use_for_calling_avfind_best_stream(const AVCodec* codec) { #if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) - return const_cast(codec); + return const_cast<_avcodec*>(codec); #else return codec; #endif } -std::string getFFMPEGErrorStringFromErrorCode(int errorCode) { - char errorBuffer[AV_ERROR_MAX_STRING_SIZE] = {0}; - av_strerror(errorCode, errorBuffer, AV_ERROR_MAX_STRING_SIZE); - return std::string(errorBuffer); +std::string get_ffmpeg_error_string_from_error_code(int error_code) { + char error_buffer[AV_ERROR_MAX_STRING_SIZE] = {0}; + av_strerror(error_code, error_buffer, AV_ERROR_MAX_STRING_SIZE); + return std::string(error_buffer); } -int64_t getDuration(const UniqueAVFrame& avFrame) { +int64_t get_duration(const UniqueAVFrame& avframe) { #if LIBAVUTIL_VERSION_MAJOR < 58 - return avFrame->pkt_duration; + return avframe->pkt_duration; #else - return avFrame->duration; + return avframe->duration; #endif } -int getNumChannels(const UniqueAVFrame& avFrame) { +int get_num_channels(const UniqueAVFrame& avframe) { #if LIBAVFILTER_VERSION_MAJOR > 8 || \ (LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) - return avFrame->ch_layout.nb_channels; + return avframe->ch_layout.nb_channels; #else - return av_get_channel_layout_nb_channels(avFrame->channel_layout); + return av_get_channel_layout_nb_channels(avframe->channel_layout); #endif } -int getNumChannels(const UniqueAVCodecContext& avCodecContext) { +int get_num_channels(const UniqueAVCodecContext& av_codec_context) { #if LIBAVFILTER_VERSION_MAJOR > 8 || \ (LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) - return avCodecContext->ch_layout.nb_channels; + return av_codec_context->ch_layout.nb_channels; #else - return avCodecContext->channels; + return av_codec_context->channels; #endif } -void setChannelLayout( - UniqueAVFrame& dstAVFrame, - const UniqueAVFrame& srcAVFrame) { +void set_channel_layout( + UniqueAVFrame& dst_avframe, + const UniqueAVFrame& src_avframe) { #if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 - dstAVFrame->ch_layout = srcAVFrame->ch_layout; + dst_avframe->ch_layout = src_avframe->ch_layout; #else - dstAVFrame->channel_layout = srcAVFrame->channel_layout; + dst_avframe->channel_layout = src_avframe->channel_layout; #endif } -SwrContext* allocateSwrContext( - UniqueAVCodecContext& avCodecContext, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate) { - SwrContext* swrContext = nullptr; +SwrContext* allocate_swr_context( + UniqueAVCodecContext& av_codec_context, + AVSampleFormat source_sample_format, + AVSampleFormat desired_sample_format, + int source_sample_rate, + int desired_sample_rate) { + SwrContext* swr_context = nullptr; #if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 - AVChannelLayout layout = avCodecContext->ch_layout; + AVChannelLayout layout = av_codec_context->ch_layout; auto status = swr_alloc_set_opts2( &swrContext, &layout, - desiredSampleFormat, - desiredSampleRate, + desired_sample_format, + desired_sample_rate, &layout, - sourceSampleFormat, - sourceSampleRate, + source_sample_format, + source_sample_rate, 0, nullptr); TORCH_CHECK( status == AVSUCCESS, "Couldn't create SwrContext: ", - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); #else - int64_t layout = static_cast(avCodecContext->channel_layout); - swrContext = swr_alloc_set_opts( + int64_t layout = static_cast(av_codec_context->channel_layout); + swr_context = swr_alloc_set_opts( nullptr, layout, - desiredSampleFormat, - desiredSampleRate, + desired_sample_format, + desired_sample_rate, layout, - sourceSampleFormat, - sourceSampleRate, + source_sample_format, + source_sample_rate, 0, nullptr); #endif - TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext"); - return swrContext; + TORCH_CHECK(swrContext != nullptr, "Couldn't create swr_context"); + return swr_context; } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 0309bf93..83f848d0 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -76,22 +76,22 @@ using UniqueSwrContext = // // AutoAVPacket autoAVPacket; // <-- malloc for AVPacket happens here // while(...){ -// ReferenceAVPacket packet(autoAVPacket); -// av_read_frame(..., packet.get()); <-- av_packet_ref() called by FFmpeg -// } <-- av_packet_unref() called here +// ReferenceAVPacket packet(autoAVPacket); +// av_read_frame(..., packet.get()); <-- avpacket_ref() called by FFmpeg +// } <-- avpacket_unref() called here // // This achieves a few desirable things: // - Memory allocation of the underlying AVPacket happens only once, when -// autoAVPacket is created. -// - av_packet_free() is called when autoAVPacket gets out of scope -// - av_packet_unref() is automatically called when needed, i.e. at the end of -// each loop iteration (or when hitting break / continue). This prevents the -// risk of us forgetting to call it. +// autoAVPacket is created. +// - avpacket_free() is called when autoAVPacket gets out of scope +// - avpacket_unref() is automatically called when needed, i.e. at the end of +// each loop iteration (or when hitting break / continue). This prevents the +// risk of us forgetting to call it. class AutoAVPacket { friend class ReferenceAVPacket; private: - AVPacket* avPacket_; + AVPacket* avpacket_; public: AutoAVPacket(); @@ -102,7 +102,7 @@ class AutoAVPacket { class ReferenceAVPacket { private: - AVPacket* avPacket_; + AVPacket* avpacket_; public: explicit ReferenceAVPacket(AutoAVPacket& shared); @@ -127,34 +127,34 @@ using AVCodecOnlyUseForCallingAVFindBestStream = const AVCodec*; #endif AVCodecOnlyUseForCallingAVFindBestStream -makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec); +make_avcodec_only_use_for_calling_avfind_best_stream(const AVCodec* codec); // Success code from FFMPEG is just a 0. We define it to make the code more // readable. const int AVSUCCESS = 0; // Returns the FFMPEG error as a string using the provided `errorCode`. -std::string getFFMPEGErrorStringFromErrorCode(int errorCode); +std::string get_ffmpeg_error_string_from_error_code(int error_code); // Returns duration from the frame. Abstracted into a function because the // struct member representing duration has changed across the versions we // support. -int64_t getDuration(const UniqueAVFrame& frame); +int64_t get_duration(const UniqueAVFrame& frame); -int getNumChannels(const UniqueAVFrame& avFrame); -int getNumChannels(const UniqueAVCodecContext& avCodecContext); +int get_num_channels(const UniqueAVFrame& avframe); +int get_num_channels(const UniqueAVCodecContext& av_codec_context); -void setChannelLayout( - UniqueAVFrame& dstAVFrame, - const UniqueAVFrame& srcAVFrame); -SwrContext* allocateSwrContext( - UniqueAVCodecContext& avCodecContext, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate); +void set_channel_layout( + UniqueAVFrame& dst_avframe, + const UniqueAVFrame& src_avframe); +SwrContext* allocate_swr_context( + UniqueAVCodecContext& av_codec_context, + AVSampleFormat source_sample_format, + AVSampleFormat desired_sample_format, + int source_sample_rate, + int desired_sample_rate); // Returns true if sws_scale can handle unaligned data. -bool canSwsScaleHandleUnalignedData(); +bool can_sws_scale_handle_unaligned_data(); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index efd93498..b711e490 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -26,16 +26,16 @@ extern "C" { namespace facebook::torchcodec { namespace { -double ptsToSeconds(int64_t pts, int den) { +double pts_to_seconds(int64_t pts, int den) { return static_cast(pts) / den; } -double ptsToSeconds(int64_t pts, const AVRational& timeBase) { - return ptsToSeconds(pts, timeBase.den); +double pts_to_seconds(int64_t pts, const AVRational& time_base) { + return pts_to_seconds(pts, time_base.den); } -int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) { - return static_cast(std::round(seconds * timeBase.den)); +int64_t seconds_to_closest_pts(double seconds, const AVRational& time_base) { + return static_cast(std::round(seconds * time_base.den)); } } // namespace @@ -45,59 +45,59 @@ int64_t secondsToClosestPts(double seconds, const AVRational& timeBase) { // -------------------------------------------------------------------------- SingleStreamDecoder::SingleStreamDecoder( - const std::string& videoFilePath, - SeekMode seekMode) - : seekMode_(seekMode) { - setFFmpegLogLevel(); - - AVFormatContext* rawContext = nullptr; - int status = - avformat_open_input(&rawContext, videoFilePath.c_str(), nullptr, nullptr); + const std::string& video_file_path, + SeekMode seek_mode) + : seek_mode_(seek_mode) { + set_ffmpeg_log_level(); + + AVFormatContext* raw_context = nullptr; + int status = avformat_open_input( + &raw_context, video_file_path.c_str(), nullptr, nullptr); TORCH_CHECK( status == 0, - "Could not open input file: " + videoFilePath + " " + - getFFMPEGErrorStringFromErrorCode(status)); + "Could not open input file: " + video_file_path + " " + + get_ffmpeg_error_string_from_error_code(status)); TORCH_CHECK(rawContext != nullptr); - formatContext_.reset(rawContext); + format_context_.reset(raw_context); - initializeDecoder(); + initialize_decoder(); } SingleStreamDecoder::SingleStreamDecoder( - std::unique_ptr context, - SeekMode seekMode) - : seekMode_(seekMode), avioContextHolder_(std::move(context)) { - setFFmpegLogLevel(); + std::unique_ptr context, + SeekMode seek_mode) + : seek_mode_(seek_mode), avio_context_holder_(std::move(context)) { + set_ffmpeg_log_level(); TORCH_CHECK(avioContextHolder_, "Context holder cannot be null"); // Because FFmpeg requires a reference to a pointer in the call to open, we // can't use a unique pointer here. Note that means we must call free if open // fails. - AVFormatContext* rawContext = avformat_alloc_context(); + AVFormatContext* raw_context = avformat_alloc_context(); TORCH_CHECK(rawContext != nullptr, "Unable to alloc avformat context"); - rawContext->pb = avioContextHolder_->getAVIOContext(); - int status = avformat_open_input(&rawContext, nullptr, nullptr, nullptr); + raw_context->pb = avio_context_holder_->get_avio_context(); + int status = avformat_open_input(&raw_context, nullptr, nullptr, nullptr); if (status != 0) { - avformat_free_context(rawContext); + avformat_free_context(raw_context); TORCH_CHECK( false, "Failed to open input buffer: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } - formatContext_.reset(rawContext); + format_context_.reset(raw_context); - initializeDecoder(); + initialize_decoder(); } SingleStreamDecoder::~SingleStreamDecoder() { - for (auto& [streamIndex, streamInfo] : streamInfos_) { - auto& device = streamInfo.videoStreamOptions.device; + for (auto& [streamIndex, stream_info] : stream_infos_) { + auto& device = stream_info.video_stream_options.device; if (device.type() == torch::kCPU) { } else if (device.type() == torch::kCUDA) { - releaseContextOnCuda(device, streamInfo.codecContext.get()); + release_context_on_cuda(device, stream_info.codec_context.get()); } else { TORCH_CHECK(false, "Invalid device type: " + device.str()); } @@ -111,130 +111,133 @@ void SingleStreamDecoder::initializeDecoder() { // avformat_open_input() which reads the header. However, some formats do not // store enough info in the header, so we call avformat_find_stream_info() // which decodes a few frames to get missing info. For more, see: - // https://ffmpeg.org/doxygen/7.0/group__lavf__decoding.html - int status = avformat_find_stream_info(formatContext_.get(), nullptr); + // https://ffmpeg.org/doxygen/7.0/group__lavf__decoding.html + int status = avformat_find_stream_info(format_context_.get(), nullptr); if (status < 0) { throw std::runtime_error( "Failed to find stream info: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } - for (unsigned int i = 0; i < formatContext_->nb_streams; i++) { - AVStream* avStream = formatContext_->streams[i]; - StreamMetadata streamMetadata; + for (unsigned int i = 0; i < format_context_->nb_streams; i++) { + AVStream* av_stream = format_context_->streams[i]; + StreamMetadata stream_metadata; TORCH_CHECK( - static_cast(i) == avStream->index, + static_cast(i) == av_stream->index, "Our stream index, " + std::to_string(i) + ", does not match AVStream's index, " + - std::to_string(avStream->index) + "."); - streamMetadata.streamIndex = i; - streamMetadata.mediaType = avStream->codecpar->codec_type; - streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id); - streamMetadata.bitRate = avStream->codecpar->bit_rate; - - int64_t frameCount = avStream->nb_frames; + std::to_string(av_stream->index) + "."); + stream_metadata.stream_index = i; + stream_metadata.media_type = av_stream->codecpar->codec_type; + stream_metadata.codec_name = + avcodec_get_name(av_stream->codecpar->codec_id); + stream_metadata.bit_rate = av_stream->codecpar->bit_rate; + + int64_t frame_count = av_stream->nb_frames; if (frameCount > 0) { - streamMetadata.numFrames = frameCount; + stream_metadata.num_frames = frame_count; } - if (avStream->duration > 0 && avStream->time_base.den > 0) { - streamMetadata.durationSeconds = - av_q2d(avStream->time_base) * avStream->duration; + if (avStream->duration > 0 && av_stream->time_base.den > 0) { + stream_metadata.duration_seconds = + av_q2d(av_stream->time_base) * av_stream->duration; } if (avStream->start_time != AV_NOPTS_VALUE) { - streamMetadata.beginStreamFromHeader = - av_q2d(avStream->time_base) * avStream->start_time; + stream_metadata.begin_stream_from_header = + av_q2d(av_stream->time_base) * av_stream->start_time; } if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) { - double fps = av_q2d(avStream->r_frame_rate); + double fps = av_q2d(av_stream->r_frame_rate); if (fps > 0) { - streamMetadata.averageFps = fps; + stream_metadata.average_fps = fps; } - containerMetadata_.numVideoStreams++; + container_metadata_.num_video_streams++; } else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { AVSampleFormat format = - static_cast(avStream->codecpar->format); + static_cast<_avsample_format>(av_stream->codecpar->format); // If the AVSampleFormat is not recognized, we get back nullptr. We have // to make sure we don't initialize a std::string with nullptr. There's // nothing to do on the else branch because we're already using an // optional; it'll just remain empty. - const char* rawSampleFormat = av_get_sample_fmt_name(format); + const char* raw_sample_format = av_get_sample_fmt_name(format); if (rawSampleFormat != nullptr) { - streamMetadata.sampleFormat = std::string(rawSampleFormat); + stream_metadata.sample_format = std::string(raw_sample_format); } - containerMetadata_.numAudioStreams++; + container_metadata_.num_audio_streams++; } - containerMetadata_.allStreamMetadata.push_back(streamMetadata); + container_metadata_.all_stream_metadata.push_back(stream_metadata); } if (formatContext_->duration > 0) { - containerMetadata_.durationSeconds = - ptsToSeconds(formatContext_->duration, AV_TIME_BASE); + container_metadata_.duration_seconds = + pts_to_seconds(format_context_->duration, AV_TIME_BASE); } if (formatContext_->bit_rate > 0) { - containerMetadata_.bitRate = formatContext_->bit_rate; + container_metadata_.bit_rate = format_context_->bit_rate; } - int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO); + int best_video_stream = + get_best_stream_index(_avm_e_d_i_a__t_y_p_e__v_i_d_e_o); if (bestVideoStream >= 0) { - containerMetadata_.bestVideoStreamIndex = bestVideoStream; + container_metadata_.best_video_stream_index = best_video_stream; } - int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO); + int best_audio_stream = + get_best_stream_index(_avm_e_d_i_a__t_y_p_e__a_u_d_i_o); if (bestAudioStream >= 0) { - containerMetadata_.bestAudioStreamIndex = bestAudioStream; + container_metadata_.best_audio_stream_index = best_audio_stream; } if (seekMode_ == SeekMode::exact) { - scanFileAndUpdateMetadataAndIndex(); + scan_file_and_update_metadata_and_index(); } initialized_ = true; } void SingleStreamDecoder::setFFmpegLogLevel() { - auto logLevel = AV_LOG_QUIET; - const char* logLevelEnv = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL"); + auto log_level = AV_LOG_QUIET; + const char* log_level_env = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL"); if (logLevelEnv != nullptr) { if (std::strcmp(logLevelEnv, "QUIET") == 0) { - logLevel = AV_LOG_QUIET; + log_level = AV_LOG_QUIET; } else if (std::strcmp(logLevelEnv, "PANIC") == 0) { - logLevel = AV_LOG_PANIC; + log_level = AV_LOG_PANIC; } else if (std::strcmp(logLevelEnv, "FATAL") == 0) { - logLevel = AV_LOG_FATAL; + log_level = AV_LOG_FATAL; } else if (std::strcmp(logLevelEnv, "ERROR") == 0) { - logLevel = AV_LOG_ERROR; + log_level = AV_LOG_ERROR; } else if (std::strcmp(logLevelEnv, "WARNING") == 0) { - logLevel = AV_LOG_WARNING; + log_level = AV_LOG_WARNING; } else if (std::strcmp(logLevelEnv, "INFO") == 0) { - logLevel = AV_LOG_INFO; + log_level = AV_LOG_INFO; } else if (std::strcmp(logLevelEnv, "VERBOSE") == 0) { - logLevel = AV_LOG_VERBOSE; + log_level = AV_LOG_VERBOSE; } else if (std::strcmp(logLevelEnv, "DEBUG") == 0) { - logLevel = AV_LOG_DEBUG; + log_level = AV_LOG_DEBUG; } else if (std::strcmp(logLevelEnv, "TRACE") == 0) { - logLevel = AV_LOG_TRACE; + log_level = AV_LOG_TRACE; } else { TORCH_CHECK( false, "Invalid TORCHCODEC_FFMPEG_LOG_LEVEL: ", - logLevelEnv, + log_level_env, ". Use e.g. 'QUIET', 'PANIC', 'VERBOSE', etc."); } } - av_log_set_level(logLevel); + av_log_set_level(log_level); } -int SingleStreamDecoder::getBestStreamIndex(AVMediaType mediaType) { - AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; - int streamIndex = - av_find_best_stream(formatContext_.get(), mediaType, -1, -1, &avCodec, 0); - return streamIndex; +int SingleStreamDecoder::getBestStreamIndex(AVMediaType media_type) { + AVCodecOnlyUseForCallingAVFindBestStream av_codec = nullptr; + int stream_index = av_find_best_stream( + format_context_.get(), media_type, -1, -1, &avCodec, 0); + return stream_index; } // -------------------------------------------------------------------------- @@ -246,19 +249,19 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { return; } - for (unsigned int i = 0; i < formatContext_->nb_streams; ++i) { + for (unsigned int i = 0; i < format_context_->nb_streams; ++i) { // We want to scan and update the metadata of all streams. TORCH_CHECK( - formatContext_->streams[i]->discard != AVDISCARD_ALL, + format_context_->streams[i]->discard != AVDISCARD_ALL, "Did you add a stream before you called for a scan?"); } - AutoAVPacket autoAVPacket; + AutoAVPacket auto_avpacket; while (true) { - ReferenceAVPacket packet(autoAVPacket); + ReferenceAVPacket packet(auto_avpacket); // av_read_frame is a misleading name: it gets the next **packet**. - int status = av_read_frame(formatContext_.get(), packet.get()); + int status = av_read_frame(format_context_.get(), packet.get()); if (status == AVERROR_EOF) { break; @@ -267,7 +270,7 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { if (status != AVSUCCESS) { throw std::runtime_error( "Failed to read frame from input file: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } if (packet->flags & AV_PKT_FLAG_DISCARD) { @@ -276,110 +279,113 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { // We got a valid packet. Let's figure out what stream it belongs to and // record its relevant metadata. - int streamIndex = packet->stream_index; - auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; - streamMetadata.minPtsFromScan = std::min( - streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts); - streamMetadata.maxPtsFromScan = std::max( - streamMetadata.maxPtsFromScan.value_or(INT64_MIN), + int stream_index = packet->stream_index; + auto& stream_metadata = + container_metadata_.all_stream_metadata[stream_index]; + stream_metadata.min_pts_from_scan = std::min( + stream_metadata.min_pts_from_scan.value_or(_i_n_t64__m_a_x), + packet->pts); + stream_metadata.max_pts_from_scan = std::max( + stream_metadata.max_pts_from_scan.value_or(_i_n_t64__m_i_n), packet->pts + packet->duration); - streamMetadata.numFramesFromScan = - streamMetadata.numFramesFromScan.value_or(0) + 1; + stream_metadata.num_frames_from_scan = + stream_metadata.num_frames_from_scan.value_or(0) + 1; // Note that we set the other value in this struct, nextPts, only after // we have scanned all packets and sorted by pts. - FrameInfo frameInfo = {packet->pts}; + FrameInfo frame_info = {packet->pts}; if (packet->flags & AV_PKT_FLAG_KEY) { - frameInfo.isKeyFrame = true; - streamInfos_[streamIndex].keyFrames.push_back(frameInfo); + frame_info.is_key_frame = true; + stream_infos_[stream_index].key_frames.push_back(frame_info); } - streamInfos_[streamIndex].allFrames.push_back(frameInfo); + stream_infos_[stream_index].all_frames.push_back(frame_info); } // Set all per-stream metadata that requires knowing the content of all // packets. - for (size_t streamIndex = 0; - streamIndex < containerMetadata_.allStreamMetadata.size(); + for (size_t stream_index = 0; + stream_index < container_metadata_.all_stream_metadata.size(); ++streamIndex) { - auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; - auto avStream = formatContext_->streams[streamIndex]; + auto& stream_metadata = + container_metadata_.all_stream_metadata[stream_index]; + auto av_stream = format_context_->streams[stream_index]; - streamMetadata.numFramesFromScan = - streamInfos_[streamIndex].allFrames.size(); + stream_metadata.num_frames_from_scan = + stream_infos_[stream_index].all_frames.size(); if (streamMetadata.minPtsFromScan.has_value()) { - streamMetadata.minPtsSecondsFromScan = - *streamMetadata.minPtsFromScan * av_q2d(avStream->time_base); + stream_metadata.min_pts_seconds_from_scan = + *streamMetadata.minPtsFromScan * av_q2d(av_stream->time_base); } if (streamMetadata.maxPtsFromScan.has_value()) { - streamMetadata.maxPtsSecondsFromScan = - *streamMetadata.maxPtsFromScan * av_q2d(avStream->time_base); + stream_metadata.max_pts_seconds_from_scan = + *streamMetadata.maxPtsFromScan * av_q2d(av_stream->time_base); } } // Reset the seek-cursor back to the beginning. - int status = avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0); + int status = avformat_seek_file(format_context_.get(), 0, INT64_MIN, 0, 0, 0); if (status < 0) { throw std::runtime_error( "Could not seek file to pts=0: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } // Sort all frames by their pts. - for (auto& [streamIndex, streamInfo] : streamInfos_) { + for (auto& [streamIndex, stream_info] : stream_infos_) { std::sort( - streamInfo.keyFrames.begin(), - streamInfo.keyFrames.end(), - [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) { - return frameInfo1.pts < frameInfo2.pts; + stream_info.key_frames.begin(), + stream_info.key_frames.end(), + [](const FrameInfo& frame_info1, const FrameInfo& frame_info2) { + return frame_info1.pts < frame_info2.pts; }); std::sort( - streamInfo.allFrames.begin(), - streamInfo.allFrames.end(), - [](const FrameInfo& frameInfo1, const FrameInfo& frameInfo2) { - return frameInfo1.pts < frameInfo2.pts; + stream_info.all_frames.begin(), + stream_info.all_frames.end(), + [](const FrameInfo& frame_info1, const FrameInfo& frame_info2) { + return frame_info1.pts < frame_info2.pts; }); - size_t keyFrameIndex = 0; - for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) { - streamInfo.allFrames[i].frameIndex = i; + size_t key_frame_index = 0; + for (size_t i = 0; i < stream_info.all_frames.size(); ++i) { + stream_info.all_frames[i].frame_index = i; if (streamInfo.allFrames[i].isKeyFrame) { TORCH_CHECK( - keyFrameIndex < streamInfo.keyFrames.size(), - "The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec."); - streamInfo.keyFrames[keyFrameIndex].frameIndex = i; + key_frame_index < stream_info.key_frames.size(), + "The all_frames vec claims it has MORE key_frames than the key_frames vec. There's a bug in torchcodec."); + stream_info.key_frames[key_frame_index].frame_index = i; ++keyFrameIndex; } - if (i + 1 < streamInfo.allFrames.size()) { - streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts; + if (i + 1 < stream_info.all_frames.size()) { + stream_info.all_frames[i].next_pts = stream_info.all_frames[i + 1].pts; } } TORCH_CHECK( - keyFrameIndex == streamInfo.keyFrames.size(), - "The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec."); + key_frame_index == stream_info.key_frames.size(), + "The all_frames vec claims it has LESS key_frames than the key_frames vec. There's a bug in torchcodec."); } - scannedAllStreams_ = true; + scanned_all_streams_ = true; } SingleStreamDecoder::ContainerMetadata SingleStreamDecoder::getContainerMetadata() const { - return containerMetadata_; + return container_metadata_; } torch::Tensor SingleStreamDecoder::getKeyFrameIndices() { - validateActiveStream(AVMEDIA_TYPE_VIDEO); - validateScannedAllStreams("getKeyFrameIndices"); + validate_active_stream(_avm_e_d_i_a__t_y_p_e__v_i_d_e_o); + validate_scanned_all_streams("get_key_frame_indices"); - const std::vector& keyFrames = - streamInfos_[activeStreamIndex_].keyFrames; - torch::Tensor keyFrameIndices = + const std::vector<_frame_info>& key_frames = + stream_infos_[active_stream_index_].key_frames; + torch::Tensor key_frame_indices = torch::empty({static_cast(keyFrames.size())}, {torch::kInt64}); - for (size_t i = 0; i < keyFrames.size(); ++i) { - keyFrameIndices[i] = keyFrames[i].frameIndex; + for (size_t i = 0; i < key_frames.size(); ++i) { + key_frame_indices[i] = key_frames[i].frame_index; } - return keyFrameIndices; + return key_frame_indices; } // -------------------------------------------------------------------------- @@ -387,118 +393,120 @@ torch::Tensor SingleStreamDecoder::getKeyFrameIndices() { // -------------------------------------------------------------------------- void SingleStreamDecoder::addStream( - int streamIndex, - AVMediaType mediaType, + int stream_index, + AVMediaType media_type, const torch::Device& device, - std::optional ffmpegThreadCount) { + std::optional ffmpeg_thread_count) { TORCH_CHECK( - activeStreamIndex_ == NO_ACTIVE_STREAM, + active_stream_index_ == NO_ACTIVE_STREAM, "Can only add one single stream."); TORCH_CHECK( - mediaType == AVMEDIA_TYPE_VIDEO || mediaType == AVMEDIA_TYPE_AUDIO, + media_type == AVMEDIA_TYPE_VIDEO || media_type == AVMEDIA_TYPE_AUDIO, "Can only add video or audio streams."); TORCH_CHECK(formatContext_.get() != nullptr); - AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; + AVCodecOnlyUseForCallingAVFindBestStream av_codec = nullptr; - activeStreamIndex_ = av_find_best_stream( - formatContext_.get(), mediaType, streamIndex, -1, &avCodec, 0); + active_stream_index_ = av_find_best_stream( + format_context_.get(), media_type, stream_index, -1, &avCodec, 0); if (activeStreamIndex_ < 0) { throw std::invalid_argument( "No valid stream found in input file. Is " + - std::to_string(streamIndex) + " of the desired media type?"); + std::to_string(stream_index) + " of the desired media type?"); } TORCH_CHECK(avCodec != nullptr); - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - streamInfo.streamIndex = activeStreamIndex_; - streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base; - streamInfo.stream = formatContext_->streams[activeStreamIndex_]; - streamInfo.avMediaType = mediaType; + StreamInfo& stream_info = stream_infos_[active_stream_index_]; + stream_info.stream_index = active_stream_index_; + stream_info.time_base = + format_context_->streams[active_stream_index_]->time_base; + stream_info.stream = format_context_->streams[active_stream_index_]; + stream_info.av_media_type = media_type; // This should never happen, checking just to be safe. TORCH_CHECK( - streamInfo.stream->codecpar->codec_type == mediaType, + stream_info.stream->codecpar->codec_type == media_type, "FFmpeg found stream with index ", - activeStreamIndex_, + active_stream_index_, " which is of the wrong media type."); // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) { - avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - findCudaCodec(device, streamInfo.stream->codecpar->codec_id) + av_codec = make_avcodec_only_use_for_calling_avfind_best_stream( + find_cuda_codec(device, stream_info.stream->codecpar->codec_id) .value_or(avCodec)); } - AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); + AVCodecContext* codec_context = avcodec_alloc_context3(av_codec); TORCH_CHECK(codecContext != nullptr); - streamInfo.codecContext.reset(codecContext); + stream_info.codec_context.reset(codec_context); - int retVal = avcodec_parameters_to_context( - streamInfo.codecContext.get(), streamInfo.stream->codecpar); + int ret_val = avcodec_parameters_to_context( + stream_info.codec_context.get(), stream_info.stream->codecpar); TORCH_CHECK_EQ(retVal, AVSUCCESS); - streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0); - streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base; + stream_info.codec_context->thread_count = ffmpeg_thread_count.value_or(0); + stream_info.codec_context->pkt_timebase = stream_info.stream->time_base; // TODO_CODE_QUALITY same as above. if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) { - initializeContextOnCuda(device, codecContext); + initialize_context_on_cuda(device, codec_context); } - retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); + ret_val = avcodec_open2(stream_info.codec_context.get(), av_codec, nullptr); if (retVal < AVSUCCESS) { - throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal)); + throw std::invalid_argument( + get_ffmpeg_error_string_from_error_code(ret_val)); } - codecContext->time_base = streamInfo.stream->time_base; - containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName = - std::string(avcodec_get_name(codecContext->codec_id)); + codec_context->time_base = stream_info.stream->time_base; + container_metadata_.all_stream_metadata[active_stream_index_].codec_name = + std::string(avcodec_get_name(codec_context->codec_id)); // We will only need packets from the active stream, so we tell FFmpeg to // discard packets from the other streams. Note that av_read_frame() may still // return some of those un-desired packet under some conditions, so it's still // important to discard/demux correctly in the inner decoding loop. - for (unsigned int i = 0; i < formatContext_->nb_streams; ++i) { - if (i != static_cast(activeStreamIndex_)) { - formatContext_->streams[i]->discard = AVDISCARD_ALL; + for (unsigned int i = 0; i < format_context_->nb_streams; ++i) { + if (i != static_cast(active_stream_index_)) { + format_context_->streams[i]->discard = AVDISCARD_ALL; } } } void SingleStreamDecoder::addVideoStream( - int streamIndex, - const VideoStreamOptions& videoStreamOptions) { + int stream_index, + const VideoStreamOptions& video_stream_options) { TORCH_CHECK( - videoStreamOptions.device.type() == torch::kCPU || - videoStreamOptions.device.type() == torch::kCUDA, - "Invalid device type: " + videoStreamOptions.device.str()); + video_stream_options.device.type() == torch::kCPU || + video_stream_options.device.type() == torch::kCUDA, + "Invalid device type: " + video_stream_options.device.str()); - addStream( - streamIndex, + add_stream( + stream_index, AVMEDIA_TYPE_VIDEO, - videoStreamOptions.device, - videoStreamOptions.ffmpegThreadCount); + video_stream_options.device, + video_stream_options.ffmpeg_thread_count); - auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; + auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; if (seekMode_ == SeekMode::approximate && !streamMetadata.averageFps.has_value()) { throw std::runtime_error( "Seek mode is approximate, but stream " + - std::to_string(activeStreamIndex_) + + std::to_string(active_stream_index_) + " does not have an average fps in its metadata."); } - auto& streamInfo = streamInfos_[activeStreamIndex_]; - streamInfo.videoStreamOptions = videoStreamOptions; + auto& stream_info = stream_infos_[active_stream_index_]; + stream_info.video_stream_options = video_stream_options; - streamMetadata.width = streamInfo.codecContext->width; - streamMetadata.height = streamInfo.codecContext->height; + stream_metadata.width = stream_info.codec_context->width; + stream_metadata.height = stream_info.codec_context->height; // By default, we want to use swscale for color conversion because it is // faster. However, it has width requirements, so we may need to fall back @@ -507,43 +515,44 @@ void SingleStreamDecoder::addVideoStream( // swscale's width requirements to be violated. We don't expose the ability to // choose color conversion library publicly; we only use this ability // internally. - int width = videoStreamOptions.width.value_or(streamInfo.codecContext->width); + int width = + video_stream_options.width.value_or(stream_info.codec_context->width); // swscale requires widths to be multiples of 32: // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements // so we fall back to filtergraph if the width is not a multiple of 32. - auto defaultLibrary = (width % 32 == 0) + auto default_library = (width % 32 == 0) ? SingleStreamDecoder::ColorConversionLibrary::SWSCALE : SingleStreamDecoder::ColorConversionLibrary::FILTERGRAPH; - streamInfo.colorConversionLibrary = - videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); + stream_info.color_conversion_library = + video_stream_options.color_conversion_library.value_or(default_library); } void SingleStreamDecoder::addAudioStream( - int streamIndex, - const AudioStreamOptions& audioStreamOptions) { + int stream_index, + const AudioStreamOptions& audio_stream_options) { TORCH_CHECK( - seekMode_ == SeekMode::approximate, + seek_mode_ == SeekMode::approximate, "seek_mode must be 'approximate' for audio streams."); - addStream(streamIndex, AVMEDIA_TYPE_AUDIO); + add_stream(stream_index, AVMEDIA_TYPE_AUDIO); - auto& streamInfo = streamInfos_[activeStreamIndex_]; - streamInfo.audioStreamOptions = audioStreamOptions; + auto& stream_info = stream_infos_[active_stream_index_]; + stream_info.audio_stream_options = audio_stream_options; - auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; - streamMetadata.sampleRate = - static_cast(streamInfo.codecContext->sample_rate); - streamMetadata.numChannels = - static_cast(getNumChannels(streamInfo.codecContext)); + auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; + stream_metadata.sample_rate = + static_cast(stream_info.codec_context->sample_rate); + stream_metadata.num_channels = + static_cast(get_num_channels(stream_info.codec_context)); // FFmpeg docs say that the decoder will try to decode natively in this // format, if it can. Docs don't say what the decoder does when it doesn't // support that format, but it looks like it does nothing, so this probably // doesn't hurt. - streamInfo.codecContext->request_sample_fmt = AV_SAMPLE_FMT_FLTP; + stream_info.codec_context->request_sample_fmt = AV_SAMPLE_FMT_FLTP; } // -------------------------------------------------------------------------- @@ -551,158 +560,162 @@ void SingleStreamDecoder::addAudioStream( // -------------------------------------------------------------------------- SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrame() { - auto output = getNextFrameInternal(); + auto output = get_next_frame_internal(); if (streamInfos_[activeStreamIndex_].avMediaType == AVMEDIA_TYPE_VIDEO) { - output.data = maybePermuteHWC2CHW(output.data); + output.data = maybe_permute_h_w_c2_c_h_w(output.data); } return output; } SingleStreamDecoder::FrameOutput SingleStreamDecoder::getNextFrameInternal( - std::optional preAllocatedOutputTensor) { - validateActiveStream(); - UniqueAVFrame avFrame = decodeAVFrame( - [this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; }); - return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor); + std::optional pre_allocated_output_tensor) { + validate_active_stream(); + UniqueAVFrame avframe = decode_avframe( + [this](const UniqueAVFrame& avframe) { return avframe->pts >= cursor_; }); + return convert_avframe_to_frame_output(avframe, pre_allocated_output_tensor); } SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndex( - int64_t frameIndex) { - auto frameOutput = getFrameAtIndexInternal(frameIndex); - frameOutput.data = maybePermuteHWC2CHW(frameOutput.data); - return frameOutput; + int64_t frame_index) { + auto frame_output = get_frame_at_index_internal(frame_index); + frame_output.data = maybe_permute_h_w_c2_c_h_w(frame_output.data); + return frame_output; } SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( - int64_t frameIndex, - std::optional preAllocatedOutputTensor) { - validateActiveStream(AVMEDIA_TYPE_VIDEO); + int64_t frame_index, + std::optional pre_allocated_output_tensor) { + validate_active_stream(_avm_e_d_i_a__t_y_p_e__v_i_d_e_o); - const auto& streamInfo = streamInfos_[activeStreamIndex_]; - const auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; - validateFrameIndex(streamMetadata, frameIndex); + const auto& stream_info = stream_infos_[active_stream_index_]; + const auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; + validate_frame_index(stream_metadata, frame_index); - int64_t pts = getPts(frameIndex); - setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); - return getNextFrameInternal(preAllocatedOutputTensor); + int64_t pts = get_pts(frame_index); + set_cursor_pts_in_seconds(pts_to_seconds(pts, stream_info.time_base)); + return get_next_frame_internal(pre_allocated_output_tensor); } SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( - const std::vector& frameIndices) { - validateActiveStream(AVMEDIA_TYPE_VIDEO); + const std::vector& frame_indices) { + validate_active_stream(_avm_e_d_i_a__t_y_p_e__v_i_d_e_o); - auto indicesAreSorted = - std::is_sorted(frameIndices.begin(), frameIndices.end()); + auto indices_are_sorted = + std::is_sorted(frame_indices.begin(), frame_indices.end()); std::vector argsort; if (!indicesAreSorted) { // if frameIndices is [13, 10, 12, 11] - // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want - // to use to decode the frames - // and argsort is [ 1, 3, 2, 0] - argsort.resize(frameIndices.size()); + // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want + // to use to decode the frames + // and argsort is [ 1, 3, 2, 0] + argsort.resize(frame_indices.size()); for (size_t i = 0; i < argsort.size(); ++i) { argsort[i] = i; } std::sort( argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) { - return frameIndices[a] < frameIndices[b]; + return frame_indices[a] < frame_indices[b]; }); } - const auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; - const auto& streamInfo = streamInfos_[activeStreamIndex_]; - const auto& videoStreamOptions = streamInfo.videoStreamOptions; - FrameBatchOutput frameBatchOutput( - frameIndices.size(), videoStreamOptions, streamMetadata); + const auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; + const auto& stream_info = stream_infos_[active_stream_index_]; + const auto& video_stream_options = stream_info.video_stream_options; + FrameBatchOutput frame_batch_output( + frame_indices.size(), video_stream_options, stream_metadata); - auto previousIndexInVideo = -1; - for (size_t f = 0; f < frameIndices.size(); ++f) { - auto indexInOutput = indicesAreSorted ? f : argsort[f]; - auto indexInVideo = frameIndices[indexInOutput]; + auto previous_index_in_video = -1; + for (size_t f = 0; f < frame_indices.size(); ++f) { + auto index_in_output = indices_are_sorted ? f : argsort[f]; + auto index_in_video = frame_indices[index_in_output]; - validateFrameIndex(streamMetadata, indexInVideo); + validate_frame_index(stream_metadata, index_in_video); - if ((f > 0) && (indexInVideo == previousIndexInVideo)) { + if ((f > 0) && (indexInVideo == previous_index_in_video)) { // Avoid decoding the same frame twice - auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; - frameBatchOutput.data[indexInOutput].copy_( - frameBatchOutput.data[previousIndexInOutput]); - frameBatchOutput.ptsSeconds[indexInOutput] = - frameBatchOutput.ptsSeconds[previousIndexInOutput]; - frameBatchOutput.durationSeconds[indexInOutput] = - frameBatchOutput.durationSeconds[previousIndexInOutput]; + auto previous_index_in_output = + indices_are_sorted ? f - 1 : argsort[f - 1]; + frame_batch_output.data[index_in_output].copy_( + frame_batch_output.data[previous_index_in_output]); + frame_batch_output.pts_seconds[index_in_output] = + frame_batch_output.pts_seconds[previous_index_in_output]; + frame_batch_output.duration_seconds[index_in_output] = + frame_batch_output.duration_seconds[previous_index_in_output]; } else { - FrameOutput frameOutput = getFrameAtIndexInternal( - indexInVideo, frameBatchOutput.data[indexInOutput]); - frameBatchOutput.ptsSeconds[indexInOutput] = frameOutput.ptsSeconds; - frameBatchOutput.durationSeconds[indexInOutput] = - frameOutput.durationSeconds; + FrameOutput frame_output = get_frame_at_index_internal( + index_in_video, frame_batch_output.data[index_in_output]); + frame_batch_output.pts_seconds[index_in_output] = + frame_output.pts_seconds; + frame_batch_output.duration_seconds[index_in_output] = + frame_output.duration_seconds; } - previousIndexInVideo = indexInVideo; + previous_index_in_video = index_in_video; } - frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); - return frameBatchOutput; + frame_batch_output.data = maybe_permute_h_w_c2_c_h_w(frame_batch_output.data); + return frame_batch_output; } SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesInRange( int64_t start, int64_t stop, int64_t step) { - validateActiveStream(AVMEDIA_TYPE_VIDEO); + validate_active_stream(_avm_e_d_i_a__t_y_p_e__v_i_d_e_o); - const auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; - const auto& streamInfo = streamInfos_[activeStreamIndex_]; - int64_t numFrames = getNumFrames(streamMetadata); + const auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; + const auto& stream_info = stream_infos_[active_stream_index_]; + int64_t num_frames = get_num_frames(stream_metadata); TORCH_CHECK( start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); TORCH_CHECK( - stop <= numFrames, + stop <= num_frames, "Range stop, " + std::to_string(stop) + - ", is more than the number of frames, " + std::to_string(numFrames)); + ", is more than the number of frames, " + std::to_string(num_frames)); TORCH_CHECK( step > 0, "Step must be greater than 0; is " + std::to_string(step)); - int64_t numOutputFrames = std::ceil((stop - start) / double(step)); - const auto& videoStreamOptions = streamInfo.videoStreamOptions; - FrameBatchOutput frameBatchOutput( - numOutputFrames, videoStreamOptions, streamMetadata); + int64_t num_output_frames = std::ceil((stop - start) / double(step)); + const auto& video_stream_options = stream_info.video_stream_options; + FrameBatchOutput frame_batch_output( + num_output_frames, video_stream_options, stream_metadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { - FrameOutput frameOutput = - getFrameAtIndexInternal(i, frameBatchOutput.data[f]); - frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds; - frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds; + FrameOutput frame_output = + get_frame_at_index_internal(i, frame_batch_output.data[f]); + frame_batch_output.pts_seconds[f] = frame_output.pts_seconds; + frame_batch_output.duration_seconds[f] = frame_output.duration_seconds; } - frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); - return frameBatchOutput; + frame_batch_output.data = maybe_permute_h_w_c2_c_h_w(frame_batch_output.data); + return frame_batch_output; } SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFramePlayedAt( double seconds) { - validateActiveStream(AVMEDIA_TYPE_VIDEO); - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - double frameStartTime = - ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase); - double frameEndTime = ptsToSeconds( - streamInfo.lastDecodedAvFramePts + streamInfo.lastDecodedAvFrameDuration, - streamInfo.timeBase); - if (seconds >= frameStartTime && seconds < frameEndTime) { + validate_active_stream(_avm_e_d_i_a__t_y_p_e__v_i_d_e_o); + StreamInfo& stream_info = stream_infos_[active_stream_index_]; + double frame_start_time = pts_to_seconds( + stream_info.last_decoded_avframe_pts, stream_info.time_base); + double frame_end_time = pts_to_seconds( + stream_info.last_decoded_avframe_pts + + stream_info.last_decoded_avframe_duration, + stream_info.time_base); + if (seconds >= frame_start_time && seconds < frame_end_time) { // We are in the same frame as the one we just returned. However, since we // don't cache it locally, we have to rewind back. - seconds = frameStartTime; + seconds = frame_start_time; } - setCursorPtsInSeconds(seconds); - UniqueAVFrame avFrame = - decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) { - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase); - double frameEndTime = ptsToSeconds( - avFrame->pts + getDuration(avFrame), streamInfo.timeBase); + set_cursor_pts_in_seconds(seconds); + UniqueAVFrame avframe = + decode_avframe([seconds, this](const UniqueAVFrame& avframe) { + StreamInfo& stream_info = stream_infos_[active_stream_index_]; + double frame_start_time = + pts_to_seconds(avframe->pts, stream_info.time_base); + double frame_end_time = pts_to_seconds( + avframe->pts + get_duration(avframe), stream_info.time_base); if (frameStartTime > seconds) { // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() @@ -713,71 +726,71 @@ SingleStreamDecoder::FrameOutput SingleStreamDecoder::getFramePlayedAt( // TODO: Maybe log to stderr for Debug builds? return true; } - return seconds >= frameStartTime && seconds < frameEndTime; + return seconds >= frame_start_time && seconds < frame_end_time; }); // Convert the frame to tensor. - FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrame); - frameOutput.data = maybePermuteHWC2CHW(frameOutput.data); - return frameOutput; + FrameOutput frame_output = convert_avframe_to_frame_output(avframe); + frame_output.data = maybe_permute_h_w_c2_c_h_w(frame_output.data); + return frame_output; } SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( const std::vector& timestamps) { - validateActiveStream(AVMEDIA_TYPE_VIDEO); + validate_active_stream(_avm_e_d_i_a__t_y_p_e__v_i_d_e_o); - const auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; + const auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; - double minSeconds = getMinSeconds(streamMetadata); - double maxSeconds = getMaxSeconds(streamMetadata); + double min_seconds = get_min_seconds(stream_metadata); + double max_seconds = get_max_seconds(stream_metadata); // The frame played at timestamp t and the one played at timestamp `t + // eps` are probably the same frame, with the same index. The easiest way to // avoid decoding that unique frame twice is to convert the input timestamps // to indices, and leverage the de-duplication logic of getFramesAtIndices. - std::vector frameIndices(timestamps.size()); + std::vector frame_indices(timestamps.size()); for (size_t i = 0; i < timestamps.size(); ++i) { - auto frameSeconds = timestamps[i]; + auto frame_seconds = timestamps[i]; TORCH_CHECK( - frameSeconds >= minSeconds && frameSeconds < maxSeconds, - "frame pts is " + std::to_string(frameSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); + frame_seconds >= min_seconds && frame_seconds < max_seconds, + "frame pts is " + std::to_string(frame_seconds) + + "; must be in range [" + std::to_string(min_seconds) + ", " + + std::to_string(max_seconds) + ")."); - frameIndices[i] = secondsToIndexLowerBound(frameSeconds); + frame_indices[i] = seconds_to_index_lower_bound(frame_seconds); } - return getFramesAtIndices(frameIndices); + return get_frames_at_indices(frame_indices); } SingleStreamDecoder::FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( - double startSeconds, - double stopSeconds) { - validateActiveStream(AVMEDIA_TYPE_VIDEO); - const auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; + double start_seconds, + double stop_seconds) { + validate_active_stream(_avm_e_d_i_a__t_y_p_e__v_i_d_e_o); + const auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; TORCH_CHECK( - startSeconds <= stopSeconds, - "Start seconds (" + std::to_string(startSeconds) + + start_seconds <= stop_seconds, + "Start seconds (" + std::to_string(start_seconds) + ") must be less than or equal to stop seconds (" + - std::to_string(stopSeconds) + "."); + std::to_string(stop_seconds) + "."); - const auto& streamInfo = streamInfos_[activeStreamIndex_]; - const auto& videoStreamOptions = streamInfo.videoStreamOptions; + const auto& stream_info = stream_infos_[active_stream_index_]; + const auto& video_stream_options = stream_info.video_stream_options; // Special case needed to implement a half-open range. At first glance, this // may seem unnecessary, as our search for stopFrame can return the end, and // we don't include stopFramIndex in our output. However, consider the // following scenario: // - // frame=0, pts=0.0 - // frame=1, pts=0.3 + // frame=0, pts=0.0 + // frame=1, pts=0.3 // - // interval A: [0.2, 0.2) - // interval B: [0.2, 0.15) + // interval A: [0.2, 0.2) + // interval B: [0.2, 0.15) // // Both intervals take place between the pts values for frame 0 and frame 1, // which by our abstract player, means that both intervals map to frame 0. By @@ -785,53 +798,55 @@ SingleStreamDecoder::getFramesPlayedInRange( // Interval B should return frame 0. However, for both A and B, the individual // values of the intervals will map to the same frame indices below. Hence, we // need this special case below. - if (startSeconds == stopSeconds) { - FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata); - frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); - return frameBatchOutput; + if (startSeconds == stop_seconds) { + FrameBatchOutput frame_batch_output( + 0, video_stream_options, stream_metadata); + frame_batch_output.data = + maybe_permute_h_w_c2_c_h_w(frame_batch_output.data); + return frame_batch_output; } - double minSeconds = getMinSeconds(streamMetadata); - double maxSeconds = getMaxSeconds(streamMetadata); + double min_seconds = get_min_seconds(stream_metadata); + double max_seconds = get_max_seconds(stream_metadata); TORCH_CHECK( - startSeconds >= minSeconds && startSeconds < maxSeconds, - "Start seconds is " + std::to_string(startSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); + start_seconds >= min_seconds && start_seconds < max_seconds, + "Start seconds is " + std::to_string(start_seconds) + + "; must be in range [" + std::to_string(min_seconds) + ", " + + std::to_string(max_seconds) + ")."); TORCH_CHECK( - stopSeconds <= maxSeconds, - "Stop seconds (" + std::to_string(stopSeconds) + - "; must be less than or equal to " + std::to_string(maxSeconds) + + stop_seconds <= max_seconds, + "Stop seconds (" + std::to_string(stop_seconds) + + "; must be less than or equal to " + std::to_string(max_seconds) + ")."); // Note that we look at nextPts for a frame, and not its pts or duration. // Our abstract player displays frames starting at the pts for that frame // until the pts for the next frame. There are two consequences: // - // 1. We ignore the duration for a frame. A frame is played until the - // next frame replaces it. This model is robust to durations being 0 or - // incorrect; our source of truth is the pts for frames. If duration is - // accurate, the nextPts for a frame would be equivalent to pts + - // duration. - // 2. In order to establish if the start of an interval maps to a - // particular frame, we need to figure out if it is ordered after the - // frame's pts, but before the next frames's pts. + // 1. We ignore the duration for a frame. A frame is played until the + // next frame replaces it. This model is robust to durations being 0 or + // incorrect; our source of truth is the pts for frames. If duration is + // accurate, the nextPts for a frame would be equivalent to pts + + // duration. + // 2. In order to establish if the start of an interval maps to a + // particular frame, we need to figure out if it is ordered after the + // frame's pts, but before the next frames's pts. - int64_t startFrameIndex = secondsToIndexLowerBound(startSeconds); - int64_t stopFrameIndex = secondsToIndexUpperBound(stopSeconds); - int64_t numFrames = stopFrameIndex - startFrameIndex; + int64_t start_frame_index = seconds_to_index_lower_bound(start_seconds); + int64_t stop_frame_index = seconds_to_index_upper_bound(stop_seconds); + int64_t num_frames = stop_frame_index - start_frame_index; - FrameBatchOutput frameBatchOutput( - numFrames, videoStreamOptions, streamMetadata); - for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - FrameOutput frameOutput = - getFrameAtIndexInternal(i, frameBatchOutput.data[f]); - frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds; - frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds; + FrameBatchOutput frame_batch_output( + num_frames, video_stream_options, stream_metadata); + for (int64_t i = start_frame_index, f = 0; i < stop_frame_index; ++i, ++f) { + FrameOutput frame_output = + get_frame_at_index_internal(i, frame_batch_output.data[f]); + frame_batch_output.pts_seconds[f] = frame_output.pts_seconds; + frame_batch_output.duration_seconds[f] = frame_output.duration_seconds; } - frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); + frame_batch_output.data = maybe_permute_h_w_c2_c_h_w(frame_batch_output.data); - return frameBatchOutput; + return frame_batch_output; } // Note [Audio Decoding Design] @@ -848,13 +863,13 @@ SingleStreamDecoder::getFramesPlayedInRange( // we want those to be close to FFmpeg concepts, but the higher-level public // APIs expose samples. As a result: // - We don't expose index-based APIs for audio, because that would mean -// exposing the concept of audio frame. For now, we think exposing time-based -// APIs is more natural. +// exposing the concept of audio frame. For now, we think exposing time-based +// APIs is more natural. // - We never perform a scan for audio streams. We don't need to, since we won't -// be converting timestamps to indices. That's why we enforce the seek_mode -// to be "approximate" (which is slightly misleading, because technically the -// output samples will be at their exact positions. But this incongruence is -// only exposed at the C++/core private levels). +// be converting timestamps to indices. That's why we enforce the seek_mode +// to be "approximate" (which is slightly misleading, because technically the +// output samples will be at their exact positions. But this incongruence is +// only exposed at the C++/core private levels). // // Audio frames are of variable dimensions: in the same stream, a frame can // contain 1024 samples and the next one may contain 512 [1]. This makes it @@ -874,46 +889,47 @@ SingleStreamDecoder::getFramesPlayedInRange( // or Decord do something similar, whether it was intended or not. This has a // few implications: // - The **only** place we're allowed to seek to in an audio stream is the -// stream's beginning. This ensures that if we need a frame, we'll have -// decoded all previous frames. +// stream's beginning. This ensures that if we need a frame, we'll have +// decoded all previous frames. // - Because of that, we don't allow the public APIs to seek. Public APIs can -// call next() and `getFramesPlayedInRangeAudio()`, but they cannot manually -// seek. +// call next() and `getFramesPlayedInRangeAudio()`, but they cannot manually +// seek. // - We try not to seek, when we can avoid it. Typically if the next frame we -// need is in the future, we don't seek back to the beginning, we just decode -// all the frames in-between. +// need is in the future, we don't seek back to the beginning, we just decode +// all the frames in-between. // // [2] If you're brave and curious, you can read the long "Seek offset for // audio" note in https://github.com/pytorch/torchcodec/pull/507/files, which // sums up past (and failed) attemps at working around this issue. SingleStreamDecoder::AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( - double startSeconds, - std::optional stopSecondsOptional) { - validateActiveStream(AVMEDIA_TYPE_AUDIO); + double start_seconds, + std::optional stop_seconds_optional) { + validate_active_stream(_avm_e_d_i_a__t_y_p_e__a_u_d_i_o); if (stopSecondsOptional.has_value()) { TORCH_CHECK( - startSeconds <= *stopSecondsOptional, - "Start seconds (" + std::to_string(startSeconds) + + start_seconds <= *stopSecondsOptional, + "Start seconds (" + std::to_string(start_seconds) + ") must be less than or equal to stop seconds (" + - std::to_string(*stopSecondsOptional) + ")."); + std::to_string(*stop_seconds_optional) + ")."); } - if (stopSecondsOptional.has_value() && startSeconds == *stopSecondsOptional) { + if (stopSecondsOptional.has_value() && + start_seconds == *stopSecondsOptional) { // For consistency with video return AudioFramesOutput{torch::empty({0, 0}), 0.0}; } - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + StreamInfo& stream_info = stream_infos_[active_stream_index_]; - auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase); - if (startPts < streamInfo.lastDecodedAvFramePts + - streamInfo.lastDecodedAvFrameDuration) { + auto start_pts = seconds_to_closest_pts(start_seconds, stream_info.time_base); + if (startPts < stream_info.last_decoded_avframe_pts + + stream_info.last_decoded_avframe_duration) { // If we need to seek backwards, then we have to seek back to the beginning // of the stream. // See [Audio Decoding Design]. - setCursor(INT64_MIN); + set_cursor(_i_n_t64__m_i_n); } // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec + @@ -921,22 +937,22 @@ SingleStreamDecoder::getFramesPlayedInRangeAudio( // sample rate, so in theory we know the number of output samples. std::vector frames; - std::optional firstFramePtsSeconds = std::nullopt; - auto stopPts = stopSecondsOptional.has_value() - ? secondsToClosestPts(*stopSecondsOptional, streamInfo.timeBase) + std::optional first_frame_pts_seconds = std::nullopt; + auto stop_pts = stop_seconds_optional.has_value() + ? seconds_to_closest_pts(*stop_seconds_optional, stream_info.time_base) : INT64_MAX; auto finished = false; while (!finished) { try { - UniqueAVFrame avFrame = - decodeAVFrame([startPts](const UniqueAVFrame& avFrame) { - return startPts < avFrame->pts + getDuration(avFrame); + UniqueAVFrame avframe = + decode_avframe([start_pts](const UniqueAVFrame& avframe) { + return start_pts < avframe->pts + get_duration(avframe); }); - auto frameOutput = convertAVFrameToFrameOutput(avFrame); + auto frame_output = convert_avframe_to_frame_output(avframe); if (!firstFramePtsSeconds.has_value()) { - firstFramePtsSeconds = frameOutput.ptsSeconds; + first_frame_pts_seconds = frame_output.pts_seconds; } - frames.push_back(frameOutput.data); + frames.push_back(frame_output.data); } catch (const EndOfFileException& e) { finished = true; } @@ -945,23 +961,23 @@ SingleStreamDecoder::getFramesPlayedInRangeAudio( // stop decoding more frames. Note that if we were to use [begin, end), // which may seem more natural, then we would decode the frame starting at // stopSeconds, which isn't what we want! - auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts + - streamInfo.lastDecodedAvFrameDuration; - finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts && - (stopPts <= lastDecodedAvFrameEnd); + auto last_decoded_avframe_end = stream_info.last_decoded_avframe_pts + + stream_info.last_decoded_avframe_duration; + finished |= (streamInfo.lastDecodedAvFramePts) <= stop_pts && + (stopPts <= last_decoded_avframe_end); } - auto lastSamples = maybeFlushSwrBuffers(); + auto last_samples = maybe_flush_swr_buffers(); if (lastSamples.has_value()) { - frames.push_back(*lastSamples); + frames.push_back(*last_samples); } TORCH_CHECK( - frames.size() > 0 && firstFramePtsSeconds.has_value(), + frames.size() > 0 && first_frame_pts_seconds.has_value(), "No audio frames were decoded. ", "This is probably because start_seconds is too high? ", "Current value is ", - startSeconds); + start_seconds); return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds}; } @@ -973,18 +989,18 @@ SingleStreamDecoder::getFramesPlayedInRangeAudio( void SingleStreamDecoder::setCursorPtsInSeconds(double seconds) { // We don't allow public audio decoding APIs to seek, see [Audio Decoding // Design] - validateActiveStream(AVMEDIA_TYPE_VIDEO); - setCursor( - secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase)); + validate_active_stream(_avm_e_d_i_a__t_y_p_e__v_i_d_e_o); + set_cursor(seconds_to_closest_pts( + seconds, stream_infos_[active_stream_index_].time_base)); } void SingleStreamDecoder::setCursor(int64_t pts) { - cursorWasJustSet_ = true; + cursor_was_just_set_ = true; cursor_ = pts; } /* -Videos have I frames and non-I frames (P and B frames). Non-I frames need data +Videos have I frames and non-_i frames (P and B frames). Non-I frames need data from the previous I frame to be decoded. Imagine the cursor is at a random frame with PTS=lastDecodedAvFramePts (x for @@ -997,28 +1013,28 @@ If y > x, we have two choices: 1. We could keep decoding forward until we hit y. Illustrated below: -I P P P I P P P I P P I P P I P - x y +I P P P I P P P I P P I P P I P +x y 2. We could try to jump to an I frame between x and y (indicated by j below). And then start decoding until we encounter y. Illustrated below: -I P P P I P P P I P P I P P I P - x j y +I P P P I P P P I P P I P P I P +x j y (2) is more efficient than (1) if there is an I frame between x and y. */ bool SingleStreamDecoder::canWeAvoidSeeking() const { - const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); + const StreamInfo& stream_info = stream_infos_.at(active_stream_index_); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { // For audio, we only need to seek if a backwards seek was requested within // getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called. // For more context, see [Audio Decoding Design] return !cursorWasJustSet_; } - int64_t lastDecodedAvFramePts = - streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts; - if (cursor_ < lastDecodedAvFramePts) { + int64_t last_decoded_avframe_pts = + stream_infos_.at(active_stream_index_).last_decoded_avframe_pts; + if (cursor_ < last_decoded_avframe_pts) { // We can never skip a seek if we are seeking backwards. return false; } @@ -1032,26 +1048,27 @@ bool SingleStreamDecoder::canWeAvoidSeeking() const { // We are seeking forwards. // We can only skip a seek if both lastDecodedAvFramePts and // cursor_ share the same keyframe. - int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts); - int targetKeyFrameIndex = getKeyFrameIndexForPts(cursor_); - return lastDecodedAvFrameIndex >= 0 && targetKeyFrameIndex >= 0 && - lastDecodedAvFrameIndex == targetKeyFrameIndex; + int last_decoded_avframe_index = + get_key_frame_index_for_pts(last_decoded_avframe_pts); + int target_key_frame_index = get_key_frame_index_for_pts(cursor_); + return last_decoded_avframe_index >= 0 && target_key_frame_index >= 0 && + last_decoded_avframe_index == target_key_frame_index; } // This method looks at currentPts and desiredPts and seeks in the // AVFormatContext if it is needed. We can skip seeking in certain cases. See // the comment of canWeAvoidSeeking() for details. void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { - validateActiveStream(); - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + validate_active_stream(); + StreamInfo& stream_info = stream_infos_[active_stream_index_]; - decodeStats_.numSeeksAttempted++; + decode_stats_.num_seeks_attempted++; if (canWeAvoidSeeking()) { - decodeStats_.numSeeksSkipped++; + decode_stats_.num_seeks_skipped++; return; } - int64_t desiredPts = cursor_; + int64_t desired_pts = cursor_; // For some encodings like H265, FFMPEG sometimes seeks past the point we // set as the max_ts. So we use our own index to give it the exact pts of @@ -1059,26 +1076,27 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { // See https://github.com/pytorch/torchcodec/issues/179 for more details. // See https://trac.ffmpeg.org/ticket/11137 for the underlying ffmpeg bug. if (!streamInfo.keyFrames.empty()) { - int desiredKeyFrameIndex = getKeyFrameIndexForPtsUsingScannedIndex( - streamInfo.keyFrames, desiredPts); - desiredKeyFrameIndex = std::max(desiredKeyFrameIndex, 0); - desiredPts = streamInfo.keyFrames[desiredKeyFrameIndex].pts; + int desired_key_frame_index = + get_key_frame_index_for_pts_using_scanned_index( + stream_info.key_frames, desired_pts); + desired_key_frame_index = std::max(desired_key_frame_index, 0); + desired_pts = stream_info.key_frames[desired_key_frame_index].pts; } int status = avformat_seek_file( - formatContext_.get(), - streamInfo.streamIndex, + format_context_.get(), + stream_info.stream_index, INT64_MIN, - desiredPts, - desiredPts, + desired_pts, + desired_pts, 0); if (status < 0) { throw std::runtime_error( - "Could not seek file to pts=" + std::to_string(desiredPts) + ": " + - getFFMPEGErrorStringFromErrorCode(status)); + "Could not seek file to pts=" + std::to_string(desired_pts) + ": " + + get_ffmpeg_error_string_from_error_code(status)); } - decodeStats_.numFlushes++; - avcodec_flush_buffers(streamInfo.codecContext.get()); + decode_stats_.num_flushes++; + avcodec_flush_buffers(stream_info.codec_context.get()); } // -------------------------------------------------------------------------- @@ -1086,35 +1104,35 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { // -------------------------------------------------------------------------- UniqueAVFrame SingleStreamDecoder::decodeAVFrame( - std::function filterFunction) { - validateActiveStream(); + std::function filter_function) { + validate_active_stream(); - resetDecodeStats(); + reset_decode_stats(); if (cursorWasJustSet_) { - maybeSeekToBeforeDesiredPts(); - cursorWasJustSet_ = false; + maybe_seek_to_before_desired_pts(); + cursor_was_just_set_ = false; } - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + StreamInfo& stream_info = stream_infos_[active_stream_index_]; // Need to get the next frame or error from PopFrame. - UniqueAVFrame avFrame(av_frame_alloc()); - AutoAVPacket autoAVPacket; + UniqueAVFrame avframe(avframe_alloc()); + AutoAVPacket auto_avpacket; int status = AVSUCCESS; - bool reachedEOF = false; + bool reached_e_o_f = false; while (true) { status = - avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get()); + avcodec_receive_frame(stream_info.codec_context.get(), avframe.get()); if (status != AVSUCCESS && status != AVERROR(EAGAIN)) { // Non-retriable error break; } - decodeStats_.numFramesReceivedByDecoder++; + decode_stats_.num_frames_received_by_decoder++; // Is this the kind of frame we're looking for? - if (status == AVSUCCESS && filterFunction(avFrame)) { + if (status == AVSUCCESS && filter_function(avframe)) { // Yes, this is the frame we'll return; break out of the decoding loop. break; } else if (status == AVSUCCESS) { @@ -1134,33 +1152,33 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // We still haven't found the frame we're looking for. So let's read more // packets and send them to the decoder. - ReferenceAVPacket packet(autoAVPacket); + ReferenceAVPacket packet(auto_avpacket); do { - status = av_read_frame(formatContext_.get(), packet.get()); - decodeStats_.numPacketsRead++; + status = av_read_frame(format_context_.get(), packet.get()); + decode_stats_.num_packets_read++; if (status == AVERROR_EOF) { // End of file reached. We must drain the codec by sending a nullptr // packet. status = avcodec_send_packet( - streamInfo.codecContext.get(), + stream_info.codec_context.get(), /*avpkt=*/nullptr); if (status < AVSUCCESS) { throw std::runtime_error( "Could not flush decoder: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } - reachedEOF = true; + reached_e_o_f = true; break; } if (status < AVSUCCESS) { throw std::runtime_error( "Could not read frame from input file: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } - } while (packet->stream_index != activeStreamIndex_); + } while (packet->stream_index != active_stream_index_); if (reachedEOF) { // We don't have any more packets to send to the decoder. So keep on @@ -1170,14 +1188,14 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // We got a valid packet. Send it to the decoder, and we'll receive it in // the next iteration. - status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); + status = avcodec_send_packet(stream_info.codec_context.get(), packet.get()); if (status < AVSUCCESS) { throw std::runtime_error( "Could not push packet to decoder: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } - decodeStats_.numPacketsSentToDecoder++; + decode_stats_.num_packets_sent_to_decoder++; } if (status < AVSUCCESS) { @@ -1188,7 +1206,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( } throw std::runtime_error( "Could not receive frame from decoder: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } // Note that we don't flush the decoder when we reach EOF (even though that's @@ -1197,10 +1215,10 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // haven't received as frames. Eventually we will either hit AVERROR_EOF from // av_receive_frame() or the user will have seeked to a different location in // the file and that will flush the decoder. - streamInfo.lastDecodedAvFramePts = avFrame->pts; - streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame); + stream_info.last_decoded_avframe_pts = avframe->pts; + stream_info.last_decoded_avframe_duration = get_duration(avframe); - return avFrame; + return avframe; } // -------------------------------------------------------------------------- @@ -1209,34 +1227,35 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( SingleStreamDecoder::FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( - UniqueAVFrame& avFrame, - std::optional preAllocatedOutputTensor) { + UniqueAVFrame& avframe, + std::optional pre_allocated_output_tensor) { // Convert the frame to tensor. - FrameOutput frameOutput; - auto& streamInfo = streamInfos_[activeStreamIndex_]; - frameOutput.ptsSeconds = ptsToSeconds( - avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base); - frameOutput.durationSeconds = ptsToSeconds( - getDuration(avFrame), - formatContext_->streams[activeStreamIndex_]->time_base); + FrameOutput frame_output; + auto& stream_info = stream_infos_[active_stream_index_]; + frame_output.pts_seconds = pts_to_seconds( + avframe->pts, format_context_->streams[active_stream_index_]->time_base); + frame_output.duration_seconds = pts_to_seconds( + get_duration(avframe), + format_context_->streams[active_stream_index_]->time_base); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { - convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); + convert_audio_avframe_to_frame_output_on_c_p_u(avframe, frame_output); } else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { - convertAVFrameToFrameOutputOnCPU( - avFrame, frameOutput, preAllocatedOutputTensor); + convert_avframe_to_frame_output_on_c_p_u( + avframe, frame_output, pre_allocated_output_tensor); } else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) { - convertAVFrameToFrameOutputOnCuda( - streamInfo.videoStreamOptions.device, - streamInfo.videoStreamOptions, - avFrame, - frameOutput, - preAllocatedOutputTensor); + convert_avframe_to_frame_output_on_cuda( + stream_info.video_stream_options.device, + stream_info.video_stream_options, + avframe, + frame_output, + pre_allocated_output_tensor); } else { TORCH_CHECK( false, - "Invalid device type: " + streamInfo.videoStreamOptions.device.str()); + "Invalid device type: " + + stream_info.video_stream_options.device.str()); } - return frameOutput; + return frame_output; } // Note [preAllocatedOutputTensor with swscale and filtergraph]: @@ -1249,30 +1268,30 @@ SingleStreamDecoder::convertAVFrameToFrameOutput( // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. void SingleStreamDecoder::convertAVFrameToFrameOutputOnCPU( - UniqueAVFrame& avFrame, - FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor) { - auto& streamInfo = streamInfos_[activeStreamIndex_]; + UniqueAVFrame& avframe, + FrameOutput& frame_output, + std::optional pre_allocated_output_tensor) { + auto& stream_info = stream_infos_[active_stream_index_]; - auto frameDims = getHeightAndWidthFromOptionsOrAVFrame( - streamInfo.videoStreamOptions, avFrame); - int expectedOutputHeight = frameDims.height; - int expectedOutputWidth = frameDims.width; + auto frame_dims = get_height_and_width_from_options_or_avframe( + stream_info.video_stream_options, avframe); + int expected_output_height = frame_dims.height; + int expected_output_width = frame_dims.width; if (preAllocatedOutputTensor.has_value()) { - auto shape = preAllocatedOutputTensor.value().sizes(); + auto shape = pre_allocated_output_tensor.value().sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == expectedOutputHeight) && - (shape[1] == expectedOutputWidth) && (shape[2] == 3), + (shape.size() == 3) && (shape[0] == expected_output_height) && + (shape[1] == expected_output_width) && (shape[2] == 3), "Expected pre-allocated tensor of shape ", - expectedOutputHeight, + expected_output_height, "x", - expectedOutputWidth, + expected_output_width, "x3, got ", shape); } - torch::Tensor outputTensor; + torch::Tensor output_tensor; // We need to compare the current frame context with our previous frame // context. If they are different, then we need to re-create our colorspace // conversion objects. We create our colorspace conversion objects late so @@ -1280,203 +1299,207 @@ void SingleStreamDecoder::convertAVFrameToFrameOutputOnCPU( // And we sometimes re-create them because it's possible for frame // resolution to change mid-stream. Finally, we want to reuse the colorspace // conversion objects as much as possible for performance reasons. - enum AVPixelFormat frameFormat = + enum AVPixelFormat frame_format = static_cast(avFrame->format); - auto frameContext = DecodedFrameContext{ - avFrame->width, - avFrame->height, - frameFormat, - expectedOutputWidth, - expectedOutputHeight}; + auto frame_context = DecodedFrameContext{ + avframe->width, + avframe->height, + frame_format, + expected_output_width, + expected_output_height}; if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( - expectedOutputHeight, expectedOutputWidth, torch::kCPU)); - - if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) { - createSwsContext(streamInfo, frameContext, avFrame->colorspace); - streamInfo.prevFrameContext = frameContext; + output_tensor = + pre_allocated_output_tensor.value_or(allocate_empty_h_w_c_tensor( + expected_output_height, expected_output_width, torch::kCPU)); + + if (!streamInfo.swsContext || + stream_info.prev_frame_context != frame_context) { + create_sws_context(stream_info, frame_context, avframe->colorspace); + stream_info.prev_frame_context = frame_context; } - int resultHeight = - convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); + int result_height = + convert_avframe_to_tensor_using_sws_scale(avframe, output_tensor); // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. // TODO: Can we do the same check for width? TORCH_CHECK( - resultHeight == expectedOutputHeight, - "resultHeight != expectedOutputHeight: ", - resultHeight, + result_height == expected_output_height, + "resultHeight != expected_output_height: ", + result_height, " != ", - expectedOutputHeight); + expected_output_height); - frameOutput.data = outputTensor; + frame_output.data = output_tensor; } else if ( - streamInfo.colorConversionLibrary == + stream_info.color_conversion_library == ColorConversionLibrary::FILTERGRAPH) { if (!streamInfo.filterGraphContext.filterGraph || - streamInfo.prevFrameContext != frameContext) { - createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth); - streamInfo.prevFrameContext = frameContext; + stream_info.prev_frame_context != frame_context) { + create_filter_graph( + stream_info, expected_output_height, expected_output_width); + stream_info.prev_frame_context = frame_context; } - outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); + output_tensor = convert_avframe_to_tensor_using_filter_graph(avframe); // Similarly to above, if this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. - auto shape = outputTensor.sizes(); + auto shape = output_tensor.sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == expectedOutputHeight) && - (shape[1] == expectedOutputWidth) && (shape[2] == 3), + (shape.size() == 3) && (shape[0] == expected_output_height) && + (shape[1] == expected_output_width) && (shape[2] == 3), "Expected output tensor of shape ", - expectedOutputHeight, + expected_output_height, "x", - expectedOutputWidth, + expected_output_width, "x3, got ", shape); if (preAllocatedOutputTensor.has_value()) { // We have already validated that preAllocatedOutputTensor and // outputTensor have the same shape. - preAllocatedOutputTensor.value().copy_(outputTensor); - frameOutput.data = preAllocatedOutputTensor.value(); + pre_allocated_output_tensor.value().copy_(output_tensor); + frame_output.data = pre_allocated_output_tensor.value(); } else { - frameOutput.data = outputTensor; + frame_output.data = output_tensor; } } else { throw std::runtime_error( "Invalid color conversion library: " + - std::to_string(static_cast(streamInfo.colorConversionLibrary))); + std::to_string(static_cast(stream_info.color_conversion_library))); } } int SingleStreamDecoder::convertAVFrameToTensorUsingSwsScale( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor) { - StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_]; - SwsContext* swsContext = activeStreamInfo.swsContext.get(); + const UniqueAVFrame& avframe, + torch::Tensor& output_tensor) { + StreamInfo& active_stream_info = stream_infos_[active_stream_index_]; + SwsContext* sws_context = active_stream_info.sws_context.get(); uint8_t* pointers[4] = { - outputTensor.data_ptr(), nullptr, nullptr, nullptr}; - int expectedOutputWidth = outputTensor.sizes()[1]; + output_tensor.data_ptr(), nullptr, nullptr, nullptr}; + int expected_output_width = output_tensor.sizes()[1]; int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0}; - int resultHeight = sws_scale( - swsContext, - avFrame->data, - avFrame->linesize, + int result_height = sws_scale( + sws_context, + avframe->data, + avframe->linesize, 0, - avFrame->height, + avframe->height, pointers, linesizes); - return resultHeight; + return result_height; } torch::Tensor SingleStreamDecoder::convertAVFrameToTensorUsingFilterGraph( - const UniqueAVFrame& avFrame) { - FilterGraphContext& filterGraphContext = - streamInfos_[activeStreamIndex_].filterGraphContext; - int status = - av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame.get()); + const UniqueAVFrame& avframe) { + FilterGraphContext& filter_graph_context = + stream_infos_[active_stream_index_].filter_graph_context; + int status = av_buffersrc_write_frame( + filter_graph_context.source_context, avframe.get()); if (status < AVSUCCESS) { - throw std::runtime_error("Failed to add frame to buffer source context"); + throw std::runtime_error("_failed to add frame to buffer source context"); } - UniqueAVFrame filteredAVFrame(av_frame_alloc()); + UniqueAVFrame filtered_avframe(avframe_alloc()); status = av_buffersink_get_frame( - filterGraphContext.sinkContext, filteredAVFrame.get()); + filter_graph_context.sink_context, filtered_avframe.get()); TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); - auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); - int height = frameDims.height; - int width = frameDims.width; + auto frame_dims = + get_height_and_width_from_resized_avframe(*filtered_avframe.get()); + int height = frame_dims.height; + int width = frame_dims.width; std::vector shape = {height, width, 3}; std::vector strides = {filteredAVFrame->linesize[0], 3, 1}; - AVFrame* filteredAVFramePtr = filteredAVFrame.release(); + AVFrame* filtered_avframe_ptr = filtered_avframe.release(); auto deleter = [filteredAVFramePtr](void*) { - UniqueAVFrame avFrameToDelete(filteredAVFramePtr); + UniqueAVFrame avframe_to_delete(filtered_avframe_ptr); }; return torch::from_blob( - filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); + filtered_avframe_ptr->data[0], shape, strides, deleter, {torch::kUInt8}); } void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( - UniqueAVFrame& srcAVFrame, - FrameOutput& frameOutput) { - AVSampleFormat sourceSampleFormat = - static_cast(srcAVFrame->format); - AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP; - - int sourceSampleRate = srcAVFrame->sample_rate; - int desiredSampleRate = - streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or( - sourceSampleRate); - - bool mustConvert = - (sourceSampleFormat != desiredSampleFormat || - sourceSampleRate != desiredSampleRate); - - UniqueAVFrame convertedAVFrame; + UniqueAVFrame& src_avframe, + FrameOutput& frame_output) { + AVSampleFormat source_sample_format = + static_cast<_avsample_format>(src_avframe->format); + AVSampleFormat desired_sample_format = AV_SAMPLE_FMT_FLTP; + + int source_sample_rate = src_avframe->sample_rate; + int desired_sample_rate = + stream_infos_[active_stream_index_] + .audio_stream_options.sample_rate.value_or(source_sample_rate); + + bool must_convert = + (sourceSampleFormat != desired_sample_format || + source_sample_rate != desired_sample_rate); + + UniqueAVFrame converted_avframe; if (mustConvert) { - convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate( - srcAVFrame, - sourceSampleFormat, - desiredSampleFormat, - sourceSampleRate, - desiredSampleRate); + converted_avframe = convert_audio_avframe_sample_format_and_sample_rate( + src_avframe, + source_sample_format, + desired_sample_format, + source_sample_rate, + desired_sample_rate); } - const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; + const UniqueAVFrame& avframe = must_convert ? converted_avframe : src_avframe; - AVSampleFormat format = static_cast(avFrame->format); + AVSampleFormat format = static_cast<_avsample_format>(avframe->format); TORCH_CHECK( - format == desiredSampleFormat, + format == desired_sample_format, "Something went wrong, the frame didn't get converted to the desired format. ", "Desired format = ", - av_get_sample_fmt_name(desiredSampleFormat), + av_get_sample_fmt_name(desired_sample_format), "source format = ", av_get_sample_fmt_name(format)); - auto numSamples = avFrame->nb_samples; // per channel - auto numChannels = getNumChannels(avFrame); + auto num_samples = avframe->nb_samples; // per channel + auto num_channels = get_num_channels(avframe); - frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32); + frame_output.data = torch::empty({numChannels, num_samples}, torch::kFloat32); if (numSamples > 0) { - uint8_t* outputChannelData = - static_cast(frameOutput.data.data_ptr()); - auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format); - for (auto channel = 0; channel < numChannels; - ++channel, outputChannelData += numBytesPerChannel) { + uint8_t* output_channel_data = + static_cast(frame_output.data.data_ptr()); + auto num_bytes_per_channel = num_samples * av_get_bytes_per_sample(format); + for (auto channel = 0; channel < num_channels; + ++channel, output_channel_data += num_bytes_per_channel) { memcpy( - outputChannelData, - avFrame->extended_data[channel], - numBytesPerChannel); + output_channel_data, + avframe->extended_data[channel], + num_bytes_per_channel); } } } UniqueAVFrame SingleStreamDecoder::convertAudioAVFrameSampleFormatAndSampleRate( - const UniqueAVFrame& srcAVFrame, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate) { - auto& streamInfo = streamInfos_[activeStreamIndex_]; + const UniqueAVFrame& src_avframe, + AVSampleFormat source_sample_format, + AVSampleFormat desired_sample_format, + int source_sample_rate, + int desired_sample_rate) { + auto& stream_info = stream_infos_[active_stream_index_]; if (!streamInfo.swrContext) { - createSwrContext( - streamInfo, - sourceSampleFormat, - desiredSampleFormat, - sourceSampleRate, - desiredSampleRate); + create_swr_context( + stream_info, + source_sample_format, + desired_sample_format, + source_sample_rate, + desired_sample_rate); } - UniqueAVFrame convertedAVFrame(av_frame_alloc()); + UniqueAVFrame converted_avframe(avframe_alloc()); TORCH_CHECK( - convertedAVFrame, + converted_avframe, "Could not allocate frame for sample format conversion."); - setChannelLayout(convertedAVFrame, srcAVFrame); - convertedAVFrame->format = static_cast(desiredSampleFormat); - convertedAVFrame->sample_rate = desiredSampleRate; - if (sourceSampleRate != desiredSampleRate) { + set_channel_layout(converted_avframe, src_avframe); + converted_avframe->format = static_cast(desired_sample_format); + converted_avframe->sample_rate = desired_sample_rate; + if (sourceSampleRate != desired_sample_rate) { // Note that this is an upper bound on the number of output samples. // `swr_convert()` will likely not fill convertedAVFrame with that many // samples if sample rate conversion is needed. It will buffer the last few @@ -1485,41 +1508,41 @@ UniqueAVFrame SingleStreamDecoder::convertAudioAVFrameSampleFormatAndSampleRate( // We could also use `swr_get_out_samples()` to determine the number of // output samples, but empirically `av_rescale_rnd()` seems to provide a // tighter bound. - convertedAVFrame->nb_samples = av_rescale_rnd( - swr_get_delay(streamInfo.swrContext.get(), sourceSampleRate) + - srcAVFrame->nb_samples, - desiredSampleRate, - sourceSampleRate, + converted_avframe->nb_samples = av_rescale_rnd( + swr_get_delay(stream_info.swr_context.get(), source_sample_rate) + + src_avframe->nb_samples, + desired_sample_rate, + source_sample_rate, AV_ROUND_UP); } else { - convertedAVFrame->nb_samples = srcAVFrame->nb_samples; + converted_avframe->nb_samples = src_avframe->nb_samples; } - auto status = av_frame_get_buffer(convertedAVFrame.get(), 0); + auto status = avframe_get_buffer(converted_avframe.get(), 0); TORCH_CHECK( status == AVSUCCESS, "Could not allocate frame buffers for sample format conversion: ", - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); - auto numConvertedSamples = swr_convert( - streamInfo.swrContext.get(), - convertedAVFrame->data, - convertedAVFrame->nb_samples, + auto num_converted_samples = swr_convert( + stream_info.swr_context.get(), + converted_avframe->data, + converted_avframe->nb_samples, static_cast( - const_cast(srcAVFrame->data)), - srcAVFrame->nb_samples); + const_cast(src_avframe->data)), + src_avframe->nb_samples); // numConvertedSamples can be 0 if we're downsampling by a great factor and // the first frame doesn't contain a lot of samples. It should be handled // properly by the caller. TORCH_CHECK( - numConvertedSamples >= 0, + num_converted_samples >= 0, "Error in swr_convert: ", - getFFMPEGErrorStringFromErrorCode(numConvertedSamples)); + get_ffmpeg_error_string_from_error_code(num_converted_samples)); // See comment above about nb_samples - convertedAVFrame->nb_samples = numConvertedSamples; + converted_avframe->nb_samples = num_converted_samples; - return convertedAVFrame; + return converted_avframe; } std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { @@ -1528,34 +1551,34 @@ std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // That's because the last few samples in a given frame require future samples // from the next frame to be properly converted. This function flushes out the // samples that are stored in swresample's buffers. - auto& streamInfo = streamInfos_[activeStreamIndex_]; + auto& stream_info = stream_infos_[active_stream_index_]; if (!streamInfo.swrContext) { return std::nullopt; } - auto numRemainingSamples = // this is an upper bound - swr_get_out_samples(streamInfo.swrContext.get(), 0); + auto num_remaining_samples = // this is an upper bound + swr_get_out_samples(stream_info.swr_context.get(), 0); if (numRemainingSamples == 0) { return std::nullopt; } - auto numChannels = getNumChannels(streamInfo.codecContext); - torch::Tensor lastSamples = - torch::empty({numChannels, numRemainingSamples}, torch::kFloat32); + auto num_channels = get_num_channels(stream_info.codec_context); + torch::Tensor last_samples = + torch::empty({numChannels, num_remaining_samples}, torch::kFloat32); - std::vector outputBuffers(numChannels); - for (auto i = 0; i < numChannels; i++) { - outputBuffers[i] = static_cast(lastSamples[i].data_ptr()); + std::vector output_buffers(num_channels); + for (auto i = 0; i < num_channels; i++) { + output_buffers[i] = static_cast(last_samples[i].data_ptr()); } - auto actualNumRemainingSamples = swr_convert( - streamInfo.swrContext.get(), - outputBuffers.data(), - numRemainingSamples, + auto actual_num_remaining_samples = swr_convert( + stream_info.swr_context.get(), + output_buffers.data(), + num_remaining_samples, nullptr, 0); - return lastSamples.narrow( + return last_samples.narrow( /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); } @@ -1564,37 +1587,39 @@ std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // -------------------------------------------------------------------------- SingleStreamDecoder::FrameBatchOutput::FrameBatchOutput( - int64_t numFrames, - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata) + int64_t num_frames, + const VideoStreamOptions& video_stream_options, + const StreamMetadata& stream_metadata) : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { - auto frameDims = getHeightAndWidthFromOptionsOrMetadata( - videoStreamOptions, streamMetadata); - int height = frameDims.height; - int width = frameDims.width; - data = allocateEmptyHWCTensor( - height, width, videoStreamOptions.device, numFrames); + auto frame_dims = get_height_and_width_from_options_or_metadata( + video_stream_options, stream_metadata); + int height = frame_dims.height; + int width = frame_dims.width; + data = allocate_empty_h_w_c_tensor( + height, width, video_stream_options.device, num_frames); } -torch::Tensor allocateEmptyHWCTensor( +torch::Tensor allocate_empty_h_w_c_tensor( int height, int width, torch::Device device, - std::optional numFrames) { - auto tensorOptions = torch::TensorOptions() - .dtype(torch::kUInt8) - .layout(torch::kStrided) - .device(device); + std::optional num_frames) { + auto tensor_options = torch::TensorOptions() + .dtype(torch::kUInt8) + .layout(torch::kStrided) + .device(device); TORCH_CHECK(height > 0, "height must be > 0, got: ", height); TORCH_CHECK(width > 0, "width must be > 0, got: ", width); if (numFrames.has_value()) { - auto numFramesValue = numFrames.value(); + auto num_frames_value = num_frames.value(); TORCH_CHECK( - numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue); - return torch::empty({numFramesValue, height, width, 3}, tensorOptions); + num_frames_value >= 0, + "numFrames must be >= 0, got: ", + num_frames_value); + return torch::empty({numFramesValue, height, width, 3}, tensor_options); } else { - return torch::empty({height, width, 3}, tensorOptions); + return torch::empty({height, width, 3}, tensor_options); } } @@ -1604,22 +1629,22 @@ torch::Tensor allocateEmptyHWCTensor( // Calling permute() is guaranteed to return a view as per the docs: // https://pytorch.org/docs/stable/generated/torch.permute.html torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( - torch::Tensor& hwcTensor) { + torch::Tensor& hwc_tensor) { if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder == "NHWC") { - return hwcTensor; + return hwc_tensor; } - auto numDimensions = hwcTensor.dim(); - auto shape = hwcTensor.sizes(); + auto num_dimensions = hwc_tensor.dim(); + auto shape = hwc_tensor.sizes(); if (numDimensions == 3) { TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape); - return hwcTensor.permute({2, 0, 1}); + return hwc_tensor.permute({2, 0, 1}); } else if (numDimensions == 4) { TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape); - return hwcTensor.permute({0, 3, 1, 2}); + return hwc_tensor.permute({0, 3, 1, 2}); } else { TORCH_CHECK( - false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions); + false, "Expected tensor with 3 or 4 dimensions, got ", num_dimensions); } } @@ -1629,11 +1654,11 @@ torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( bool SingleStreamDecoder::DecodedFrameContext::operator==( const SingleStreamDecoder::DecodedFrameContext& other) { - return decodedWidth == other.decodedWidth && - decodedHeight == other.decodedHeight && - decodedFormat == other.decodedFormat && - expectedWidth == other.expectedWidth && - expectedHeight == other.expectedHeight; + return decoded_width == other.decoded_width && + decoded_height == other.decoded_height && + decoded_format == other.decoded_format && + expected_width == other.expected_width && + expected_height == other.expected_height; } bool SingleStreamDecoder::DecodedFrameContext::operator!=( @@ -1642,42 +1667,42 @@ bool SingleStreamDecoder::DecodedFrameContext::operator!=( } void SingleStreamDecoder::createFilterGraph( - StreamInfo& streamInfo, - int expectedOutputHeight, - int expectedOutputWidth) { - FilterGraphContext& filterGraphContext = streamInfo.filterGraphContext; - filterGraphContext.filterGraph.reset(avfilter_graph_alloc()); + StreamInfo& stream_info, + int expected_output_height, + int expected_output_width) { + FilterGraphContext& filter_graph_context = stream_info.filter_graph_context; + filter_graph_context.filter_graph.reset(avfilter_graph_alloc()); TORCH_CHECK(filterGraphContext.filterGraph.get() != nullptr); if (streamInfo.videoStreamOptions.ffmpegThreadCount.has_value()) { - filterGraphContext.filterGraph->nb_threads = - streamInfo.videoStreamOptions.ffmpegThreadCount.value(); + filter_graph_context.filter_graph->nb_threads = + stream_info.video_stream_options.ffmpeg_thread_count.value(); } const AVFilter* buffersrc = avfilter_get_by_name("buffer"); const AVFilter* buffersink = avfilter_get_by_name("buffersink"); - AVCodecContext* codecContext = streamInfo.codecContext.get(); + AVCodecContext* codec_context = stream_info.codec_context.get(); - std::stringstream filterArgs; - filterArgs << "video_size=" << codecContext->width << "x" - << codecContext->height; - filterArgs << ":pix_fmt=" << codecContext->pix_fmt; - filterArgs << ":time_base=" << streamInfo.stream->time_base.num << "/" - << streamInfo.stream->time_base.den; - filterArgs << ":pixel_aspect=" << codecContext->sample_aspect_ratio.num << "/" - << codecContext->sample_aspect_ratio.den; + std::stringstream filter_args; + filter_args << "video_size=" << codec_context->width << "x" + << codec_context->height; + filter_args << ":pix_fmt=" << codec_context->pix_fmt; + filter_args << ":time_base=" << stream_info.stream->time_base.num << "/" + << stream_info.stream->time_base.den; + filter_args << ":pixel_aspect=" << codec_context->sample_aspect_ratio.num + << "/" << codec_context->sample_aspect_ratio.den; int status = avfilter_graph_create_filter( &filterGraphContext.sourceContext, buffersrc, "in", - filterArgs.str().c_str(), + filter_args.str().c_str(), nullptr, - filterGraphContext.filterGraph.get()); + filter_graph_context.filter_graph.get()); if (status < 0) { throw std::runtime_error( - std::string("Failed to create filter graph: ") + filterArgs.str() + - ": " + getFFMPEGErrorStringFromErrorCode(status)); + std::string("_failed to create filter graph: ") + filter_args.str() + + ": " + get_ffmpeg_error_string_from_error_code(status)); } status = avfilter_graph_create_filter( @@ -1686,17 +1711,17 @@ void SingleStreamDecoder::createFilterGraph( "out", nullptr, nullptr, - filterGraphContext.filterGraph.get()); + filter_graph_context.filter_graph.get()); if (status < 0) { throw std::runtime_error( "Failed to create filter graph: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; status = av_opt_set_int_list( - filterGraphContext.sinkContext, + filter_graph_context.sink_context, "pix_fmts", pix_fmts, AV_PIX_FMT_NONE, @@ -1704,59 +1729,61 @@ void SingleStreamDecoder::createFilterGraph( if (status < 0) { throw std::runtime_error( "Failed to set output pixel formats: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } UniqueAVFilterInOut outputs(avfilter_inout_alloc()); UniqueAVFilterInOut inputs(avfilter_inout_alloc()); outputs->name = av_strdup("in"); - outputs->filter_ctx = filterGraphContext.sourceContext; + outputs->filter_ctx = filter_graph_context.source_context; outputs->pad_idx = 0; outputs->next = nullptr; inputs->name = av_strdup("out"); - inputs->filter_ctx = filterGraphContext.sinkContext; + inputs->filter_ctx = filter_graph_context.sink_context; inputs->pad_idx = 0; inputs->next = nullptr; std::stringstream description; - description << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; + description << "scale=" << expected_output_width << ":" + << expected_output_height; description << ":sws_flags=bilinear"; - AVFilterInOut* outputsTmp = outputs.release(); - AVFilterInOut* inputsTmp = inputs.release(); + AVFilterInOut* outputs_tmp = outputs.release(); + AVFilterInOut* inputs_tmp = inputs.release(); status = avfilter_graph_parse_ptr( - filterGraphContext.filterGraph.get(), + filter_graph_context.filter_graph.get(), description.str().c_str(), &inputsTmp, &outputsTmp, nullptr); - outputs.reset(outputsTmp); - inputs.reset(inputsTmp); + outputs.reset(outputs_tmp); + inputs.reset(inputs_tmp); if (status < 0) { throw std::runtime_error( "Failed to parse filter description: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } - status = avfilter_graph_config(filterGraphContext.filterGraph.get(), nullptr); + status = + avfilter_graph_config(filter_graph_context.filter_graph.get(), nullptr); if (status < 0) { throw std::runtime_error( "Failed to configure filter graph: " + - getFFMPEGErrorStringFromErrorCode(status)); + get_ffmpeg_error_string_from_error_code(status)); } } void SingleStreamDecoder::createSwsContext( - StreamInfo& streamInfo, - const DecodedFrameContext& frameContext, + StreamInfo& stream_info, + const DecodedFrameContext& frame_context, const enum AVColorSpace colorspace) { - SwsContext* swsContext = sws_getContext( - frameContext.decodedWidth, - frameContext.decodedHeight, - frameContext.decodedFormat, - frameContext.expectedWidth, - frameContext.expectedHeight, + SwsContext* sws_context = sws_get_context( + frame_context.decoded_width, + frame_context.decoded_height, + frame_context.decoded_format, + frame_context.expected_width, + frame_context.expected_height, AV_PIX_FMT_RGB24, SWS_BILINEAR, nullptr, @@ -1764,11 +1791,11 @@ void SingleStreamDecoder::createSwsContext( nullptr); TORCH_CHECK(swsContext, "sws_getContext() returned nullptr"); - int* invTable = nullptr; + int* inv_table = nullptr; int* table = nullptr; - int srcRange, dstRange, brightness, contrast, saturation; - int ret = sws_getColorspaceDetails( - swsContext, + int src_range, dst_range, brightness, contrast, saturation; + int ret = sws_get_colorspace_details( + sws_context, &invTable, &srcRange, &table, @@ -1778,43 +1805,43 @@ void SingleStreamDecoder::createSwsContext( &saturation); TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1"); - const int* colorspaceTable = sws_getCoefficients(colorspace); - ret = sws_setColorspaceDetails( - swsContext, - colorspaceTable, - srcRange, - colorspaceTable, - dstRange, + const int* colorspace_table = sws_get_coefficients(colorspace); + ret = sws_set_colorspace_details( + sws_context, + colorspace_table, + src_range, + colorspace_table, + dst_range, brightness, contrast, saturation); TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1"); - streamInfo.swsContext.reset(swsContext); + stream_info.sws_context.reset(sws_context); } void SingleStreamDecoder::createSwrContext( - StreamInfo& streamInfo, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate) { - auto swrContext = allocateSwrContext( - streamInfo.codecContext, - sourceSampleFormat, - desiredSampleFormat, - sourceSampleRate, - desiredSampleRate); - - auto status = swr_init(swrContext); + StreamInfo& stream_info, + AVSampleFormat source_sample_format, + AVSampleFormat desired_sample_format, + int source_sample_rate, + int desired_sample_rate) { + auto swr_context = allocate_swr_context( + stream_info.codec_context, + source_sample_format, + desired_sample_format, + source_sample_rate, + desired_sample_rate); + + auto status = swr_init(swr_context); TORCH_CHECK( status == AVSUCCESS, "Couldn't initialize SwrContext: ", - getFFMPEGErrorStringFromErrorCode(status), + get_ffmpeg_error_string_from_error_code(status), ". If the error says 'Invalid argument', it's likely that you are using " "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " "valid scenarios. Try to upgrade FFmpeg?"); - streamInfo.swrContext.reset(swrContext); + stream_info.swr_context.reset(swr_context); } // -------------------------------------------------------------------------- @@ -1822,101 +1849,104 @@ void SingleStreamDecoder::createSwrContext( // -------------------------------------------------------------------------- int SingleStreamDecoder::getKeyFrameIndexForPts(int64_t pts) const { - const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); + const StreamInfo& stream_info = stream_infos_.at(active_stream_index_); if (streamInfo.keyFrames.empty()) { return av_index_search_timestamp( - streamInfo.stream, pts, AVSEEK_FLAG_BACKWARD); + stream_info.stream, pts, AVSEEK_FLAG_BACKWARD); } else { - return getKeyFrameIndexForPtsUsingScannedIndex(streamInfo.keyFrames, pts); + return get_key_frame_index_for_pts_using_scanned_index( + stream_info.key_frames, pts); } } int SingleStreamDecoder::getKeyFrameIndexForPtsUsingScannedIndex( - const std::vector& keyFrames, + const std::vector<_single_stream_decoder::_frame_info>& key_frames, int64_t pts) const { - auto upperBound = std::upper_bound( - keyFrames.begin(), - keyFrames.end(), + auto upper_bound = std::upper_bound( + key_frames.begin(), + key_frames.end(), pts, - [](int64_t pts, const SingleStreamDecoder::FrameInfo& frameInfo) { - return pts < frameInfo.pts; + [](int64_t pts, const SingleStreamDecoder::FrameInfo& frame_info) { + return pts < frame_info.pts; }); - if (upperBound == keyFrames.begin()) { + if (upperBound == key_frames.begin()) { return -1; } - return upperBound - 1 - keyFrames.begin(); + return upper_bound - 1 - key_frames.begin(); } int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { - auto& streamInfo = streamInfos_[activeStreamIndex_]; + auto& stream_info = stream_infos_[active_stream_index_]; switch (seekMode_) { case SeekMode::exact: { auto frame = std::lower_bound( - streamInfo.allFrames.begin(), - streamInfo.allFrames.end(), + stream_info.all_frames.begin(), + stream_info.all_frames.end(), seconds, [&streamInfo](const FrameInfo& info, double start) { - return ptsToSeconds(info.nextPts, streamInfo.timeBase) <= start; + return pts_to_seconds(info.next_pts, stream_info.time_base) <= + start; }); - return frame - streamInfo.allFrames.begin(); + return frame - stream_info.all_frames.begin(); } case SeekMode::approximate: { - auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; + auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; TORCH_CHECK( - streamMetadata.averageFps.has_value(), + stream_metadata.average_fps.has_value(), "Cannot use approximate mode since we couldn't find the average fps from the metadata."); - return std::floor(seconds * streamMetadata.averageFps.value()); + return std::floor(seconds * stream_metadata.average_fps.value()); } default: - throw std::runtime_error("Unknown SeekMode"); + throw std::runtime_error("_unknown SeekMode"); } } int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { - auto& streamInfo = streamInfos_[activeStreamIndex_]; + auto& stream_info = stream_infos_[active_stream_index_]; switch (seekMode_) { case SeekMode::exact: { auto frame = std::upper_bound( - streamInfo.allFrames.begin(), - streamInfo.allFrames.end(), + stream_info.all_frames.begin(), + stream_info.all_frames.end(), seconds, [&streamInfo](double stop, const FrameInfo& info) { - return stop <= ptsToSeconds(info.pts, streamInfo.timeBase); + return stop <= pts_to_seconds(info.pts, stream_info.time_base); }); - return frame - streamInfo.allFrames.begin(); + return frame - stream_info.all_frames.begin(); } case SeekMode::approximate: { - auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; + auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; TORCH_CHECK( - streamMetadata.averageFps.has_value(), + stream_metadata.average_fps.has_value(), "Cannot use approximate mode since we couldn't find the average fps from the metadata."); - return std::ceil(seconds * streamMetadata.averageFps.value()); + return std::ceil(seconds * stream_metadata.average_fps.value()); } default: - throw std::runtime_error("Unknown SeekMode"); + throw std::runtime_error("_unknown SeekMode"); } } -int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { - auto& streamInfo = streamInfos_[activeStreamIndex_]; +int64_t SingleStreamDecoder::getPts(int64_t frame_index) { + auto& stream_info = stream_infos_[active_stream_index_]; switch (seekMode_) { case SeekMode::exact: - return streamInfo.allFrames[frameIndex].pts; + return stream_info.all_frames[frame_index].pts; case SeekMode::approximate: { - auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; + auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; TORCH_CHECK( - streamMetadata.averageFps.has_value(), + stream_metadata.average_fps.has_value(), "Cannot use approximate mode since we couldn't find the average fps from the metadata."); - return secondsToClosestPts( - frameIndex / streamMetadata.averageFps.value(), streamInfo.timeBase); + return seconds_to_closest_pts( + frame_index / stream_metadata.average_fps.value(), + stream_info.time_base); } default: - throw std::runtime_error("Unknown SeekMode"); + throw std::runtime_error("_unknown SeekMode"); } } @@ -1925,46 +1955,46 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { // -------------------------------------------------------------------------- int64_t SingleStreamDecoder::getNumFrames( - const StreamMetadata& streamMetadata) { + const StreamMetadata& stream_metadata) { switch (seekMode_) { case SeekMode::exact: - return streamMetadata.numFramesFromScan.value(); + return stream_metadata.num_frames_from_scan.value(); case SeekMode::approximate: { TORCH_CHECK( - streamMetadata.numFrames.has_value(), + stream_metadata.num_frames.has_value(), "Cannot use approximate mode since we couldn't find the number of frames from the metadata."); - return streamMetadata.numFrames.value(); + return stream_metadata.num_frames.value(); } default: - throw std::runtime_error("Unknown SeekMode"); + throw std::runtime_error("_unknown SeekMode"); } } double SingleStreamDecoder::getMinSeconds( - const StreamMetadata& streamMetadata) { + const StreamMetadata& stream_metadata) { switch (seekMode_) { case SeekMode::exact: - return streamMetadata.minPtsSecondsFromScan.value(); + return stream_metadata.min_pts_seconds_from_scan.value(); case SeekMode::approximate: return 0; default: - throw std::runtime_error("Unknown SeekMode"); + throw std::runtime_error("_unknown SeekMode"); } } double SingleStreamDecoder::getMaxSeconds( - const StreamMetadata& streamMetadata) { + const StreamMetadata& stream_metadata) { switch (seekMode_) { case SeekMode::exact: - return streamMetadata.maxPtsSecondsFromScan.value(); + return stream_metadata.max_pts_seconds_from_scan.value(); case SeekMode::approximate: { TORCH_CHECK( - streamMetadata.durationSeconds.has_value(), + stream_metadata.duration_seconds.has_value(), "Cannot use approximate mode since we couldn't find the duration from the metadata."); - return streamMetadata.durationSeconds.value(); + return stream_metadata.duration_seconds.value(); } default: - throw std::runtime_error("Unknown SeekMode"); + throw std::runtime_error("_unknown SeekMode"); } } @@ -1973,24 +2003,26 @@ double SingleStreamDecoder::getMaxSeconds( // -------------------------------------------------------------------------- void SingleStreamDecoder::validateActiveStream( - std::optional avMediaType) { - auto errorMsg = - "Provided stream index=" + std::to_string(activeStreamIndex_) + + std::optional<_avmedia_type> av_media_type) { + auto error_msg = + "Provided stream index=" + std::to_string(active_stream_index_) + " was not previously added."; - TORCH_CHECK(activeStreamIndex_ != NO_ACTIVE_STREAM, errorMsg); - TORCH_CHECK(streamInfos_.count(activeStreamIndex_) > 0, errorMsg); + TORCH_CHECK(activeStreamIndex_ != NO_ACTIVE_STREAM, error_msg); + TORCH_CHECK(streamInfos_.count(activeStreamIndex_) > 0, error_msg); - int allStreamMetadataSize = - static_cast(containerMetadata_.allStreamMetadata.size()); + int all_stream_metadata_size = + static_cast(container_metadata_.all_stream_metadata.size()); TORCH_CHECK( - activeStreamIndex_ >= 0 && activeStreamIndex_ < allStreamMetadataSize, - "Invalid stream index=" + std::to_string(activeStreamIndex_) + + active_stream_index_ >= 0 && + active_stream_index_ < all_stream_metadata_size, + "Invalid stream index=" + std::to_string(active_stream_index_) + "; valid indices are in the range [0, " + - std::to_string(allStreamMetadataSize) + ")."); + std::to_string(all_stream_metadata_size) + ")."); if (avMediaType.has_value()) { TORCH_CHECK( - streamInfos_[activeStreamIndex_].avMediaType == avMediaType.value(), + stream_infos_[active_stream_index_].av_media_type == + av_media_type.value(), "The method you called isn't supported. ", "If you're seeing this error, you are probably trying to call an ", "unsupported method on an audio stream."); @@ -2005,14 +2037,14 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) { } void SingleStreamDecoder::validateFrameIndex( - const StreamMetadata& streamMetadata, - int64_t frameIndex) { - int64_t numFrames = getNumFrames(streamMetadata); + const StreamMetadata& stream_metadata, + int64_t frame_index) { + int64_t num_frames = get_num_frames(stream_metadata); TORCH_CHECK( - frameIndex >= 0 && frameIndex < numFrames, - "Invalid frame index=" + std::to_string(frameIndex) + - " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + - " numFrames=" + std::to_string(numFrames)); + frame_index >= 0 && frame_index < num_frames, + "Invalid frame index=" + std::to_string(frame_index) + + " for stream_index=" + std::to_string(stream_metadata.stream_index) + + " num_frames=" + std::to_string(num_frames)); } // -------------------------------------------------------------------------- @@ -2020,71 +2052,73 @@ void SingleStreamDecoder::validateFrameIndex( // -------------------------------------------------------------------------- SingleStreamDecoder::DecodeStats SingleStreamDecoder::getDecodeStats() const { - return decodeStats_; + return decode_stats_; } std::ostream& operator<<( std::ostream& os, const SingleStreamDecoder::DecodeStats& stats) { os << "DecodeStats{" - << "numFramesReceivedByDecoder=" << stats.numFramesReceivedByDecoder - << ", numPacketsRead=" << stats.numPacketsRead - << ", numPacketsSentToDecoder=" << stats.numPacketsSentToDecoder - << ", numSeeksAttempted=" << stats.numSeeksAttempted - << ", numSeeksSkipped=" << stats.numSeeksSkipped - << ", numFlushes=" << stats.numFlushes << "}"; + << "numFramesReceivedByDecoder=" << stats.num_frames_received_by_decoder + << ", num_packets_read=" << stats.num_packets_read + << ", num_packets_sent_to_decoder=" << stats.num_packets_sent_to_decoder + << ", num_seeks_attempted=" << stats.num_seeks_attempted + << ", num_seeks_skipped=" << stats.num_seeks_skipped + << ", num_flushes=" << stats.num_flushes << "}"; return os; } void SingleStreamDecoder::resetDecodeStats() { - decodeStats_ = DecodeStats{}; + decode_stats_ = DecodeStats{}; } -double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) { - validateActiveStream(AVMEDIA_TYPE_VIDEO); - validateScannedAllStreams("getPtsSecondsForFrame"); +double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frame_index) { + validate_active_stream(_avm_e_d_i_a__t_y_p_e__v_i_d_e_o); + validate_scanned_all_streams("get_pts_seconds_for_frame"); - const auto& streamInfo = streamInfos_[activeStreamIndex_]; - const auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; - validateFrameIndex(streamMetadata, frameIndex); + const auto& stream_info = stream_infos_[active_stream_index_]; + const auto& stream_metadata = + container_metadata_.all_stream_metadata[active_stream_index_]; + validate_frame_index(stream_metadata, frame_index); - return ptsToSeconds( - streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase); + return pts_to_seconds( + stream_info.all_frames[frame_index].pts, stream_info.time_base); } // -------------------------------------------------------------------------- // FrameDims APIs // -------------------------------------------------------------------------- -FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) { - return FrameDims(resizedAVFrame.height, resizedAVFrame.width); +FrameDims get_height_and_width_from_resized_avframe( + const AVFrame& resized_avframe) { + return FrameDims(resizedAVFrame.height, resized_avframe.width); } -FrameDims getHeightAndWidthFromOptionsOrMetadata( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, - const SingleStreamDecoder::StreamMetadata& streamMetadata) { +FrameDims get_height_and_width_from_options_or_metadata( + const SingleStreamDecoder::VideoStreamOptions& video_stream_options, + const SingleStreamDecoder::StreamMetadata& stream_metadata) { return FrameDims( - videoStreamOptions.height.value_or(*streamMetadata.height), - videoStreamOptions.width.value_or(*streamMetadata.width)); + video_stream_options.height.value_or(*stream_metadata.height), + video_stream_options.width.value_or(*stream_metadata.width)); } -FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, - const UniqueAVFrame& avFrame) { +FrameDims get_height_and_width_from_options_or_avframe( + const SingleStreamDecoder::VideoStreamOptions& video_stream_options, + const UniqueAVFrame& avframe) { return FrameDims( - videoStreamOptions.height.value_or(avFrame->height), - videoStreamOptions.width.value_or(avFrame->width)); + video_stream_options.height.value_or(avframe->height), + video_stream_options.width.value_or(avframe->width)); } -SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) { +SingleStreamDecoder::SeekMode seek_mode_from_string( + std::string_view seek_mode) { if (seekMode == "exact") { return SingleStreamDecoder::SeekMode::exact; } else if (seekMode == "approximate") { return SingleStreamDecoder::SeekMode::approximate; } else { - TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode)); + TORCH_CHECK(false, "Invalid seek mode: " + std::string(seek_mode)); } } diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index df333d03..92152701 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -32,16 +32,16 @@ class SingleStreamDecoder { // Creates a SingleStreamDecoder from the video at videoFilePath. explicit SingleStreamDecoder( - const std::string& videoFilePath, - SeekMode seekMode = SeekMode::exact); + const std::string& video_file_path, + SeekMode seek_mode = SeekMode::exact); // Creates a SingleStreamDecoder using the provided AVIOContext inside the // AVIOContextHolder. The AVIOContextHolder is the base class, and the // derived class will have specialized how the custom read, seek and writes // work. explicit SingleStreamDecoder( - std::unique_ptr context, - SeekMode seekMode = SeekMode::exact); + std::unique_ptr context, + SeekMode seek_mode = SeekMode::exact); // -------------------------------------------------------------------------- // VIDEO METADATA QUERY API @@ -50,64 +50,64 @@ class SingleStreamDecoder { // Updates the metadata of the video to accurate values obtained by scanning // the contents of the video file. Also updates each StreamInfo's index, i.e. // the allFrames and keyFrames vectors. - void scanFileAndUpdateMetadataAndIndex(); + void scan_file_and_update_metadata_and_index(); struct StreamMetadata { // Common (video and audio) fields derived from the AVStream. - int streamIndex; + int stream_index; // See this link for what various values are available: // https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48 - AVMediaType mediaType; - std::optional codecId; - std::optional codecName; - std::optional durationSeconds; - std::optional beginStreamFromHeader; - std::optional numFrames; - std::optional numKeyFrames; - std::optional averageFps; - std::optional bitRate; + AVMediaType media_type; + std::optional<_avcodec_i_d> codec_id; + std::optional codec_name; + std::optional duration_seconds; + std::optional begin_stream_from_header; + std::optional num_frames; + std::optional num_key_frames; + std::optional average_fps; + std::optional bit_rate; // More accurate duration, obtained by scanning the file. // These presentation timestamps are in time base. - std::optional minPtsFromScan; - std::optional maxPtsFromScan; + std::optional min_pts_from_scan; + std::optional max_pts_from_scan; // These presentation timestamps are in seconds. - std::optional minPtsSecondsFromScan; - std::optional maxPtsSecondsFromScan; + std::optional min_pts_seconds_from_scan; + std::optional max_pts_seconds_from_scan; // This can be useful for index-based seeking. - std::optional numFramesFromScan; + std::optional num_frames_from_scan; // Video-only fields derived from the AVCodecContext. std::optional width; std::optional height; // Audio-only fields - std::optional sampleRate; - std::optional numChannels; - std::optional sampleFormat; + std::optional sample_rate; + std::optional num_channels; + std::optional sample_format; }; struct ContainerMetadata { - std::vector allStreamMetadata; - int numAudioStreams = 0; - int numVideoStreams = 0; + std::vector<_stream_metadata> all_stream_metadata; + int num_audio_streams = 0; + int num_video_streams = 0; // Note that this is the container-level duration, which is usually the max // of all stream durations available in the container. - std::optional durationSeconds; + std::optional duration_seconds; // Total BitRate level information at the container level in bit/s - std::optional bitRate; + std::optional bit_rate; // If set, this is the index to the default audio stream. - std::optional bestAudioStreamIndex; + std::optional best_audio_stream_index; // If set, this is the index to the default video stream. - std::optional bestVideoStreamIndex; + std::optional best_video_stream_index; }; // Returns the metadata for the container. - ContainerMetadata getContainerMetadata() const; + ContainerMetadata get_container_metadata() const; // Returns the key frame indices as a tensor. The tensor is 1D and contains // int64 values, where each value is the frame index for a key frame. - torch::Tensor getKeyFrameIndices(); + torch::Tensor get_key_frame_indices(); // -------------------------------------------------------------------------- // ADDING STREAMS API @@ -128,15 +128,15 @@ class SingleStreamDecoder { // 0 means FFMPEG will choose the number of threads automatically to fully // utilize all cores. If not set, it will be the default FFMPEG behavior for // the given codec. - std::optional ffmpegThreadCount; + std::optional ffmpeg_thread_count; // Currently the dimension order can be either NHWC or NCHW. // H=height, W=width, C=channel. - std::string dimensionOrder = "NCHW"; + std::string dimension_order = "NCHW"; // The output height and width of the frame. If not specified, the output // is the same as the original video. std::optional width; std::optional height; - std::optional colorConversionLibrary; + std::optional<_color_conversion_library> color_conversion_library; // By default we use CPU for decoding for both C++ and python users. torch::Device device = torch::kCPU; }; @@ -144,15 +144,15 @@ class SingleStreamDecoder { struct AudioStreamOptions { AudioStreamOptions() {} - std::optional sampleRate; + std::optional sample_rate; }; - void addVideoStream( - int streamIndex, - const VideoStreamOptions& videoStreamOptions = VideoStreamOptions()); - void addAudioStream( - int streamIndex, - const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); + void add_video_stream( + int stream_index, + const VideoStreamOptions& video_stream_options = VideoStreamOptions()); + void add_audio_stream( + int stream_index, + const AudioStreamOptions& audio_stream_options = AudioStreamOptions()); // -------------------------------------------------------------------------- // DECODING AND SEEKING APIs @@ -170,45 +170,47 @@ class SingleStreamDecoder { // - 3D (C, H, W) or (H, W, C) for videos // - 2D (numChannels, numSamples) for audio torch::Tensor data; - double ptsSeconds; - double durationSeconds; + double pts_seconds; + double duration_seconds; }; struct FrameBatchOutput { torch::Tensor data; // 4D: of shape NCHW or NHWC. - torch::Tensor ptsSeconds; // 1D of shape (N,) - torch::Tensor durationSeconds; // 1D of shape (N,) + torch::Tensor pts_seconds; // 1D of shape (N,) + torch::Tensor duration_seconds; // 1D of shape (N,) explicit FrameBatchOutput( - int64_t numFrames, - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata); + int64_t num_frames, + const VideoStreamOptions& video_stream_options, + const StreamMetadata& stream_metadata); }; struct AudioFramesOutput { - torch::Tensor data; // shape is (numChannels, numSamples) - double ptsSeconds; + torch::Tensor data; // shape is (numChannels, num_samples) + double pts_seconds; }; // Places the cursor at the first frame on or after the position in seconds. // Calling getNextFrame() will return the first frame at // or after this position. - void setCursorPtsInSeconds(double seconds); + void set_cursor_pts_in_seconds(double seconds); // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. - FrameOutput getNextFrame(); + FrameOutput get_next_frame(); - FrameOutput getFrameAtIndex(int64_t frameIndex); + FrameOutput get_frame_at_index(int64_t frame_index); // Returns frames at the given indices for a given stream as a single stacked // Tensor. - FrameBatchOutput getFramesAtIndices(const std::vector& frameIndices); + FrameBatchOutput get_frames_at_indices( + const std::vector& frame_indices); // Returns frames within a given range. The range is defined by [start, stop). // The values retrieved from the range are: [start, start+step, // start+(2*step), start+(3*step), ..., stop). The default for step is 1. - FrameBatchOutput getFramesInRange(int64_t start, int64_t stop, int64_t step); + FrameBatchOutput + get_frames_in_range(int64_t start, int64_t stop, int64_t step); // Decodes the first frame in any added stream that is visible at a given // timestamp. Frames in the video have a presentation timestamp and a @@ -216,9 +218,9 @@ class SingleStreamDecoder { // duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0). // i.e. it will be returned when this function is called with seconds=5.0 or // seconds=5.999, etc. - FrameOutput getFramePlayedAt(double seconds); + FrameOutput get_frame_played_at(double seconds); - FrameBatchOutput getFramesPlayedAt(const std::vector& timestamps); + FrameBatchOutput get_frames_played_at(const std::vector& timestamps); // Returns frames within a given pts range. The range is defined by // [startSeconds, stopSeconds) with respect to the pts values for frames. The @@ -236,14 +238,14 @@ class SingleStreamDecoder { // // Valid values for startSeconds and stopSeconds are: // - // [minPtsSecondsFromScan, maxPtsSecondsFromScan) - FrameBatchOutput getFramesPlayedInRange( - double startSeconds, - double stopSeconds); + // [minPtsSecondsFromScan, maxPtsSecondsFromScan) + FrameBatchOutput get_frames_played_in_range( + double start_seconds, + double stop_seconds); - AudioFramesOutput getFramesPlayedInRangeAudio( - double startSeconds, - std::optional stopSecondsOptional = std::nullopt); + AudioFramesOutput get_frames_played_in_range_audio( + double start_seconds, + std::optional stop_seconds_optional = std::nullopt); class EndOfFileException : public std::runtime_error { public: @@ -259,27 +261,27 @@ class SingleStreamDecoder { // Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we // can move it back to private. - FrameOutput getFrameAtIndexInternal( - int64_t frameIndex, - std::optional preAllocatedOutputTensor = std::nullopt); + FrameOutput get_frame_at_index_internal( + int64_t frame_index, + std::optional pre_allocated_output_tensor = std::nullopt); // Exposed for _test_frame_pts_equality, which is used to test non-regression // of pts resolution (64 to 32 bit floats) - double getPtsSecondsForFrame(int64_t frameIndex); + double get_pts_seconds_for_frame(int64_t frame_index); // Exposed for performance testing. struct DecodeStats { - int64_t numSeeksAttempted = 0; - int64_t numSeeksDone = 0; - int64_t numSeeksSkipped = 0; - int64_t numPacketsRead = 0; - int64_t numPacketsSentToDecoder = 0; - int64_t numFramesReceivedByDecoder = 0; - int64_t numFlushes = 0; + int64_t num_seeks_attempted = 0; + int64_t num_seeks_done = 0; + int64_t num_seeks_skipped = 0; + int64_t num_packets_read = 0; + int64_t num_packets_sent_to_decoder = 0; + int64_t num_frames_received_by_decoder = 0; + int64_t num_flushes = 0; }; - DecodeStats getDecodeStats() const; - void resetDecodeStats(); + DecodeStats get_decode_stats() const; + void reset_decode_stats(); private: // -------------------------------------------------------------------------- @@ -296,171 +298,171 @@ class SingleStreamDecoder { // typically done during pts -> index conversions). // TODO: This field is unset (left to the default) for entries in the // keyFrames vec! - int64_t nextPts = INT64_MAX; + int64_t next_pts = INT64_MAX; // Note that frameIndex is ALWAYS the index into all of the frames in that // stream, even when the FrameInfo is part of the key frame index. Given a // FrameInfo for a key frame, the frameIndex allows us to know which frame // that is in the stream. - int64_t frameIndex = 0; + int64_t frame_index = 0; // Indicates whether a frame is a key frame. It may appear redundant as it's // only true for FrameInfos in the keyFrames index, but it is needed to // correctly map frames between allFrames and keyFrames during the scan. - bool isKeyFrame = false; + bool is_key_frame = false; }; struct FilterGraphContext { - UniqueAVFilterGraph filterGraph; - AVFilterContext* sourceContext = nullptr; - AVFilterContext* sinkContext = nullptr; + UniqueAVFilterGraph filter_graph; + AVFilterContext* source_context = nullptr; + AVFilterContext* sink_context = nullptr; }; struct DecodedFrameContext { - int decodedWidth; - int decodedHeight; - AVPixelFormat decodedFormat; - int expectedWidth; - int expectedHeight; + int decoded_width; + int decoded_height; + AVPixelFormat decoded_format; + int expected_width; + int expected_height; bool operator==(const DecodedFrameContext&); bool operator!=(const DecodedFrameContext&); }; struct StreamInfo { - int streamIndex = -1; + int stream_index = -1; AVStream* stream = nullptr; - AVMediaType avMediaType = AVMEDIA_TYPE_UNKNOWN; + AVMediaType av_media_type = AVMEDIA_TYPE_UNKNOWN; - AVRational timeBase = {}; - UniqueAVCodecContext codecContext; + AVRational time_base = {}; + UniqueAVCodecContext codec_context; // The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was // called. - std::vector keyFrames; - std::vector allFrames; + std::vector<_frame_info> key_frames; + std::vector<_frame_info> all_frames; // TODO since the decoder is single-stream, these should be decoder fields, // not streamInfo fields. And they should be defined right next to // `cursor_`, with joint documentation. - int64_t lastDecodedAvFramePts = 0; - int64_t lastDecodedAvFrameDuration = 0; - VideoStreamOptions videoStreamOptions; - AudioStreamOptions audioStreamOptions; + int64_t last_decoded_avframe_pts = 0; + int64_t last_decoded_avframe_duration = 0; + VideoStreamOptions video_stream_options; + AudioStreamOptions audio_stream_options; // color-conversion fields. Only one of FilterGraphContext and // UniqueSwsContext should be non-null. - FilterGraphContext filterGraphContext; - ColorConversionLibrary colorConversionLibrary = FILTERGRAPH; - UniqueSwsContext swsContext; - UniqueSwrContext swrContext; + FilterGraphContext filter_graph_context; + ColorConversionLibrary color_conversion_library = FILTERGRAPH; + UniqueSwsContext sws_context; + UniqueSwrContext swr_context; // Used to know whether a new FilterGraphContext or UniqueSwsContext should // be created before decoding a new frame. - DecodedFrameContext prevFrameContext; + DecodedFrameContext prev_frame_context; }; // -------------------------------------------------------------------------- // INITIALIZERS // -------------------------------------------------------------------------- - void initializeDecoder(); - void setFFmpegLogLevel(); + void initialize_decoder(); + void set_ffmpeg_log_level(); // -------------------------------------------------------------------------- // DECODING APIS AND RELATED UTILS // -------------------------------------------------------------------------- - void setCursor(int64_t pts); - void setCursor(double) = delete; // prevent calls with doubles and floats - bool canWeAvoidSeeking() const; + void set_cursor(int64_t pts); + void set_cursor(double) = delete; // prevent calls with doubles and floats + bool can_we_avoid_seeking() const; - void maybeSeekToBeforeDesiredPts(); + void maybe_seek_to_before_desired_pts(); - UniqueAVFrame decodeAVFrame( - std::function filterFunction); + UniqueAVFrame decode_avframe( + std::function filter_function); - FrameOutput getNextFrameInternal( - std::optional preAllocatedOutputTensor = std::nullopt); + FrameOutput get_next_frame_internal( + std::optional pre_allocated_output_tensor = std::nullopt); - torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor); + torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwc_tensor); - FrameOutput convertAVFrameToFrameOutput( - UniqueAVFrame& avFrame, - std::optional preAllocatedOutputTensor = std::nullopt); + FrameOutput convert_avframe_to_frame_output( + UniqueAVFrame& avframe, + std::optional pre_allocated_output_tensor = std::nullopt); - void convertAVFrameToFrameOutputOnCPU( - UniqueAVFrame& avFrame, - FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = std::nullopt); + void convert_avframe_to_frame_output_on_c_p_u( + UniqueAVFrame& avframe, + FrameOutput& frame_output, + std::optional pre_allocated_output_tensor = std::nullopt); - void convertAudioAVFrameToFrameOutputOnCPU( - UniqueAVFrame& srcAVFrame, - FrameOutput& frameOutput); + void convert_audio_avframe_to_frame_output_on_c_p_u( + UniqueAVFrame& src_avframe, + FrameOutput& frame_output); - torch::Tensor convertAVFrameToTensorUsingFilterGraph( - const UniqueAVFrame& avFrame); + torch::Tensor convert_avframe_to_tensor_using_filter_graph( + const UniqueAVFrame& avframe); - int convertAVFrameToTensorUsingSwsScale( - const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor); + int convert_avframe_to_tensor_using_sws_scale( + const UniqueAVFrame& avframe, + torch::Tensor& output_tensor); - UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( - const UniqueAVFrame& srcAVFrame, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate); + UniqueAVFrame convert_audio_avframe_sample_format_and_sample_rate( + const UniqueAVFrame& src_avframe, + AVSampleFormat source_sample_format, + AVSampleFormat desired_sample_format, + int source_sample_rate, + int desired_sample_rate); - std::optional maybeFlushSwrBuffers(); + std::optional maybe_flush_swr_buffers(); // -------------------------------------------------------------------------- // COLOR CONVERSION LIBRARIES HANDLERS CREATION // -------------------------------------------------------------------------- - void createFilterGraph( - StreamInfo& streamInfo, - int expectedOutputHeight, - int expectedOutputWidth); + void create_filter_graph( + StreamInfo& stream_info, + int expected_output_height, + int expected_output_width); - void createSwsContext( - StreamInfo& streamInfo, - const DecodedFrameContext& frameContext, + void create_sws_context( + StreamInfo& stream_info, + const DecodedFrameContext& frame_context, const enum AVColorSpace colorspace); - void createSwrContext( - StreamInfo& streamInfo, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate); + void create_swr_context( + StreamInfo& stream_info, + AVSampleFormat source_sample_format, + AVSampleFormat desired_sample_format, + int source_sample_rate, + int desired_sample_rate); // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- - int getKeyFrameIndexForPts(int64_t pts) const; + int get_key_frame_index_for_pts(int64_t pts) const; // Returns the key frame index of the presentation timestamp using our index. // We build this index by scanning the file in // scanFileAndUpdateMetadataAndIndex - int getKeyFrameIndexForPtsUsingScannedIndex( - const std::vector& keyFrames, + int get_key_frame_index_for_pts_using_scanned_index( + const std::vector<_single_stream_decoder::_frame_info>& key_frames, int64_t pts) const; - int64_t secondsToIndexLowerBound(double seconds); + int64_t seconds_to_index_lower_bound(double seconds); - int64_t secondsToIndexUpperBound(double seconds); + int64_t seconds_to_index_upper_bound(double seconds); - int64_t getPts(int64_t frameIndex); + int64_t get_pts(int64_t frame_index); // -------------------------------------------------------------------------- // STREAM AND METADATA APIS // -------------------------------------------------------------------------- - void addStream( - int streamIndex, - AVMediaType mediaType, + void add_stream( + int stream_index, + AVMediaType media_type, const torch::Device& device = torch::kCPU, - std::optional ffmpegThreadCount = std::nullopt); + std::optional ffmpeg_thread_count = std::nullopt); // Returns the "best" stream index for a given media type. The "best" is // determined by various heuristics in FFMPEG. @@ -469,44 +471,44 @@ class SingleStreamDecoder { // for more details about the heuristics. // Returns the key frame index of the presentation timestamp using FFMPEG's // index. Note that this index may be truncated for some files. - int getBestStreamIndex(AVMediaType mediaType); + int get_best_stream_index(_avmedia_type media_type); - int64_t getNumFrames(const StreamMetadata& streamMetadata); - double getMinSeconds(const StreamMetadata& streamMetadata); - double getMaxSeconds(const StreamMetadata& streamMetadata); + int64_t get_num_frames(const StreamMetadata& stream_metadata); + double get_min_seconds(const StreamMetadata& stream_metadata); + double get_max_seconds(const StreamMetadata& stream_metadata); // -------------------------------------------------------------------------- // VALIDATION UTILS // -------------------------------------------------------------------------- - void validateActiveStream( - std::optional avMediaType = std::nullopt); - void validateScannedAllStreams(const std::string& msg); - void validateFrameIndex( - const StreamMetadata& streamMetadata, - int64_t frameIndex); + void validate_active_stream( + std::optional<_avmedia_type> av_media_type = std::nullopt); + void validate_scanned_all_streams(const std::string& msg); + void validate_frame_index( + const StreamMetadata& stream_metadata, + int64_t frame_index); // -------------------------------------------------------------------------- // ATTRIBUTES // -------------------------------------------------------------------------- - SeekMode seekMode_; - ContainerMetadata containerMetadata_; - UniqueAVFormatContext formatContext_; - std::map streamInfos_; + SeekMode seek_mode_; + ContainerMetadata container_metadata_; + UniqueAVFormatContext format_context_; + std::map stream_infos_; const int NO_ACTIVE_STREAM = -2; - int activeStreamIndex_ = NO_ACTIVE_STREAM; + int active_stream_index_ = NO_ACTIVE_STREAM; - bool cursorWasJustSet_ = false; + bool cursor_was_just_set_ = false; // The desired position of the cursor in the stream. We send frames >= this // pts to the user when they request a frame. int64_t cursor_ = INT64_MIN; // Stores various internal decoding stats. - DecodeStats decodeStats_; + DecodeStats decode_stats_; // Stores the AVIOContext for the input buffer. - std::unique_ptr avioContextHolder_; + std::unique_ptr avio_context_holder_; // Whether or not we have already scanned all streams to update the metadata. - bool scannedAllStreams_ = false; + bool scanned_all_streams_ = false; // Tracks that we've already been initialized. bool initialized_ = false; }; @@ -527,19 +529,19 @@ class SingleStreamDecoder { // *decreasing order of accuracy*, we use the following sources for determining // height and width: // - getHeightAndWidthFromResizedAVFrame(). This is the height and width of the -// AVframe, *post*-resizing. This is only used for single-frame decoding APIs, -// on CPU, with filtergraph. +// AVframe, *post*-resizing. This is only used for single-frame decoding APIs, +// on CPU, with filtergraph. // - getHeightAndWidthFromOptionsOrAVFrame(). This is the height and width from -// the user-specified options if they exist, or the height and width of the -// AVFrame *before* it is resized. In theory, i.e. if there are no bugs within -// our code or within FFmpeg code, this should be exactly the same as -// getHeightAndWidthFromResizedAVFrame(). This is used by single-frame -// decoding APIs, on CPU with swscale, and on GPU. +// the user-specified options if they exist, or the height and width of the +// AVFrame *before* it is resized. In theory, i.e. if there are no bugs within +// our code or within FFmpeg code, this should be exactly the same as +// getHeightAndWidthFromResizedAVFrame(). This is used by single-frame +// decoding APIs, on CPU with swscale, and on GPU. // - getHeightAndWidthFromOptionsOrMetadata(). This is the height and width from -// the user-specified options if they exist, or the height and width form the -// stream metadata, which itself got its value from the CodecContext, when the -// stream was added. This is used by batch decoding APIs, for both GPU and -// CPU. +// the user-specified options if they exist, or the height and width form the +// stream metadata, which itself got its value from the CodecContext, when the +// stream was added. This is used by batch decoding APIs, for both GPU and +// CPU. // // The source of truth for height and width really is the (resized) AVFrame: it // comes from the decoded ouptut of FFmpeg. The info from the metadata (i.e. @@ -565,27 +567,28 @@ struct FrameDims { // There's nothing preventing you from calling this on a non-resized frame, but // please don't. -FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame); +FrameDims get_height_and_width_from_resized_avframe( + const AVFrame& resized_avframe); -FrameDims getHeightAndWidthFromOptionsOrMetadata( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, - const SingleStreamDecoder::StreamMetadata& streamMetadata); +FrameDims get_height_and_width_from_options_or_metadata( + const SingleStreamDecoder::VideoStreamOptions& video_stream_options, + const SingleStreamDecoder::StreamMetadata& stream_metadata); -FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const SingleStreamDecoder::VideoStreamOptions& videoStreamOptions, - const UniqueAVFrame& avFrame); +FrameDims get_height_and_width_from_options_or_avframe( + const SingleStreamDecoder::VideoStreamOptions& video_stream_options, + const UniqueAVFrame& avframe); -torch::Tensor allocateEmptyHWCTensor( +torch::Tensor allocate_empty_h_w_c_tensor( int height, int width, torch::Device device, - std::optional numFrames = std::nullopt); + std::optional num_frames = std::nullopt); // Prints the SingleStreamDecoder::DecodeStats to the ostream. std::ostream& operator<<( std::ostream& os, const SingleStreamDecoder::DecodeStats& stats); -SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode); +SingleStreamDecoder::SeekMode seek_mode_from_string(std::string_view seek_mode); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 45324908..a2fad8d4 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -22,22 +22,22 @@ namespace facebook::torchcodec { // `Tensor(a!)`. The `(a!)` part normally indicates that the tensor is being // mutated in place. We need it to make sure that torch.compile does not reorder // calls to these functions. For more detail, see: -// https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#readme +// https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#readme TORCH_LIBRARY(torchcodec_ns, m) { m.impl_abstract_pystub( "torchcodec._core.ops", "//pytorch/torchcodec:torchcodec"); - m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); + m.def("create_from_file(str filename, str? seek_mode=_none) -> Tensor"); m.def( - "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); + "create_from_tensor(Tensor video_tensor, str? seek_mode=_none) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); m.def( - "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None) -> ()"); + "_add_video_stream(Tensor(a!) decoder, *, int? width=_none, int? height=_none, int? num_threads=_none, str? dimension_order=_none, int? stream_index=_none, str? device=_none, str? color_conversion_library=_none) -> ()"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? width=_none, int? height=_none, int? num_threads=_none, str? dimension_order=_none, int? stream_index=_none, str? device=_none) -> ()"); m.def( - "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None) -> ()"); - m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); - m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)"); + "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=_none, int? sample_rate=_none) -> ()"); + m.def("seek_to_pts(_tensor(a!) decoder, float seconds) -> ()"); + m.def("get_next_frame(_tensor(a!) decoder) -> (Tensor, Tensor, Tensor)"); m.def( "get_frame_at_pts(Tensor(a!) decoder, float seconds) -> (Tensor, Tensor, Tensor)"); m.def( @@ -45,55 +45,56 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); + "get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=_none) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)"); m.def( "get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)"); - m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor"); - m.def("get_json_metadata(Tensor(a!) decoder) -> str"); - m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); + m.def("_get_key_frame_indices(_tensor(a!) decoder) -> Tensor"); + m.def("get_json_metadata(_tensor(a!) decoder) -> str"); + m.def("get_container_json_metadata(_tensor(a!) decoder) -> str"); m.def( "get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str"); m.def("_get_json_ffmpeg_library_versions() -> str"); m.def( "_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool"); - m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()"); + m.def("scan_all_streams_to_update_metadata(_tensor(a!) decoder) -> ()"); } namespace { -at::Tensor wrapDecoderPointerToTensor( - std::unique_ptr uniqueDecoder) { - SingleStreamDecoder* decoder = uniqueDecoder.release(); +at::Tensor wrap_decoder_pointer_to_tensor( + std::unique_ptr<_single_stream_decoder> unique_decoder) { + SingleStreamDecoder* decoder = unique_decoder.release(); auto deleter = [decoder](void*) { delete decoder; }; at::Tensor tensor = at::from_blob( decoder, {sizeof(SingleStreamDecoder*)}, deleter, {at::kLong}); - auto videoDecoder = - static_cast(tensor.mutable_data_ptr()); - TORCH_CHECK_EQ(videoDecoder, decoder) << "videoDecoder=" << videoDecoder; + auto video_decoder = + static_cast<_single_stream_decoder*>(tensor.mutable_data_ptr()); + TORCH_CHECK_EQ(videoDecoder, decoder) << "videoDecoder=" << video_decoder; return tensor; } SingleStreamDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); void* buffer = tensor.mutable_data_ptr(); - SingleStreamDecoder* decoder = static_cast(buffer); + SingleStreamDecoder* decoder = static_cast<_single_stream_decoder*>(buffer); return decoder; } // The elements of this tuple are all tensors that represent a single frame: -// 1. The frame data, which is a multidimensional tensor. -// 2. A single float value for the pts in seconds. -// 3. A single float value for the duration in seconds. +// 1. The frame data, which is a multidimensional tensor. +// 2. A single float value for the pts in seconds. +// 3. A single float value for the duration in seconds. // The reason we use Tensors for the second and third values is so we can run // under torch.compile(). using OpsFrameOutput = std::tuple; -OpsFrameOutput makeOpsFrameOutput(SingleStreamDecoder::FrameOutput& frame) { +OpsFrameOutput make_ops_frame_output( + _single_stream_decoder::_frame_output& frame) { return std::make_tuple( frame.data, torch::tensor(frame.ptsSeconds, torch::dtype(torch::kFloat64)), @@ -103,44 +104,45 @@ OpsFrameOutput makeOpsFrameOutput(SingleStreamDecoder::FrameOutput& frame) { // All elements of this tuple are tensors of the same leading dimension. The // tuple represents the frames for N total frames, where N is the dimension of // each stacked tensor. The elments are: -// 1. Stacked tensor of data for all N frames. Each frame is also a -// multidimensional tensor. -// 2. Tensor of N pts values in seconds, where each pts is a single -// float. -// 3. Tensor of N durationis in seconds, where each duration is a -// single float. +// 1. Stacked tensor of data for all N frames. Each frame is also a +// multidimensional tensor. +// 2. Tensor of N pts values in seconds, where each pts is a single +// float. +// 3. Tensor of N durationis in seconds, where each duration is a +// single float. using OpsFrameBatchOutput = std::tuple; -OpsFrameBatchOutput makeOpsFrameBatchOutput( +OpsFrameBatchOutput make_ops_frame_batch_output( SingleStreamDecoder::FrameBatchOutput& batch) { - return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds); + return std::make_tuple(batch.data, batch.pts_seconds, batch.duration_seconds); } // The elements of this tuple are all tensors that represent the concatenation // of multiple audio frames: -// 1. The frames data (concatenated) -// 2. A single float value for the pts of the first frame, in seconds. +// 1. The frames data (concatenated) +// 2. A single float value for the pts of the first frame, in seconds. using OpsAudioFramesOutput = std::tuple; -OpsAudioFramesOutput makeOpsAudioFramesOutput( - SingleStreamDecoder::AudioFramesOutput& audioFrames) { +OpsAudioFramesOutput make_ops_audio_frames_output( + SingleStreamDecoder::AudioFramesOutput& audio_frames) { return std::make_tuple( - audioFrames.data, + audio_frames.data, torch::tensor(audioFrames.ptsSeconds, torch::dtype(torch::kFloat64))); } -std::string quoteValue(const std::string& value) { +std::string quote_value(const std::string& value) { return "\"" + value + "\""; } -std::string mapToJson(const std::map& metadataMap) { +std::string map_to_json( + const std::map& metadata_map) { std::stringstream ss; ss << "{\n"; - auto it = metadataMap.begin(); - while (it != metadataMap.end()) { + auto it = metadata_map.begin(); + while (it != metadata_map.end()) { ss << "\"" << it->first << "\": " << it->second; ++it; - if (it != metadataMap.end()) { + if (it != metadata_map.end()) { ss << ",\n"; } else { ss << "\n"; @@ -161,17 +163,18 @@ std::string mapToJson(const std::map& metadataMap) { at::Tensor create_from_file( std::string_view filename, std::optional seek_mode = std::nullopt) { - std::string filenameStr(filename); + std::string filename_str(filename); - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SingleStreamDecoder::SeekMode real_seek = + SingleStreamDecoder::SeekMode::exact; if (seek_mode.has_value()) { - realSeek = seekModeFromString(seek_mode.value()); + real_seek = seek_mode_from_string(seek_mode.value()); } - std::unique_ptr uniqueDecoder = - std::make_unique(filenameStr, realSeek); + std::unique_ptr<_single_stream_decoder> unique_decoder = + std::make_unique<_single_stream_decoder>(filename_str, real_seek); - return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); + return wrap_decoder_pointer_to_tensor(std::move(unique_decoder)); } // Create a SingleStreamDecoder from the actual bytes of a video and wrap the @@ -182,26 +185,28 @@ at::Tensor create_from_tensor( TORCH_CHECK(video_tensor.is_contiguous(), "video_tensor must be contiguous"); TORCH_CHECK( video_tensor.scalar_type() == torch::kUInt8, - "video_tensor must be kUInt8"); + "video_tensor must be k_u_int8"); void* data = video_tensor.mutable_data_ptr(); size_t length = video_tensor.numel(); - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SingleStreamDecoder::SeekMode real_seek = + SingleStreamDecoder::SeekMode::exact; if (seek_mode.has_value()) { - realSeek = seekModeFromString(seek_mode.value()); + real_seek = seek_mode_from_string(seek_mode.value()); } - auto contextHolder = std::make_unique(data, length); + auto context_holder = std::make_unique<_avio_bytes_context>(data, length); - std::unique_ptr uniqueDecoder = - std::make_unique(std::move(contextHolder), realSeek); - return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); + std::unique_ptr<_single_stream_decoder> unique_decoder = + std::make_unique<_single_stream_decoder>( + std::move(context_holder), real_seek); + return wrap_decoder_pointer_to_tensor(std::move(unique_decoder)); } at::Tensor _convert_to_tensor(int64_t decoder_ptr) { - auto decoder = reinterpret_cast(decoder_ptr); - std::unique_ptr uniqueDecoder(decoder); - return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); + auto decoder = reinterpret_cast<_single_stream_decoder*>(decoder_ptr); + std::unique_ptr<_single_stream_decoder> unique_decoder(decoder); + return wrap_decoder_pointer_to_tensor(std::move(unique_decoder)); } void _add_video_stream( @@ -213,36 +218,36 @@ void _add_video_stream( std::optional stream_index = std::nullopt, std::optional device = std::nullopt, std::optional color_conversion_library = std::nullopt) { - SingleStreamDecoder::VideoStreamOptions videoStreamOptions; - videoStreamOptions.width = width; - videoStreamOptions.height = height; - videoStreamOptions.ffmpegThreadCount = num_threads; + SingleStreamDecoder::VideoStreamOptions video_stream_options; + video_stream_options.width = width; + video_stream_options.height = height; + video_stream_options.ffmpeg_thread_count = num_threads; if (dimension_order.has_value()) { - std::string stdDimensionOrder{dimension_order.value()}; - TORCH_CHECK(stdDimensionOrder == "NHWC" || stdDimensionOrder == "NCHW"); - videoStreamOptions.dimensionOrder = stdDimensionOrder; + std::string std_dimension_order{dimension_order.value()}; + TORCH_CHECK(stdDimensionOrder == "NHWC" || std_dimension_order == "NCHW"); + video_stream_options.dimension_order = std_dimension_order; } if (color_conversion_library.has_value()) { - std::string stdColorConversionLibrary{color_conversion_library.value()}; + std::string std_color_conversion_library{color_conversion_library.value()}; if (stdColorConversionLibrary == "filtergraph") { - videoStreamOptions.colorConversionLibrary = + video_stream_options.color_conversion_library = SingleStreamDecoder::ColorConversionLibrary::FILTERGRAPH; } else if (stdColorConversionLibrary == "swscale") { - videoStreamOptions.colorConversionLibrary = + video_stream_options.color_conversion_library = SingleStreamDecoder::ColorConversionLibrary::SWSCALE; } else { throw std::runtime_error( - "Invalid color_conversion_library=" + stdColorConversionLibrary + + "Invalid color_conversion_library=" + std_color_conversion_library + ". color_conversion_library must be either filtergraph or swscale."); } } if (device.has_value()) { if (device.value() == "cpu") { - videoStreamOptions.device = torch::Device(torch::kCPU); + video_stream_options.device = torch::Device(torch::kCPU); } else if (device.value().rfind("cuda", 0) == 0) { // starts with "cuda" - std::string deviceStr(device.value()); - videoStreamOptions.device = torch::Device(deviceStr); + std::string device_str(device.value()); + video_stream_options.device = torch::Device(deviceStr); } else { throw std::runtime_error( "Invalid device=" + std::string(device.value()) + @@ -250,8 +255,9 @@ void _add_video_stream( } } - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); + video_decoder->add_video_stream( + stream_index.value_or(-1), video_stream_options); } // Add a new video stream at `stream_index` using the provided options. @@ -277,63 +283,64 @@ void add_audio_stream( at::Tensor& decoder, std::optional stream_index = std::nullopt, std::optional sample_rate = std::nullopt) { - SingleStreamDecoder::AudioStreamOptions audioStreamOptions; - audioStreamOptions.sampleRate = sample_rate; + SingleStreamDecoder::AudioStreamOptions audio_stream_options; + audio_stream_options.sample_rate = sample_rate; - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - videoDecoder->addAudioStream(stream_index.value_or(-1), audioStreamOptions); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); + video_decoder->add_audio_stream( + stream_index.value_or(-1), audio_stream_options); } // Seek to a particular presentation timestamp in the video in seconds. void seek_to_pts(at::Tensor& decoder, double seconds) { - auto videoDecoder = - static_cast(decoder.mutable_data_ptr()); - videoDecoder->setCursorPtsInSeconds(seconds); + auto video_decoder = + static_cast<_single_stream_decoder*>(decoder.mutable_data_ptr()); + video_decoder->set_cursor_pts_in_seconds(seconds); } // Get the next frame from the video as a tuple that has the frame data, pts and // duration as tensors. OpsFrameOutput get_next_frame(at::Tensor& decoder) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); SingleStreamDecoder::FrameOutput result; try { - result = videoDecoder->getNextFrame(); + result = video_decoder->get_next_frame(); } catch (const SingleStreamDecoder::EndOfFileException& e) { C10_THROW_ERROR(IndexError, e.what()); } - return makeOpsFrameOutput(result); + return make_ops_frame_output(result); } // Return the frame that is visible at a given timestamp in seconds. Each frame // in FFMPEG has a presentation timestamp and a duration. The frame visible at a // given timestamp T has T >= PTS and T < PTS + Duration. OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); SingleStreamDecoder::FrameOutput result; try { - result = videoDecoder->getFramePlayedAt(seconds); + result = video_decoder->get_frame_played_at(seconds); } catch (const SingleStreamDecoder::EndOfFileException& e) { C10_THROW_ERROR(IndexError, e.what()); } - return makeOpsFrameOutput(result); + return make_ops_frame_output(result); } // Return the frame that is visible at a given index in the video. OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto result = videoDecoder->getFrameAtIndex(frame_index); - return makeOpsFrameOutput(result); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); + auto result = video_decoder->get_frame_at_index(frame_index); + return make_ops_frame_output(result); } // Return the frames at given indices for a given stream OpsFrameBatchOutput get_frames_at_indices( at::Tensor& decoder, at::IntArrayRef frame_indices) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - std::vector frameIndicesVec( + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); + std::vector frame_indices_vec( frame_indices.begin(), frame_indices.end()); - auto result = videoDecoder->getFramesAtIndices(frameIndicesVec); - return makeOpsFrameBatchOutput(result); + auto result = video_decoder->get_frames_at_indices(frame_indices_vec); + return make_ops_frame_batch_output(result); } // Return the frames inside a range as a single stacked Tensor. The range is @@ -343,19 +350,20 @@ OpsFrameBatchOutput get_frames_in_range( int64_t start, int64_t stop, std::optional step = std::nullopt) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto result = videoDecoder->getFramesInRange(start, stop, step.value_or(1)); - return makeOpsFrameBatchOutput(result); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); + auto result = + video_decoder->get_frames_in_range(start, stop, step.value_or(1)); + return make_ops_frame_batch_output(result); } // Return the frames at given ptss for a given stream OpsFrameBatchOutput get_frames_by_pts( at::Tensor& decoder, at::ArrayRef timestamps) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - std::vector timestampsVec(timestamps.begin(), timestamps.end()); - auto result = videoDecoder->getFramesPlayedAt(timestampsVec); - return makeOpsFrameBatchOutput(result); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); + std::vector timestamps_vec(timestamps.begin(), timestamps.end()); + auto result = video_decoder->get_frames_played_at(timestamps_vec); + return make_ops_frame_batch_output(result); } // Return the frames inside the range as a single stacked Tensor. The range is @@ -365,20 +373,20 @@ OpsFrameBatchOutput get_frames_by_pts_in_range( at::Tensor& decoder, double start_seconds, double stop_seconds) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); auto result = - videoDecoder->getFramesPlayedInRange(start_seconds, stop_seconds); - return makeOpsFrameBatchOutput(result); + video_decoder->get_frames_played_in_range(start_seconds, stop_seconds); + return make_ops_frame_batch_output(result); } OpsAudioFramesOutput get_frames_by_pts_in_range_audio( at::Tensor& decoder, double start_seconds, std::optional stop_seconds = std::nullopt) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto result = - videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds); - return makeOpsAudioFramesOutput(result); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); + auto result = video_decoder->get_frames_played_in_range_audio( + start_seconds, stop_seconds); + return make_ops_audio_frames_output(result); } // For testing only. We need to implement this operation as a core library @@ -394,186 +402,188 @@ bool _test_frame_pts_equality( at::Tensor& decoder, int64_t frame_index, double pts_seconds_to_test) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); return pts_seconds_to_test == - videoDecoder->getPtsSecondsForFrame(frame_index); + video_decoder->get_pts_seconds_for_frame(frame_index); } torch::Tensor _get_key_frame_indices(at::Tensor& decoder) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - return videoDecoder->getKeyFrameIndices(); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); + return video_decoder->get_key_frame_indices(); } // Get the metadata from the video as a string. std::string get_json_metadata(at::Tensor& decoder) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); - SingleStreamDecoder::ContainerMetadata videoMetadata = - videoDecoder->getContainerMetadata(); - auto maybeBestVideoStreamIndex = videoMetadata.bestVideoStreamIndex; + SingleStreamDecoder::ContainerMetadata video_metadata = + video_decoder->get_container_metadata(); + auto maybe_best_video_stream_index = video_metadata.best_video_stream_index; - std::map metadataMap; + std::map metadata_map; // serialize the metadata into a string std::stringstream ss; - double durationSeconds = 0; + double duration_seconds = 0; if (maybeBestVideoStreamIndex.has_value() && - videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex] + video_metadata.all_stream_metadata[*maybe_best_video_stream_index] .durationSeconds.has_value()) { - durationSeconds = - videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex] + duration_seconds = + video_metadata.all_stream_metadata[*maybe_best_video_stream_index] .durationSeconds.value_or(0); } else { // Fallback to container-level duration if stream duration is not found. - durationSeconds = videoMetadata.durationSeconds.value_or(0); + duration_seconds = video_metadata.duration_seconds.value_or(0); } - metadataMap["durationSeconds"] = std::to_string(durationSeconds); + metadata_map["duration_seconds"] = std::to_string(duration_seconds); if (videoMetadata.bitRate.has_value()) { - metadataMap["bitRate"] = std::to_string(videoMetadata.bitRate.value()); + metadata_map["bit_rate"] = std::to_string(video_metadata.bit_rate.value()); } if (maybeBestVideoStreamIndex.has_value()) { - auto streamMetadata = - videoMetadata.allStreamMetadata[*maybeBestVideoStreamIndex]; + auto stream_metadata = + video_metadata.all_stream_metadata[*maybe_best_video_stream_index]; if (streamMetadata.numFramesFromScan.has_value()) { - metadataMap["numFrames"] = - std::to_string(*streamMetadata.numFramesFromScan); + metadata_map["num_frames"] = + std::to_string(*stream_metadata.num_frames_from_scan); } else if (streamMetadata.numFrames.has_value()) { - metadataMap["numFrames"] = std::to_string(*streamMetadata.numFrames); + metadata_map["num_frames"] = std::to_string(*stream_metadata.num_frames); } if (streamMetadata.minPtsSecondsFromScan.has_value()) { - metadataMap["minPtsSecondsFromScan"] = - std::to_string(*streamMetadata.minPtsSecondsFromScan); + metadata_map["min_pts_seconds_from_scan"] = + std::to_string(*stream_metadata.min_pts_seconds_from_scan); } if (streamMetadata.maxPtsSecondsFromScan.has_value()) { - metadataMap["maxPtsSecondsFromScan"] = - std::to_string(*streamMetadata.maxPtsSecondsFromScan); + metadata_map["max_pts_seconds_from_scan"] = + std::to_string(*stream_metadata.max_pts_seconds_from_scan); } if (streamMetadata.codecName.has_value()) { - metadataMap["codec"] = quoteValue(streamMetadata.codecName.value()); + metadata_map["codec"] = quote_value(stream_metadata.codec_name.value()); } if (streamMetadata.width.has_value()) { - metadataMap["width"] = std::to_string(*streamMetadata.width); + metadata_map["width"] = std::to_string(*stream_metadata.width); } if (streamMetadata.height.has_value()) { - metadataMap["height"] = std::to_string(*streamMetadata.height); + metadata_map["height"] = std::to_string(*stream_metadata.height); } if (streamMetadata.averageFps.has_value()) { - metadataMap["averageFps"] = std::to_string(*streamMetadata.averageFps); + metadata_map["average_fps"] = + std::to_string(*stream_metadata.average_fps); } } if (videoMetadata.bestVideoStreamIndex.has_value()) { - metadataMap["bestVideoStreamIndex"] = - std::to_string(*videoMetadata.bestVideoStreamIndex); + metadata_map["best_video_stream_index"] = + std::to_string(*video_metadata.best_video_stream_index); } if (videoMetadata.bestAudioStreamIndex.has_value()) { - metadataMap["bestAudioStreamIndex"] = - std::to_string(*videoMetadata.bestAudioStreamIndex); + metadata_map["best_audio_stream_index"] = + std::to_string(*video_metadata.best_audio_stream_index); } - return mapToJson(metadataMap); + return map_to_json(metadata_map); } // Get the container metadata as a string. std::string get_container_json_metadata(at::Tensor& decoder) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); - auto containerMetadata = videoDecoder->getContainerMetadata(); + auto container_metadata = video_decoder->get_container_metadata(); std::map map; if (containerMetadata.durationSeconds.has_value()) { - map["durationSeconds"] = std::to_string(*containerMetadata.durationSeconds); + map["duration_seconds"] = + std::to_string(*container_metadata.duration_seconds); } if (containerMetadata.bitRate.has_value()) { - map["bitRate"] = std::to_string(*containerMetadata.bitRate); + map["bit_rate"] = std::to_string(*container_metadata.bit_rate); } if (containerMetadata.bestVideoStreamIndex.has_value()) { - map["bestVideoStreamIndex"] = - std::to_string(*containerMetadata.bestVideoStreamIndex); + map["best_video_stream_index"] = + std::to_string(*container_metadata.best_video_stream_index); } if (containerMetadata.bestAudioStreamIndex.has_value()) { - map["bestAudioStreamIndex"] = - std::to_string(*containerMetadata.bestAudioStreamIndex); + map["best_audio_stream_index"] = + std::to_string(*container_metadata.best_audio_stream_index); } - map["numStreams"] = - std::to_string(containerMetadata.allStreamMetadata.size()); + map["num_streams"] = + std::to_string(container_metadata.all_stream_metadata.size()); - return mapToJson(map); + return map_to_json(map); } // Get the stream metadata as a string. std::string get_stream_json_metadata( at::Tensor& decoder, int64_t stream_index) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto allStreamMetadata = - videoDecoder->getContainerMetadata().allStreamMetadata; + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); + auto all_stream_metadata = + video_decoder->get_container_metadata().all_stream_metadata; if (stream_index < 0 || - stream_index >= static_cast(allStreamMetadata.size())) { + stream_index >= static_cast(all_stream_metadata.size())) { throw std::out_of_range( "stream_index out of bounds: " + std::to_string(stream_index)); } - auto streamMetadata = allStreamMetadata[stream_index]; + auto stream_metadata = all_stream_metadata[stream_index]; std::map map; if (streamMetadata.durationSeconds.has_value()) { - map["durationSeconds"] = std::to_string(*streamMetadata.durationSeconds); + map["duration_seconds"] = std::to_string(*stream_metadata.duration_seconds); } if (streamMetadata.bitRate.has_value()) { - map["bitRate"] = std::to_string(*streamMetadata.bitRate); + map["bit_rate"] = std::to_string(*stream_metadata.bit_rate); } if (streamMetadata.numFramesFromScan.has_value()) { - map["numFramesFromScan"] = - std::to_string(*streamMetadata.numFramesFromScan); + map["num_frames_from_scan"] = + std::to_string(*stream_metadata.num_frames_from_scan); } if (streamMetadata.numFrames.has_value()) { - map["numFrames"] = std::to_string(*streamMetadata.numFrames); + map["num_frames"] = std::to_string(*stream_metadata.num_frames); } if (streamMetadata.beginStreamFromHeader.has_value()) { - map["beginStreamFromHeader"] = - std::to_string(*streamMetadata.beginStreamFromHeader); + map["begin_stream_from_header"] = + std::to_string(*stream_metadata.begin_stream_from_header); } if (streamMetadata.minPtsSecondsFromScan.has_value()) { - map["minPtsSecondsFromScan"] = - std::to_string(*streamMetadata.minPtsSecondsFromScan); + map["min_pts_seconds_from_scan"] = + std::to_string(*stream_metadata.min_pts_seconds_from_scan); } if (streamMetadata.maxPtsSecondsFromScan.has_value()) { - map["maxPtsSecondsFromScan"] = - std::to_string(*streamMetadata.maxPtsSecondsFromScan); + map["max_pts_seconds_from_scan"] = + std::to_string(*stream_metadata.max_pts_seconds_from_scan); } if (streamMetadata.codecName.has_value()) { - map["codec"] = quoteValue(streamMetadata.codecName.value()); + map["codec"] = quote_value(stream_metadata.codec_name.value()); } if (streamMetadata.width.has_value()) { - map["width"] = std::to_string(*streamMetadata.width); + map["width"] = std::to_string(*stream_metadata.width); } if (streamMetadata.height.has_value()) { - map["height"] = std::to_string(*streamMetadata.height); + map["height"] = std::to_string(*stream_metadata.height); } if (streamMetadata.averageFps.has_value()) { - map["averageFps"] = std::to_string(*streamMetadata.averageFps); + map["average_fps"] = std::to_string(*stream_metadata.average_fps); } if (streamMetadata.sampleRate.has_value()) { - map["sampleRate"] = std::to_string(*streamMetadata.sampleRate); + map["sample_rate"] = std::to_string(*stream_metadata.sample_rate); } if (streamMetadata.numChannels.has_value()) { - map["numChannels"] = std::to_string(*streamMetadata.numChannels); + map["num_channels"] = std::to_string(*stream_metadata.num_channels); } if (streamMetadata.sampleFormat.has_value()) { - map["sampleFormat"] = quoteValue(streamMetadata.sampleFormat.value()); + map["sample_format"] = quote_value(stream_metadata.sample_format.value()); } if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) { - map["mediaType"] = quoteValue("video"); + map["media_type"] = quote_value("video"); } else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) { - map["mediaType"] = quoteValue("audio"); + map["media_type"] = quote_value("audio"); } else { - map["mediaType"] = quoteValue("other"); + map["media_type"] = quote_value("other"); } - return mapToJson(map); + return map_to_json(map); } // Returns version information about the various FFMPEG libraries that are @@ -609,8 +619,8 @@ std::string _get_json_ffmpeg_library_versions() { // accurate seeking. Note that this function reads the entire video but it does // not decode frames. Reading a video file is much cheaper than decoding it. void scan_all_streams_to_update_metadata(at::Tensor& decoder) { - auto videoDecoder = unwrapTensorToGetDecoder(decoder); - videoDecoder->scanFileAndUpdateMetadataAndIndex(); + auto video_decoder = unwrap_tensor_to_get_decoder(decoder); + video_decoder->scan_file_and_update_metadata_and_index(); } TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { diff --git a/src/torchcodec/_core/pybind_ops.cpp b/src/torchcodec/_core/pybind_ops.cpp index 6f873f5a..6f273307 100644 --- a/src/torchcodec/_core/pybind_ops.cpp +++ b/src/torchcodec/_core/pybind_ops.cpp @@ -19,22 +19,24 @@ namespace facebook::torchcodec { // In principle, this should be able to return a tensor. But when we try that, // we run into the bug reported here: // -// https://github.com/pytorch/pytorch/issues/136664 +// https://github.com/pytorch/pytorch/issues/136664 // // So we instead launder the pointer through an int, and then use a conversion // function on the custom ops side to launder that int into a tensor. int64_t create_from_file_like( py::object file_like, std::optional seek_mode) { - SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact; + SingleStreamDecoder::SeekMode real_seek = + SingleStreamDecoder::SeekMode::exact; if (seek_mode.has_value()) { - realSeek = seekModeFromString(seek_mode.value()); + real_seek = seek_mode_from_string(seek_mode.value()); } - auto avioContextHolder = std::make_unique(file_like); + auto avio_context_holder = + std::make_unique<_avio_file_like_context>(file_like); SingleStreamDecoder* decoder = - new SingleStreamDecoder(std::move(avioContextHolder), realSeek); + new SingleStreamDecoder(std::move(avioContextHolder), real_seek); return reinterpret_cast(decoder); }