Issue:
In the predict_masks method of the MaskDecoder class, there's an enhancement regarding tensor dimension handling. Here's a detailed breakdown:
-
Conditional Check:
- A new check
if image_embeddings.shape[0] != tokens.shape[0]: has been added to ascertain tensor dimension consistency before applying torch.repeat_interleave.
-
Usage of torch.repeat_interleave:
- Ensures
image_embeddings tensor's batch size aligns with tokens by expanding it along the batch dimension.
-
Ensuring Consistency:
- This check ensures that
torch.repeat_interleave is applied only when necessary, ensuring consistent tensor handling within the predict_masks method, as opposed to the original implementation where torch.repeat_interleave is applied directly.