Skip to content

Commit 374d950

Browse files
authored
Simplify seeking and cursor logic (#543)
1 parent 05a29a5 commit 374d950

File tree

2 files changed

+25
-30
lines changed

2 files changed

+25
-30
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

+16-19
Original file line numberDiff line numberDiff line change
@@ -567,10 +567,8 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
567567

568568
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
569569
std::optional<torch::Tensor> preAllocatedOutputTensor) {
570-
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
571-
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
572-
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
573-
});
570+
AVFrameStream avFrameStream = decodeAVFrame(
571+
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
574572
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
575573
}
576574

@@ -842,7 +840,9 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
842840
// --------------------------------------------------------------------------
843841

844842
void VideoDecoder::setCursorPtsInSeconds(double seconds) {
845-
desiredPtsSeconds_ = seconds;
843+
cursorWasJustSet_ = true;
844+
cursor_ =
845+
secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase);
846846
}
847847

848848
/*
@@ -870,25 +870,25 @@ I P P P I P P P I P P I P P I P
870870
871871
(2) is more efficient than (1) if there is an I frame between x and y.
872872
*/
873-
bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
873+
bool VideoDecoder::canWeAvoidSeeking() const {
874874
int64_t lastDecodedAvFramePts =
875875
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;
876-
if (targetPts < lastDecodedAvFramePts) {
876+
if (cursor_ < lastDecodedAvFramePts) {
877877
// We can never skip a seek if we are seeking backwards.
878878
return false;
879879
}
880-
if (lastDecodedAvFramePts == targetPts) {
880+
if (lastDecodedAvFramePts == cursor_) {
881881
// We are seeking to the exact same frame as we are currently at. Without
882882
// caching we have to rewind back and decode the frame again.
883883
// TODO: https://github.com/pytorch-labs/torchcodec/issues/84 we could
884884
// implement caching.
885885
return false;
886886
}
887887
// We are seeking forwards.
888-
// We can only skip a seek if both lastDecodedAvFramePts and targetPts share
889-
// the same keyframe.
888+
// We can only skip a seek if both lastDecodedAvFramePts and
889+
// cursor_ share the same keyframe.
890890
int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts);
891-
int targetKeyFrameIndex = getKeyFrameIndexForPts(targetPts);
891+
int targetKeyFrameIndex = getKeyFrameIndexForPts(cursor_);
892892
return lastDecodedAvFrameIndex >= 0 && targetKeyFrameIndex >= 0 &&
893893
lastDecodedAvFrameIndex == targetKeyFrameIndex;
894894
}
@@ -900,16 +900,14 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
900900
validateActiveStream(AVMEDIA_TYPE_VIDEO);
901901
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
902902

903-
int64_t desiredPts =
904-
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
905-
streamInfo.discardFramesBeforePts = desiredPts;
906-
907903
decodeStats_.numSeeksAttempted++;
908-
if (canWeAvoidSeeking(desiredPts)) {
904+
if (canWeAvoidSeeking()) {
909905
decodeStats_.numSeeksSkipped++;
910906
return;
911907
}
912908

909+
int64_t desiredPts = cursor_;
910+
913911
// For some encodings like H265, FFMPEG sometimes seeks past the point we
914912
// set as the max_ts. So we use our own index to give it the exact pts of
915913
// the key frame that we want to seek to.
@@ -948,10 +946,9 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
948946

949947
resetDecodeStats();
950948

951-
// Seek if needed.
952-
if (desiredPtsSeconds_.has_value()) {
949+
if (cursorWasJustSet_) {
953950
maybeSeekToBeforeDesiredPts();
954-
desiredPtsSeconds_ = std::nullopt;
951+
cursorWasJustSet_ = false;
955952
}
956953

957954
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];

src/torchcodec/decoders/_core/VideoDecoder.h

+9-11
Original file line numberDiff line numberDiff line change
@@ -332,15 +332,11 @@ class VideoDecoder {
332332
std::vector<FrameInfo> keyFrames;
333333
std::vector<FrameInfo> allFrames;
334334

335-
// The current position of the cursor in the stream, and associated frame
336-
// duration.
335+
// TODO since the decoder is single-stream, these should be decoder fields,
336+
// not streamInfo fields. And they should be defined right next to
337+
// `cursor_`, with joint documentation.
337338
int64_t lastDecodedAvFramePts = 0;
338339
int64_t lastDecodedAvFrameDuration = 0;
339-
// The desired position of the cursor in the stream. We send frames >=
340-
// this pts to the user when they request a frame.
341-
// We update this field if the user requested a seek. This typically
342-
// corresponds to the decoder's desiredPts_ attribute.
343-
int64_t discardFramesBeforePts = INT64_MIN;
344340
VideoStreamOptions videoStreamOptions;
345341

346342
// color-conversion fields. Only one of FilterGraphContext and
@@ -363,7 +359,7 @@ class VideoDecoder {
363359
// DECODING APIS AND RELATED UTILS
364360
// --------------------------------------------------------------------------
365361

366-
bool canWeAvoidSeeking(int64_t targetPts) const;
362+
bool canWeAvoidSeeking() const;
367363

368364
void maybeSeekToBeforeDesiredPts();
369365

@@ -466,9 +462,11 @@ class VideoDecoder {
466462
std::map<int, StreamInfo> streamInfos_;
467463
const int NO_ACTIVE_STREAM = -2;
468464
int activeStreamIndex_ = NO_ACTIVE_STREAM;
469-
// Set when the user wants to seek and stores the desired pts that the user
470-
// wants to seek to.
471-
std::optional<double> desiredPtsSeconds_;
465+
466+
bool cursorWasJustSet_ = false;
467+
// The desired position of the cursor in the stream. We send frames >= this
468+
// pts to the user when they request a frame.
469+
int64_t cursor_ = INT64_MIN;
472470
// Stores various internal decoding stats.
473471
DecodeStats decodeStats_;
474472
// Stores the AVIOContext for the input buffer.

0 commit comments

Comments
 (0)