Skip to content

Commit 61c29f8

Browse files
authored
[PLT-0] Fix VideoClassificationText (#2044)
Co-authored-by: paulnoirel <87332996+paulnoirel@users.noreply.github.com>
1 parent 78cfe93 commit 61c29f8

3 files changed

Lines changed: 174 additions & 3 deletions

File tree

libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,61 @@ def from_common(
209209
)
210210

211211

212+
class NDVideoTextAnswer(BaseModel):
213+
value: str
214+
frames: List[Dict[str, int]]
215+
216+
model_config = ConfigDict(populate_by_name=True)
217+
218+
219+
class NDVideoText(BaseModel):
220+
"""Video text classification with per-segment text values and frame ranges.
221+
222+
Produces NDJSON like:
223+
{"name": "...", "answer": [{"value": "text", "frames": [{"start": 1, "end": 5}]}], ...}
224+
"""
225+
226+
name: Optional[str] = None
227+
schema_id: Optional[str] = Field(default=None, alias="schemaId")
228+
answer: List[NDVideoTextAnswer]
229+
data_row: DataRow = Field(alias="dataRow")
230+
231+
model_config = ConfigDict(populate_by_name=True)
232+
233+
@model_validator(mode="after")
234+
def must_set_one(self):
235+
if not self.name and not self.schema_id:
236+
raise ValueError("Schema id or name are not set. Set either one.")
237+
return self
238+
239+
@model_serializer(mode="wrap")
240+
def serialize_model(self, handler):
241+
res = handler(self)
242+
if "name" in res and res["name"] is None:
243+
res.pop("name")
244+
if "schemaId" in res and res["schemaId"] is None:
245+
res.pop("schemaId")
246+
return res
247+
248+
@classmethod
249+
def from_video_text_group(
250+
cls,
251+
annotation_group: List["VideoClassificationAnnotation"],
252+
frame_ranges_by_text: Dict[str, List[Dict[str, int]]],
253+
data: "GenericDataRowData",
254+
) -> "NDVideoText":
255+
first = annotation_group[0]
256+
return cls(
257+
name=first.name,
258+
schema_id=first.feature_schema_id,
259+
data_row=DataRow(id=data.uid, global_key=data.global_key),
260+
answer=[
261+
NDVideoTextAnswer(value=text_val, frames=ranges)
262+
for text_val, ranges in frame_ranges_by_text.items()
263+
],
264+
)
265+
266+
212267
class NDPromptTextSubclass(NDAnswer):
213268
answer: str
214269

@@ -517,6 +572,7 @@ def from_common(
517572
NDRadioSubclass.model_rebuild()
518573
NDRadio.model_rebuild()
519574
NDText.model_rebuild()
575+
NDVideoText.model_rebuild()
520576
NDPromptText.model_rebuild()
521577
NDTextSubclass.model_rebuild()
522578

libs/labelbox/src/labelbox/data/serialization/ndjson/label.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from .temporal import create_temporal_ndjson_classifications
3333
from labelbox.types import DocumentRectangle, DocumentEntity
34+
from ...annotation_types.classification.classification import Text
3435
from .classification import (
3536
NDChecklistSubclass,
3637
NDClassification,
@@ -39,6 +40,7 @@
3940
NDPromptClassificationType,
4041
NDPromptText,
4142
NDRadioSubclass,
43+
NDVideoText,
4244
)
4345
from .metric import NDConfusionMatrixMetric, NDMetricAnnotation, NDScalarMetric
4446
from .mmc import NDMessageTask
@@ -61,6 +63,7 @@
6163
NDRelationship,
6264
NDPromptText,
6365
NDMessageTask,
66+
NDVideoText,
6467
]
6568

6669

@@ -142,18 +145,43 @@ def _create_video_annotations(
142145
yield NDObject.from_common(annotation=annot, data=label.data)
143146

144147
for annotation_group in video_annotations.values():
145-
segment_frame_ranges = cls._get_segment_frame_ranges(
146-
annotation_group
147-
)
148148
if isinstance(annotation_group[0], VideoClassificationAnnotation):
149149
annotation = annotation_group[0]
150+
151+
if isinstance(annotation.value, Text):
152+
by_text = defaultdict(list)
153+
for ann in annotation_group:
154+
by_text[ann.value.answer].append(ann)
155+
156+
frame_ranges_by_text = {}
157+
for text_val, anns in sorted(
158+
by_text.items(),
159+
key=lambda x: min(a.frame for a in x[1]),
160+
):
161+
ranges = [
162+
{"start": s, "end": e}
163+
for s, e in cls._get_segment_frame_ranges(anns)
164+
]
165+
frame_ranges_by_text[text_val] = ranges
166+
167+
yield NDVideoText.from_video_text_group(
168+
annotation_group, frame_ranges_by_text, label.data
169+
)
170+
continue
171+
172+
segment_frame_ranges = cls._get_segment_frame_ranges(
173+
annotation_group
174+
)
150175
frames_data = []
151176
for frames in segment_frame_ranges:
152177
frames_data.append({"start": frames[0], "end": frames[-1]})
153178
annotation.extra.update({"frames": frames_data})
154179
yield NDClassification.from_common(annotation, label.data)
155180

156181
elif isinstance(annotation_group[0], VideoObjectAnnotation):
182+
segment_frame_ranges = cls._get_segment_frame_ranges(
183+
annotation_group
184+
)
157185
segments = []
158186
for start_frame, end_frame in segment_frame_ranges:
159187
segment = []

libs/labelbox/tests/data/serialization/ndjson/test_video.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,93 @@ def test_video_classification_global_subclassifications():
635635
assert res == [expected_first_annotation, expected_second_annotation]
636636

637637

638+
def test_video_classification_text_produces_ndjson_with_frames():
639+
"""VideoClassificationAnnotation + Text serializes with answer as a list of {value, frames}."""
640+
label = Label(
641+
data=GenericDataRowData(global_key="sample-video-text"),
642+
annotations=[
643+
VideoClassificationAnnotation(
644+
name="free_text",
645+
frame=9,
646+
segment_index=0,
647+
value=Text(answer="Looks like a hungry big cat"),
648+
),
649+
VideoClassificationAnnotation(
650+
name="free_text",
651+
frame=15,
652+
segment_index=0,
653+
value=Text(answer="Looks like a hungry big cat"),
654+
),
655+
VideoClassificationAnnotation(
656+
name="free_text",
657+
frame=40,
658+
segment_index=1,
659+
value=Text(answer="It's getting closer!"),
660+
),
661+
VideoClassificationAnnotation(
662+
name="free_text",
663+
frame=50,
664+
segment_index=1,
665+
value=Text(answer="It's getting closer!"),
666+
),
667+
],
668+
)
669+
serialized = list(NDJsonConverter.serialize([label]))
670+
free_text_rows = [r for r in serialized if r.get("name") == "free_text"]
671+
assert len(free_text_rows) == 1
672+
673+
row = free_text_rows[0]
674+
assert row["dataRow"] == {"globalKey": "sample-video-text"}
675+
assert "answer" in row
676+
answer = row["answer"]
677+
assert isinstance(answer, list)
678+
assert len(answer) == 2
679+
680+
by_value = {a["value"]: a for a in answer}
681+
assert "Looks like a hungry big cat" in by_value
682+
assert "It's getting closer!" in by_value
683+
assert by_value["Looks like a hungry big cat"]["frames"] == [
684+
{"start": 9, "end": 15}
685+
]
686+
assert by_value["It's getting closer!"]["frames"] == [
687+
{"start": 40, "end": 50}
688+
]
689+
690+
691+
def test_video_classification_text_single_text_across_frames():
692+
"""VideoClassificationAnnotation + Text with same text across all frames."""
693+
label = Label(
694+
data=GenericDataRowData(global_key="sample-video-single-text"),
695+
annotations=[
696+
VideoClassificationAnnotation(
697+
name="free_text_per_frame",
698+
frame=9,
699+
segment_index=0,
700+
value=Text(answer="sample text"),
701+
),
702+
VideoClassificationAnnotation(
703+
name="free_text_per_frame",
704+
frame=15,
705+
segment_index=0,
706+
value=Text(answer="sample text"),
707+
),
708+
],
709+
)
710+
serialized = list(NDJsonConverter.serialize([label]))
711+
free_text_rows = [
712+
r for r in serialized if r.get("name") == "free_text_per_frame"
713+
]
714+
assert len(free_text_rows) == 1
715+
716+
row = free_text_rows[0]
717+
assert row["dataRow"] == {"globalKey": "sample-video-single-text"}
718+
answer = row["answer"]
719+
assert isinstance(answer, list)
720+
assert len(answer) == 1
721+
assert answer[0]["value"] == "sample text"
722+
assert answer[0]["frames"] == [{"start": 9, "end": 15}]
723+
724+
638725
def test_video_classification_nesting_bbox():
639726
bbox_annotation = [
640727
VideoObjectAnnotation(

0 commit comments

Comments
 (0)