@@ -608,6 +608,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
608608 // we have scanned all packets and sorted by pts.
609609 FrameInfo frameInfo = {packet->pts };
610610 if (packet->flags & AV_PKT_FLAG_KEY) {
611+ frameInfo.isKeyFrame = true ;
611612 streamInfos_[streamIndex].keyFrames .push_back (frameInfo);
612613 }
613614 streamInfos_[streamIndex].allFrames .push_back (frameInfo);
@@ -658,25 +659,23 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
658659 return frameInfo1.pts < frameInfo2.pts ;
659660 });
660661
661- size_t keyIndex = 0 ;
662+ size_t keyFrameIndex = 0 ;
662663 for (size_t i = 0 ; i < streamInfo.allFrames .size (); ++i) {
663664 streamInfo.allFrames [i].frameIndex = i;
664-
665- // For correctly encoded files, we shouldn't need to ensure that keyIndex
666- // is less than the number of key frames. That is, the relationship
667- // between the frames in allFrames and keyFrames should be such that
668- // keyIndex is always a valid index into keyFrames. But we're being
669- // defensive in case we encounter incorrectly encoded files.
670- if (keyIndex < streamInfo.keyFrames .size () &&
671- streamInfo.keyFrames [keyIndex].pts == streamInfo.allFrames [i].pts ) {
672- streamInfo.keyFrames [keyIndex].frameIndex = i;
673- ++keyIndex;
665+ if (streamInfo.allFrames [i].isKeyFrame ) {
666+ TORCH_CHECK (
667+ keyFrameIndex < streamInfo.keyFrames .size (),
668+ " The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec." );
669+ streamInfo.keyFrames [keyFrameIndex].frameIndex = i;
670+ ++keyFrameIndex;
674671 }
675-
676672 if (i + 1 < streamInfo.allFrames .size ()) {
677673 streamInfo.allFrames [i].nextPts = streamInfo.allFrames [i + 1 ].pts ;
678674 }
679675 }
676+ TORCH_CHECK (
677+ keyFrameIndex == streamInfo.keyFrames .size (),
678+ " The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec." );
680679 }
681680
682681 scannedAllStreams_ = true ;
0 commit comments