diff --git a/src/torchcodec/_core/AVIOBytesContext.cpp b/src/torchcodec/_core/AVIOBytesContext.cpp index 3e1481be..af1c8170 100644 --- a/src/torchcodec/_core/AVIOBytesContext.cpp +++ b/src/torchcodec/_core/AVIOBytesContext.cpp @@ -13,7 +13,7 @@ AVIOBytesContext::AVIOBytesContext(const void* data, int64_t dataSize) : dataContext_{static_cast(data), dataSize, 0} { TORCH_CHECK(data != nullptr, "Video data buffer cannot be nullptr!"); TORCH_CHECK(dataSize > 0, "Video data size must be positive"); - createAVIOContext(&read, &seek, &dataContext_); + createAVIOContext(&read, nullptr, &seek, &dataContext_); } // The signature of this function is defined by FFMPEG. @@ -67,4 +67,71 @@ int64_t AVIOBytesContext::seek(void* opaque, int64_t offset, int whence) { return ret; } +AVIOToTensorContext::AVIOToTensorContext() + : dataContext_{ + torch::empty( + {AVIOToTensorContext::INITIAL_TENSOR_SIZE}, + {torch::kUInt8}), + 0} { + createAVIOContext(nullptr, &write, &seek, &dataContext_); +} + +// The signature of this function is defined by FFMPEG. +int AVIOToTensorContext::write(void* opaque, const uint8_t* buf, int buf_size) { + auto dataContext = static_cast(opaque); + + int64_t bufSize = static_cast(buf_size); + if (dataContext->current + bufSize > dataContext->outputTensor.numel()) { + TORCH_CHECK( + dataContext->outputTensor.numel() * 2 <= + AVIOToTensorContext::MAX_TENSOR_SIZE, + "We tried to allocate an output encoded tensor larger than ", + AVIOToTensorContext::MAX_TENSOR_SIZE, + " bytes. If you think this should be supported, please report."); + + // We double the size of the outpout tensor. Calling cat() may not be the + // most efficient, but it's simple. + dataContext->outputTensor = + torch::cat({dataContext->outputTensor, dataContext->outputTensor}); + } + + TORCH_CHECK( + dataContext->current + bufSize <= dataContext->outputTensor.numel(), + "Re-allocation of the output tensor didn't work. ", + "This should not happen, please report on TorchCodec bug tracker"); + + uint8_t* outputTensorData = dataContext->outputTensor.data_ptr(); + std::memcpy(outputTensorData + dataContext->current, buf, bufSize); + dataContext->current += bufSize; + return buf_size; +} + +// The signature of this function is defined by FFMPEG. +// Note: This `seek()` implementation is very similar to that of +// AVIOBytesContext. We could consider merging both classes, or do some kind of +// refac, but this doesn't seem worth it ATM. +int64_t AVIOToTensorContext::seek(void* opaque, int64_t offset, int whence) { + auto dataContext = static_cast(opaque); + int64_t ret = -1; + + switch (whence) { + case AVSEEK_SIZE: + ret = dataContext->outputTensor.numel(); + break; + case SEEK_SET: + dataContext->current = offset; + ret = offset; + break; + default: + break; + } + + return ret; +} + +torch::Tensor AVIOToTensorContext::getOutputTensor() { + return dataContext_.outputTensor.narrow( + /*dim=*/0, /*start=*/0, /*length=*/dataContext_.current); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOBytesContext.h b/src/torchcodec/_core/AVIOBytesContext.h index c4fb7185..fad12b61 100644 --- a/src/torchcodec/_core/AVIOBytesContext.h +++ b/src/torchcodec/_core/AVIOBytesContext.h @@ -6,12 +6,13 @@ #pragma once +#include #include "src/torchcodec/_core/AVIOContextHolder.h" namespace facebook::torchcodec { -// Enables users to pass in the entire video as bytes. Our read and seek -// functions then traverse the bytes in memory. +// For Decoding: enables users to pass in the entire video or audio as bytes. +// Our read and seek functions then traverse the bytes in memory. class AVIOBytesContext : public AVIOContextHolder { public: explicit AVIOBytesContext(const void* data, int64_t dataSize); @@ -29,4 +30,25 @@ class AVIOBytesContext : public AVIOContextHolder { DataContext dataContext_; }; +// For Encoding: used to encode into an output uint8 (bytes) tensor. +class AVIOToTensorContext : public AVIOContextHolder { + public: + explicit AVIOToTensorContext(); + torch::Tensor getOutputTensor(); + + private: + struct DataContext { + torch::Tensor outputTensor; + int64_t current; + }; + + static constexpr int64_t INITIAL_TENSOR_SIZE = 10'000'000; // 10MB + static constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB + static int write(void* opaque, const uint8_t* buf, int buf_size); + // We need to expose seek() for some formats like mp3. + static int64_t seek(void* opaque, int64_t offset, int whence); + + DataContext dataContext_; +}; + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/AVIOContextHolder.cpp b/src/torchcodec/_core/AVIOContextHolder.cpp index f0ef095f..e0462c28 100644 --- a/src/torchcodec/_core/AVIOContextHolder.cpp +++ b/src/torchcodec/_core/AVIOContextHolder.cpp @@ -11,6 +11,7 @@ namespace facebook::torchcodec { void AVIOContextHolder::createAVIOContext( AVIOReadFunction read, + AVIOWriteFunction write, AVIOSeekFunction seek, void* heldData, int bufferSize) { @@ -22,13 +23,17 @@ void AVIOContextHolder::createAVIOContext( buffer != nullptr, "Failed to allocate buffer of size " + std::to_string(bufferSize)); - avioContext_.reset(avio_alloc_context( + TORCH_CHECK( + (seek != nullptr) && ((write != nullptr) ^ (read != nullptr)), + "seek method must be defined, and either write or read must be defined. " + "But not both!") + avioContext_.reset(avioAllocContext( buffer, bufferSize, - 0, + /*write_flag=*/write != nullptr, heldData, read, - nullptr, // write function; not supported yet + write, seek)); if (!avioContext_) { diff --git a/src/torchcodec/_core/AVIOContextHolder.h b/src/torchcodec/_core/AVIOContextHolder.h index 3b094c26..93fe4930 100644 --- a/src/torchcodec/_core/AVIOContextHolder.h +++ b/src/torchcodec/_core/AVIOContextHolder.h @@ -19,9 +19,9 @@ namespace facebook::torchcodec { // freed. // 2. It is a base class for AVIOContext specializations. When specializing a // AVIOContext, we need to provide four things: -// 1. A read callback function. -// 2. A seek callback function. -// 3. A write callback function. (Not supported yet; it's for encoding.) +// 1. A read callback function, for decoding. +// 2. A seek callback function, for decoding and encoding. +// 3. A write callback function, for encoding. // 4. A pointer to some context object that has the same lifetime as the // AVIOContext itself. This context object holds the custom state that // tracks the custom behavior of reading, seeking and writing. It is @@ -44,13 +44,10 @@ class AVIOContextHolder { // enforced by having a pure virtual methods, but we don't have any.) AVIOContextHolder() = default; - // These signatures are defined by FFmpeg. - using AVIOReadFunction = int (*)(void*, uint8_t*, int); - using AVIOSeekFunction = int64_t (*)(void*, int64_t, int); - // Deriving classes should call this function in their constructor. void createAVIOContext( AVIOReadFunction read, + AVIOWriteFunction write, AVIOSeekFunction seek, void* heldData, int bufferSize = defaultBufferSize); diff --git a/src/torchcodec/_core/AVIOFileLikeContext.cpp b/src/torchcodec/_core/AVIOFileLikeContext.cpp index 4a905b93..5497f89b 100644 --- a/src/torchcodec/_core/AVIOFileLikeContext.cpp +++ b/src/torchcodec/_core/AVIOFileLikeContext.cpp @@ -23,7 +23,7 @@ AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike) py::hasattr(fileLike, "seek"), "File like object must implement a seek method."); } - createAVIOContext(&read, &seek, &fileLike_); + createAVIOContext(&read, nullptr, &seek, &fileLike_); } int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) { diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 2c1fffe3..ecec520b 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -65,8 +65,9 @@ function(make_torchcodec_libraries set(decoder_library_name "libtorchcodec_decoder${ffmpeg_major_version}") set(decoder_sources AVIOContextHolder.cpp + AVIOBytesContext.cpp FFMPEGCommon.cpp - DeviceInterface.cpp + DeviceInterface.cpp SingleStreamDecoder.cpp # TODO: lib name should probably not be "*_decoder*" now that it also # contains an encoder diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 08af3402..114e8600 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -1,5 +1,6 @@ #include +#include "src/torchcodec/_core/AVIOBytesContext.h" #include "src/torchcodec/_core/Encoder.h" #include "torch/types.h" @@ -7,6 +8,17 @@ namespace facebook::torchcodec { namespace { +torch::Tensor validateWf(torch::Tensor wf) { + TORCH_CHECK( + wf.dtype() == torch::kFloat32, + "waveform must have float32 dtype, got ", + wf.dtype()); + // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed + // planar (fltp). + TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim()); + return wf; +} + void validateSampleRate(const AVCodec& avCodec, int sampleRate) { if (avCodec.supported_samplerates == nullptr) { return; @@ -80,20 +92,12 @@ AudioEncoder::AudioEncoder( int sampleRate, std::string_view fileName, std::optional bitRate) - : wf_(wf) { - TORCH_CHECK( - wf_.dtype() == torch::kFloat32, - "waveform must have float32 dtype, got ", - wf_.dtype()); - // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed - // planar (fltp). - TORCH_CHECK( - wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim()); - + : wf_(validateWf(wf)) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; - auto status = avformat_alloc_output_context2( + int status = avformat_alloc_output_context2( &avFormatContext, nullptr, nullptr, fileName.data()); + TORCH_CHECK( avFormatContext != nullptr, "Couldn't allocate AVFormatContext. ", @@ -101,17 +105,42 @@ AudioEncoder::AudioEncoder( getFFMPEGErrorStringFromErrorCode(status)); avFormatContext_.reset(avFormatContext); - // TODO-ENCODING: Should also support encoding into bytes (use - // AVIOBytesContext) - TORCH_CHECK( - !(avFormatContext->oformat->flags & AVFMT_NOFILE), - "AVFMT_NOFILE is set. We only support writing to a file."); status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); TORCH_CHECK( status >= 0, "avio_open failed: ", getFFMPEGErrorStringFromErrorCode(status)); + initializeEncoder(sampleRate, bitRate); +} + +AudioEncoder::AudioEncoder( + const torch::Tensor wf, + int sampleRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + std::optional bitRate) + : wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) { + setFFmpegLogLevel(); + AVFormatContext* avFormatContext = nullptr; + int status = avformat_alloc_output_context2( + &avFormatContext, nullptr, formatName.data(), nullptr); + + TORCH_CHECK( + avFormatContext != nullptr, + "Couldn't allocate AVFormatContext. ", + "Check the desired extension? ", + getFFMPEGErrorStringFromErrorCode(status)); + avFormatContext_.reset(avFormatContext); + + avFormatContext_->pb = avioContextHolder_->getAVIOContext(); + + initializeEncoder(sampleRate, bitRate); +} + +void AudioEncoder::initializeEncoder( + int sampleRate, + std::optional bitRate) { // We use the AVFormatContext's default codec for that // specific format/container. const AVCodec* avCodec = @@ -150,7 +179,7 @@ AudioEncoder::AudioEncoder( setDefaultChannelLayout(avCodecContext_, numChannels); - status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); + int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr); TORCH_CHECK( status == AVSUCCESS, "avcodec_open2 failed: ", @@ -170,7 +199,18 @@ AudioEncoder::AudioEncoder( streamIndex_ = avStream->index; } +torch::Tensor AudioEncoder::encodeToTensor() { + TORCH_CHECK( + avioContextHolder_ != nullptr, + "Cannot encode to tensor, avio context doesn't exist."); + encode(); + return avioContextHolder_->getOutputTensor(); +} + void AudioEncoder::encode() { + // TODO-ENCODING: Need to check, but consecutive calls to encode() are + // probably invalid. We can address this once we (re)design the public and + // private encoding APIs. UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame."); // Default to 256 like in torchaudio diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 3e1abeac..17f09d59 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -1,5 +1,6 @@ #pragma once #include +#include "src/torchcodec/_core/AVIOBytesContext.h" #include "src/torchcodec/_core/FFMPEGCommon.h" namespace facebook::torchcodec { @@ -21,9 +22,19 @@ class AudioEncoder { int sampleRate, std::string_view fileName, std::optional bitRate = std::nullopt); + AudioEncoder( + const torch::Tensor wf, + int sampleRate, + std::string_view formatName, + std::unique_ptr avioContextHolder, + std::optional bitRate = std::nullopt); void encode(); + torch::Tensor encodeToTensor(); private: + void initializeEncoder( + int sampleRate, + std::optional bitRate = std::nullopt); void encodeInnerLoop( AutoAVPacket& autoAVPacket, const UniqueAVFrame& srcAVFrame); @@ -35,5 +46,8 @@ class AudioEncoder { UniqueSwrContext swrContext_; const torch::Tensor wf_; + + // Stores the AVIOContext for the output tensor buffer. + std::unique_ptr avioContextHolder_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 19722108..a8da49e8 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -261,4 +261,27 @@ void setFFmpegLogLevel() { av_log_set_level(logLevel); } +AVIOContext* avioAllocContext( + uint8_t* buffer, + int buffer_size, + int write_flag, + void* opaque, + AVIOReadFunction read_packet, + AVIOWriteFunction write_packet, + AVIOSeekFunction seek) { + return avio_alloc_context( + buffer, + buffer_size, + write_flag, + opaque, + read_packet, +// The buf parameter of the write function is not const before FFmpeg 7. +#if LIBAVFILTER_VERSION_MAJOR >= 10 // FFmpeg >= 7 + write_packet, +#else + reinterpret_cast(write_packet), +#endif + seek); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 8c4abd13..308dec48 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -177,4 +177,19 @@ bool canSwsScaleHandleUnalignedData(); void setFFmpegLogLevel(); +// These signatures are defined by FFmpeg. +using AVIOReadFunction = int (*)(void*, uint8_t*, int); +using AVIOWriteFunction = int (*)(void*, const uint8_t*, int); // FFmpeg >= 7 +using AVIOWriteFunctionOld = int (*)(void*, uint8_t*, int); // FFmpeg < 7 +using AVIOSeekFunction = int64_t (*)(void*, int64_t, int); + +AVIOContext* avioAllocContext( + uint8_t* buffer, + int buffer_size, + int write_flag, + void* opaque, + AVIOReadFunction read_packet, + AVIOWriteFunction write_packet, + AVIOSeekFunction seek); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index 4be8a7de..77fc7b85 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -18,12 +18,12 @@ _test_frame_pts_equality, add_audio_stream, add_video_stream, - create_audio_encoder, create_from_bytes, create_from_file, create_from_file_like, create_from_tensor, - encode_audio, + encode_audio_to_file, + encode_audio_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index e0f3e9b9..2f470617 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -29,8 +29,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { "torchcodec._core.ops", "//pytorch/torchcodec:torchcodec"); m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); m.def( - "create_audio_encoder(Tensor wf, int sample_rate, str filename, int? bit_rate=None) -> Tensor"); - m.def("encode_audio(Tensor(a!) encoder) -> ()"); + "encode_audio_to_file(Tensor wf, int sample_rate, str filename, int? bit_rate=None) -> ()"); + m.def( + "encode_audio_to_tensor(Tensor wf, int sample_rate, str format, int? bit_rate=None) -> Tensor"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def("_convert_to_tensor(int decoder_ptr) -> Tensor"); @@ -153,6 +154,16 @@ std::string mapToJson(const std::map& metadataMap) { return ss.str(); } +int validateSampleRate(int64_t sampleRate) { + TORCH_CHECK( + sampleRate <= std::numeric_limits::max(), + "sample_rate=", + sampleRate, + " is too large to be cast to an int."); + + return static_cast(sampleRate); +} + } // namespace // ============================== @@ -374,44 +385,30 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio( return makeOpsAudioFramesOutput(result); } -at::Tensor wrapAudioEncoderPointerToTensor( - std::unique_ptr uniqueAudioEncoder) { - AudioEncoder* encoder = uniqueAudioEncoder.release(); - - auto deleter = [encoder](void*) { delete encoder; }; - at::Tensor tensor = - at::from_blob(encoder, {sizeof(AudioEncoder*)}, deleter, {at::kLong}); - auto encoder_ = static_cast(tensor.mutable_data_ptr()); - TORCH_CHECK_EQ(encoder_, encoder) << "AudioEncoder=" << encoder_; - return tensor; -} - -AudioEncoder* unwrapTensorToGetAudioEncoder(at::Tensor& tensor) { - TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); - void* buffer = tensor.mutable_data_ptr(); - AudioEncoder* encoder = static_cast(buffer); - return encoder; -} - -at::Tensor create_audio_encoder( +void encode_audio_to_file( const at::Tensor wf, int64_t sample_rate, std::string_view file_name, std::optional bit_rate = std::nullopt) { - TORCH_CHECK( - sample_rate <= std::numeric_limits::max(), - "sample_rate=", - sample_rate, - " is too large to be cast to an int."); - std::unique_ptr uniqueAudioEncoder = - std::make_unique( - wf, static_cast(sample_rate), file_name, bit_rate); - return wrapAudioEncoderPointerToTensor(std::move(uniqueAudioEncoder)); + AudioEncoder(wf, validateSampleRate(sample_rate), file_name, bit_rate) + .encode(); } -void encode_audio(at::Tensor& encoder) { - auto encoder_ = unwrapTensorToGetAudioEncoder(encoder); - encoder_->encode(); +// TODO-ENCODING is "format" a good parameter name?? It kinda conflicts with +// "sample_format" which we may eventually want to expose. +at::Tensor encode_audio_to_tensor( + const at::Tensor wf, + int64_t sample_rate, + std::string_view format, + std::optional bit_rate = std::nullopt) { + auto avioContextHolder = std::make_unique(); + return AudioEncoder( + wf, + validateSampleRate(sample_rate), + format, + std::move(avioContextHolder), + bit_rate) + .encodeToTensor(); } // For testing only. We need to implement this operation as a core library @@ -647,7 +644,6 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) { TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); - m.impl("create_audio_encoder", &create_audio_encoder); m.impl("create_from_tensor", &create_from_tensor); m.impl("_convert_to_tensor", &_convert_to_tensor); m.impl( @@ -655,7 +651,8 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { } TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { - m.impl("encode_audio", &encode_audio); + m.impl("encode_audio_to_file", &encode_audio_to_file); + m.impl("encode_audio_to_tensor", &encode_audio_to_tensor); m.impl("seek_to_pts", &seek_to_pts); m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 166ebe55..e9b4faec 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -91,11 +91,11 @@ def load_torchcodec_shared_libraries(): create_from_file = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_file.default ) -create_audio_encoder = torch._dynamo.disallow_in_graph( - torch.ops.torchcodec_ns.create_audio_encoder.default +encode_audio_to_file = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.encode_audio_to_file.default ) -encode_audio = torch._dynamo.disallow_in_graph( - torch.ops.torchcodec_ns.encode_audio.default +encode_audio_to_tensor = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns.encode_audio_to_tensor.default ) create_from_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_tensor.default @@ -161,15 +161,17 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch. return torch.empty([], dtype=torch.long) -@register_fake("torchcodec_ns::create_audio_encoder") -def create_audio_encoder_abstract( +@register_fake("torchcodec_ns::encode_audio_to_file") +def encode_audio_to_file_abstract( wf: torch.Tensor, sample_rate: int, filename: str, bit_rate: Optional[int] = None -) -> torch.Tensor: - return torch.empty([], dtype=torch.long) +) -> None: + return -@register_fake("torchcodec_ns::encode_audio") -def encode_audio_abstract(encoder: torch.Tensor) -> torch.Tensor: +@register_fake("torchcodec_ns::encode_audio_to_tensor") +def encode_audio_to_tensor_abstract( + wf: torch.Tensor, sample_rate: int, format: str, bit_rate: Optional[int] = None +) -> torch.Tensor: return torch.empty([], dtype=torch.long) diff --git a/test/test_ops.py b/test/test_ops.py index 158e3d08..b41be538 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -22,12 +22,12 @@ _test_frame_pts_equality, add_audio_stream, add_video_stream, - create_audio_encoder, create_from_bytes, create_from_file, create_from_file_like, create_from_tensor, - encode_audio, + encode_audio_to_file, + encode_audio_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, get_frame_at_pts, @@ -1117,11 +1117,15 @@ def seek(self, offset: int, whence: int) -> bytes: class TestAudioEncoderOps: def decode(self, source) -> torch.Tensor: - if isinstance(source, TestContainerFile): - source = str(source.path) + if isinstance(source, torch.Tensor): + decoder = create_from_tensor(source, seek_mode="approximate") else: - source = str(source) - decoder = create_from_file(source, seek_mode="approximate") + if isinstance(source, TestContainerFile): + source = str(source.path) + else: + source = str(source) + decoder = create_from_file(source, seek_mode="approximate") + add_audio_stream(decoder) frames, *_ = get_frames_by_pts_in_range_audio( decoder, start_seconds=0, stop_seconds=None @@ -1133,41 +1137,44 @@ def test_bad_input(self, tmp_path): valid_output_file = str(tmp_path / ".mp3") with pytest.raises(RuntimeError, match="must have float32 dtype, got int"): - create_audio_encoder( + encode_audio_to_file( wf=torch.arange(10, dtype=torch.int), sample_rate=10, filename=valid_output_file, ) with pytest.raises(RuntimeError, match="must have 2 dimensions, got 1"): - create_audio_encoder( + encode_audio_to_file( wf=torch.rand(3), sample_rate=10, filename=valid_output_file ) with pytest.raises(RuntimeError, match="No such file or directory"): - create_audio_encoder( + encode_audio_to_file( wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3" ) with pytest.raises(RuntimeError, match="Check the desired extension"): - create_audio_encoder( + encode_audio_to_file( wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension" ) with pytest.raises(RuntimeError, match="invalid sample rate=10"): - create_audio_encoder( + encode_audio_to_file( wf=self.decode(NASA_AUDIO_MP3), sample_rate=10, filename=valid_output_file, ) with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"): - create_audio_encoder( + encode_audio_to_file( wf=self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate, filename=valid_output_file, bit_rate=-1, # bad ) + @pytest.mark.parametrize( + "encode_method", (encode_audio_to_file, encode_audio_to_tensor) + ) @pytest.mark.parametrize("output_format", ("wav", "flac")) - def test_round_trip(self, output_format, tmp_path): + def test_round_trip(self, encode_method, output_format, tmp_path): # Check that decode(encode(samples)) == samples on lossless formats if get_ffmpeg_major_version() == 4 and output_format == "wav": @@ -1176,15 +1183,24 @@ def test_round_trip(self, output_format, tmp_path): asset = NASA_AUDIO_MP3 source_samples = self.decode(asset) - encoded_path = tmp_path / f"output.{output_format}" - encoder = create_audio_encoder( - wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path) - ) - encode_audio(encoder) + if encode_method is encode_audio_to_file: + encoded_path = tmp_path / f"output.{output_format}" + encode_audio_to_file( + wf=source_samples, + sample_rate=asset.sample_rate, + filename=str(encoded_path), + ) + encoded_source = encoded_path + else: + encoded_source = encode_audio_to_tensor( + wf=source_samples, sample_rate=asset.sample_rate, format=output_format + ) + assert encoded_source.dtype == torch.uint8 + assert encoded_source.ndim == 1 rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) torch.testing.assert_close( - self.decode(encoded_path), source_samples, rtol=rtol, atol=atol + self.decode(encoded_source), source_samples, rtol=rtol, atol=atol ) @pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") @@ -1199,8 +1215,6 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path): pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}" - encoded_by_us = tmp_path / f"our_output.{output_format}" - subprocess.run( ["ffmpeg", "-i", str(asset.path)] + (["-b:a", f"{bit_rate}"] if bit_rate is not None else []) @@ -1211,13 +1225,13 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path): check=True, ) - encoder = create_audio_encoder( + encoded_by_us = tmp_path / f"our_output.{output_format}" + encode_audio_to_file( wf=self.decode(asset), sample_rate=asset.sample_rate, filename=str(encoded_by_us), bit_rate=bit_rate, ) - encode_audio(encoder) rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None) torch.testing.assert_close( @@ -1227,6 +1241,49 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path): atol=atol, ) + @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) + @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) + @pytest.mark.parametrize("output_format", ("mp3", "wav", "flac")) + def test_tensor_against_file(self, asset, bit_rate, output_format, tmp_path): + if get_ffmpeg_major_version() == 4 and output_format == "wav": + pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") + + encoded_file = tmp_path / f"our_output.{output_format}" + encode_audio_to_file( + wf=self.decode(asset), + sample_rate=asset.sample_rate, + filename=str(encoded_file), + bit_rate=bit_rate, + ) + + encoded_tensor = encode_audio_to_tensor( + wf=self.decode(asset), + sample_rate=asset.sample_rate, + format=output_format, + bit_rate=bit_rate, + ) + + torch.testing.assert_close( + self.decode(encoded_file), self.decode(encoded_tensor) + ) + + def test_encode_to_tensor_long_output(self): + # Check that we support re-allocating the output tensor when the encoded + # data is large. + samples = torch.rand(1, int(1e7)) + encoded_tensor = encode_audio_to_tensor( + wf=samples, + sample_rate=16_000, + format="flac", + bit_rate=44_000, + ) + # Note: this should be in sync with its C++ counterpart for the test to + # be meaningful. + INITIAL_TENSOR_SIZE = 10_000_000 + assert encoded_tensor.numel() > INITIAL_TENSOR_SIZE + + torch.testing.assert_close(self.decode(encoded_tensor), samples) + if __name__ == "__main__": pytest.main()