11
22from typing import Type , Dict , Optional , List , Tuple , Any , Union
3- from pydantic import BaseModel , confloat
3+ from pydantic import BaseModel , confloat , Field
4+ from label_studio_sdk .label_interface .objects import PredictionValue
5+ from typing import Union , List
46
5- from label_studio_sdk .objects import PredictionValue
7+
8+ # one or multiple predictions per task
9+ SingleTaskPredictions = Union [List [PredictionValue ], PredictionValue ]
610
711
812class ModelResponse (BaseModel ):
913 """
1014 """
1115 model_version : Optional [str ] = None
12- predictions : List [PredictionValue ]
16+ predictions : List [SingleTaskPredictions ]
1317
1418 def has_model_version (self ) -> bool :
1519 return bool (self .model_version )
@@ -18,21 +22,16 @@ def update_predictions_version(self) -> None:
1822 """
1923 """
2024 for prediction in self .predictions :
21- if not prediction .model_version :
22- prediction .model_version = self .model_version
25+ if isinstance (prediction , PredictionValue ):
26+ prediction = [prediction ]
27+ for p in prediction :
28+ if not p .model_version :
29+ p .model_version = self .model_version
2330
2431 def set_version (self , version : str ) -> None :
2532 """
2633 """
2734 self .model_version = version
2835 # Set the version for each prediction
2936 self .update_predictions_version ()
30-
31- def serialize (self ):
32- """
33- """
34- return {
35- "model_version" : self .model_version ,
36- "predictions" : [ p .serialize () for p in self .predictions ]
37- }
3837
0 commit comments