@@ -454,23 +454,73 @@ def build_image_encoder_engine(self,
454
454
455
455
return self .load_image_encoder_engine (engine_path , max_batch_size )
456
456
457
+
457
458
def predict (self ,
458
- image : PIL .Image ,
459
- text : List [str ],
460
- text_encodings : Optional [OwlEncodeTextOutput ],
461
- threshold : Union [int , float , List [Union [int , float ]]] = 0.1 ,
462
- pad_square : bool = True ,
463
-
464
- ) -> OwlDecodeOutput :
459
+ image : Union [PIL .Image .Image , torch .Tensor ],
460
+ text : List [str ],
461
+ text_encodings : Optional [OwlEncodeTextOutput ],
462
+ threshold : Union [int , float , List [Union [int , float ]]] = 0.1 ,
463
+ pad_square : bool = True ,
464
+ ) -> OwlDecodeOutput :
465
+
466
+ if isinstance (image , PIL .Image .Image ):
467
+ image_tensor = self .image_preprocessor .preprocess_pil_image (image )
465
468
466
- image_tensor = self .image_preprocessor .preprocess_pil_image (image )
469
+ rois = torch .tensor ([[0 , 0 , image .width , image .height ]], dtype = image_tensor .dtype , device = image_tensor .device )
470
+
471
+ elif isinstance (image , torch .Tensor ):
472
+ image_tensor = self .image_preprocessor .preprocess_tensor_image (image )
467
473
474
+ rois = torch .tensor ([[0 , 0 , image .shape [1 ], image .shape [0 ]]], dtype = image_tensor .dtype , device = image_tensor .device )
475
+
476
+ else :
477
+ raise ValueError ("Input image must be either a PIL Image or a torch.Tensor" )
478
+
468
479
if text_encodings is None :
469
480
text_encodings = self .encode_text (text )
481
+
482
+ image_encodings = self .encode_rois (image_tensor , rois , pad_square = pad_square )
483
+
484
+ return self .decode (image_encodings , text_encodings , threshold )
485
+
486
+ # def predict(self,
487
+ # image: PIL.Image,
488
+ # text: List[str],
489
+ # text_encodings: Optional[OwlEncodeTextOutput],
490
+ # threshold: Union[int, float, List[Union[int, float]]] = 0.1,
491
+ # pad_square: bool = True,
492
+
493
+ # ) -> OwlDecodeOutput:
470
494
471
- rois = torch .tensor ([[0 , 0 , image .width , image .height ]], dtype = image_tensor .dtype , device = image_tensor .device )
495
+ # image_tensor = self.image_preprocessor.preprocess_pil_image(image)
496
+ # print(image_tensor)
497
+ # if text_encodings is None:
498
+ # text_encodings = self.encode_text(text)
472
499
473
- image_encodings = self . encode_rois ( image_tensor , rois , pad_square = pad_square )
500
+ # rois = torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device )
474
501
475
- return self .decode ( image_encodings , text_encodings , threshold )
502
+ # image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square )
476
503
504
+ # return self.decode(image_encodings, text_encodings, threshold)
505
+
506
+ # def predictTensor(self,
507
+ # image: torch.Tensor,
508
+ # text: List[str],
509
+ # text_encodings: Optional[OwlEncodeTextOutput],
510
+ # threshold: Union[int, float, List[Union[int, float]]] = 0.1,
511
+ # pad_square: bool = True,
512
+
513
+ # ) -> OwlDecodeOutput:
514
+
515
+ # image_tensor = self.image_preprocessor.preprocess_tensor_image(image)
516
+
517
+ # if text_encodings is None:
518
+ # text_encodings = self.encode_text(text)
519
+ # print(image_tensor)
520
+ # #print(image.shape[1])
521
+ # rois = torch.tensor([[0, 0, image.shape[1], image.shape[0]]], dtype=image_tensor.dtype, device=image_tensor.device)
522
+
523
+ # image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)
524
+
525
+ # return self.decode(image_encodings, text_encodings, threshold)
526
+
0 commit comments