Skip to content

Conversation

@kiryteo
Copy link

@kiryteo kiryteo commented May 23, 2025

What does this PR do?

  • Added maxpool_resize transform to obtain maxpool-like functionality when resizing inputs.

MONAI provides resize functionality via monai.transforms.Resized which uses torch.nn.functional.interpolate along with one of the available interpolation modes. None of the modes offer the maxpool-like functionality and thus we need a custom transform. No additional dependencies required. This is an optional transform to be used with config and does not introduce any breaking changes.

Before submitting

  • Did you make sure title is self-explanatory and the description concisely explains the PR?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you test your PR locally with pytest command?
  • Did you run pre-commit hooks with pre-commit run -a command?

Did you have fun?

Make sure you had fun coding 🙃

Copy link
Contributor

@benjijamorris benjijamorris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! My one suggestion would be to add an explicit spatial_dims argument. I think this would take the guesswork out of trying to find the spatial dimensions and make the code simpler but I can be convinced otherwise!

raise TypeError(f"Input '{key}' must be a PyTorch tensor, got {type(x)}")

# Determine expected tensor dimensions and spatial size length
input_dims = x.dim()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Monai transforms expect C[Z]YX images - would it be reasonable to enforce this based on the spatial size and then just add the batch dimension right before pooling? I often have an explicit spatial_dims argument to help with this kind of check


# Normalize spatial_size to match expected number of spatial dimensions
try:
spatial_size = ensure_tuple_rep(self.spatial_size, expected_spatial_dims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding an explicit spatial_dims argument here would let you find the spatial size once in the init


orig_size = x.shape[-expected_spatial_dims:]

# Replace non-positive spatial_size values with original dimensions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants