1414limitations under the License.
1515"""
1616import re
17+ import os
1718
1819from ...representation import CharacterRecognitionPrediction
1920from ...utils import UnsupportedPackage , extract_image_representations
@@ -121,7 +122,7 @@ def _initialize_pipeline(self, config):
121122 except ImportError as import_error :
122123 UnsupportedPackage ("openvino_genai" , import_error .msg ).raise_error (self .__class__ .__name__ )
123124
124- model_dir = config . get ( "_models" , [ None ])[ 0 ]
125+ model_dir = get_model_dir ( config )
125126 device = config .get ("_device" , "CPU" )
126127 pipeline = ov_genai .WhisperPipeline (str (model_dir ), device = device )
127128 return pipeline
@@ -169,7 +170,7 @@ def _initialize_pipeline(self, config):
169170 UnsupportedPackage ("optimum.intel.openvino" , import_error .msg ).raise_error (self .__class__ .__name__ )
170171
171172 device = config .get ("_device" , "CPU" )
172- model_dir = config . get ( "_models" , [ None ])[ 0 ]
173+ model_dir = get_model_dir ( config )
173174 ov_model = OVModelForSpeechSeq2Seq .from_pretrained (str (model_dir )).to (device )
174175 ov_processor = AutoProcessor .from_pretrained (str (model_dir ))
175176
@@ -184,3 +185,12 @@ def _get_predictions(self, data, identifiers, input_meta):
184185 sampling_rate = input_meta [0 ].get ("sample_rate" )
185186 sample = {"path" : identifiers [0 ], "array" : data [0 ], "sampling_rate" : sampling_rate }
186187 return self .pipeline (sample , return_timestamps = True )["text" ]
188+
189+
190+
191+ def get_model_dir (config ):
192+ model_path = config .get ("_models" , [None ])[0 ]
193+
194+ if os .path .isfile (model_path ):
195+ return os .path .dirname (model_path )
196+ return model_path
0 commit comments