Skip to content

Commit ead1d3c

Browse files
committed
Compatible with DS
1 parent 320b659 commit ead1d3c

File tree

2 files changed

+77
-11
lines changed

2 files changed

+77
-11
lines changed

nanoowl/image_preprocessor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,20 @@ def preprocess_pil_image(self, image: PIL.Image.Image):
7272
image = image.permute(2, 0, 1)[None, ...]
7373
image = image.to(self.mean.device)
7474
image = image.type(self.mean.dtype)
75+
return self.forward(image, inplace=True)
76+
77+
@torch.no_grad()
78+
def preprocess_tensor_image(self, image: torch.Tensor) -> torch.Tensor:
79+
# Assuming the input image tensor is in the shape (H, W, C)
80+
assert image.dim() == 3, "Input image tensor must have 3 dimensions (H, W, C)"
81+
assert image.size(2) == 3, "Input image tensor must have 3 channels (RGB)"
82+
# Permute the tensor to match the expected shape (N, C, H, W)
83+
image = image.permute(2, 0, 1)[None, ...]
84+
# Convert the image tensor to the same device as self.mean
85+
image = image.to(self.mean.device)
86+
87+
# Convert the data type of the image tensor to match self.mean
88+
image = image.type(self.mean.dtype)
89+
90+
# Assuming self.forward is a method in your class
7591
return self.forward(image, inplace=True)

nanoowl/owl_predictor.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -454,23 +454,73 @@ def build_image_encoder_engine(self,
454454

455455
return self.load_image_encoder_engine(engine_path, max_batch_size)
456456

457+
457458
def predict(self,
458-
image: PIL.Image,
459-
text: List[str],
460-
text_encodings: Optional[OwlEncodeTextOutput],
461-
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
462-
pad_square: bool = True,
463-
464-
) -> OwlDecodeOutput:
459+
image: Union[PIL.Image.Image, torch.Tensor],
460+
text: List[str],
461+
text_encodings: Optional[OwlEncodeTextOutput],
462+
threshold: Union[int, float, List[Union[int, float]]] = 0.1,
463+
pad_square: bool = True,
464+
) -> OwlDecodeOutput:
465+
466+
if isinstance(image, PIL.Image.Image):
467+
image_tensor = self.image_preprocessor.preprocess_pil_image(image)
465468

466-
image_tensor = self.image_preprocessor.preprocess_pil_image(image)
469+
rois = torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device)
470+
471+
elif isinstance(image, torch.Tensor):
472+
image_tensor = self.image_preprocessor.preprocess_tensor_image(image)
467473

474+
rois = torch.tensor([[0, 0, image.shape[1], image.shape[0]]], dtype=image_tensor.dtype, device=image_tensor.device)
475+
476+
else:
477+
raise ValueError("Input image must be either a PIL Image or a torch.Tensor")
478+
468479
if text_encodings is None:
469480
text_encodings = self.encode_text(text)
481+
482+
image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)
483+
484+
return self.decode(image_encodings, text_encodings, threshold)
485+
486+
# def predict(self,
487+
# image: PIL.Image,
488+
# text: List[str],
489+
# text_encodings: Optional[OwlEncodeTextOutput],
490+
# threshold: Union[int, float, List[Union[int, float]]] = 0.1,
491+
# pad_square: bool = True,
492+
493+
# ) -> OwlDecodeOutput:
470494

471-
rois = torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device)
495+
# image_tensor = self.image_preprocessor.preprocess_pil_image(image)
496+
# print(image_tensor)
497+
# if text_encodings is None:
498+
# text_encodings = self.encode_text(text)
472499

473-
image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)
500+
# rois = torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device)
474501

475-
return self.decode(image_encodings, text_encodings, threshold)
502+
# image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)
476503

504+
# return self.decode(image_encodings, text_encodings, threshold)
505+
506+
# def predictTensor(self,
507+
# image: torch.Tensor,
508+
# text: List[str],
509+
# text_encodings: Optional[OwlEncodeTextOutput],
510+
# threshold: Union[int, float, List[Union[int, float]]] = 0.1,
511+
# pad_square: bool = True,
512+
513+
# ) -> OwlDecodeOutput:
514+
515+
# image_tensor = self.image_preprocessor.preprocess_tensor_image(image)
516+
517+
# if text_encodings is None:
518+
# text_encodings = self.encode_text(text)
519+
# print(image_tensor)
520+
# #print(image.shape[1])
521+
# rois = torch.tensor([[0, 0, image.shape[1], image.shape[0]]], dtype=image_tensor.dtype, device=image_tensor.device)
522+
523+
# image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)
524+
525+
# return self.decode(image_encodings, text_encodings, threshold)
526+

0 commit comments

Comments
 (0)