Skip to content

MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16 #8326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from

Conversation

johnzielke
Copy link
Contributor

Description

The current maisi vae encoder only supports float32 and float16 with a parameter that needs to be set manually. This PR instead infers the norm datatype from the input and should therefore enable the use of other datatypes

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@johnzielke johnzielke force-pushed the bugfix/maisi-vae-autoencoder-fix-dtype branch from 81865e4 to 8888a48 Compare February 4, 2025 15:55
@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 13, 2025

Hi @johnzielke, could you please help resolve the conflict then I cane help trigger the blossom, thanks!

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 21, 2025

/build

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 21, 2025

Hi @dongyang0122, could you please help check whether this change make sense to you? Just want confirm there is no other specific concern regarding this norm_float16 param.

self.print_info = print_info
self.save_mem = save_mem

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.print_info:
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")

target_dtype = input.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the input is float32 but users want convert the output to the float16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change only affects the group norm and makes the behavior in line with the rest of the model.
If I understand you correctly, to achieve what you want, the common pattern would be:

model = model.to(dtype=torch.bfloat16)
prediction = model(x.to(dtype=torch.bfloat16)

Or are you referring to something else?

Copy link
Contributor Author

@johnzielke johnzielke Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this change, the parameter needs to be manually adjusted to produce a tensor that's compatible with the rest of the model. In addition, it could only be float32 or float16. Is there a reason one would want to have the GroupNorm in a different datatype than the rest of the model?

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 22, 2025

/build

@johnzielke
Copy link
Contributor Author

Hi, are there any updates on this PR?

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 30, 2025

Hi, are there any updates on this PR?

Hello @johnzielke, sorry for the late response.

I discussed offline with @dongyang0122. The addition of norm_float16 is intended to prevent out-of-memory (OOM) issues during inference, as this has been a bottleneck for MAISI. The rationale for not using float16 previously was to avoid affecting the precision of other layers and thus prevent truncation errors. Therefore, I suggest that we retain this argument.
What do you think?

@johnzielke
Copy link
Contributor Author

Thank you for the info. It is surprising to me that this deceptively simple operation would be a problem during inference compared to attention or other operations.
But without going into that, that brings up some other questions:
I think the reason that I created this PR is that if the return type of the group norm is float16, different from the rest of the network, I recall this creates a problem with the convolution later on as these will have a float32 datatype. (this is after converting the model to a different datatype such as bfloat16 using vae_model.to(torch.bfloat16)

Without going into details why the groupnorm is so memory-intensive, my other suggestion is to make the norm_float16 parameter accept a dtype as well, so that you can have it use float16, bfloat16 or any of the other dtypes.

The current implementation makes it impossible/hard to use this model with any of the new high-performance datatypes introduced in new Nvidia GPUs, which would is very unfortunate in my opinion. This is how I stumbled on this problem, as it is not mentioned in the documentation that this use case is not supported.

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 30, 2025

my other suggestion is to make the norm_float16 parameter accept a dtype as well, so that you can have it use float16, bfloat16 or any of the other dtypes.

Yes, sure. It totally make sense and make it more easy to be compatible with more use cases. Could you please help modify it in this pr and also backward compatible with current use case? Thanks in advance!

@johnzielke
Copy link
Contributor Author

johnzielke commented Apr 30, 2025

Yes. I'm not going to change the naming then, which will mean the parameter name will be confusing, but I think it keeps the amount of change and deprecation to a minimum. I suggest:
norm_float16: bool | str | None
bool: Current behavior
str: torch dtype to convert to
None: The behavior suggested initially in the PR, where it will be whatever dtype the input variable is

@johnzielke
Copy link
Contributor Author

For my education: Why does the current implementation require so much less memory than a simple:

  param_n, param_c, param_d, param_h, param_w = input.shape
  input_g = input.view(param_n, self.num_groups, param_c // self.num_groups, param_d, param_h, param_w)

  mean = input_g.mean([2, 3, 4, 5], keepdim=True)
  std = input_g.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
  return (input - mean) / std

The mean and std tensors should be quite small, and the resulting tensor is of the same size, so why is the loop implementation more memory-efficient?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants