-
Notifications
You must be signed in to change notification settings - Fork 1.2k
MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16 #8326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16 #8326
Conversation
Signed-off-by: John Zielke <[email protected]>
81865e4
to
8888a48
Compare
Hi @johnzielke, could you please help resolve the conflict then I cane help trigger the blossom, thanks! |
/build |
Hi @dongyang0122, could you please help check whether this change make sense to you? Just want confirm there is no other specific concern regarding this |
self.print_info = print_info | ||
self.save_mem = save_mem | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
if self.print_info: | ||
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}") | ||
|
||
target_dtype = input.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the input is float32 but users want convert the output to the float16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change only affects the group norm and makes the behavior in line with the rest of the model.
If I understand you correctly, to achieve what you want, the common pattern would be:
model = model.to(dtype=torch.bfloat16)
prediction = model(x.to(dtype=torch.bfloat16)
Or are you referring to something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this change, the parameter needs to be manually adjusted to produce a tensor that's compatible with the rest of the model. In addition, it could only be float32 or float16. Is there a reason one would want to have the GroupNorm in a different datatype than the rest of the model?
/build |
Hi, are there any updates on this PR? |
Hello @johnzielke, sorry for the late response. I discussed offline with @dongyang0122. The addition of |
Thank you for the info. It is surprising to me that this deceptively simple operation would be a problem during inference compared to attention or other operations. Without going into details why the groupnorm is so memory-intensive, my other suggestion is to make the norm_float16 parameter accept a dtype as well, so that you can have it use float16, bfloat16 or any of the other dtypes. The current implementation makes it impossible/hard to use this model with any of the new high-performance datatypes introduced in new Nvidia GPUs, which would is very unfortunate in my opinion. This is how I stumbled on this problem, as it is not mentioned in the documentation that this use case is not supported. |
Yes, sure. It totally make sense and make it more easy to be compatible with more use cases. Could you please help modify it in this pr and also backward compatible with current use case? Thanks in advance! |
Yes. I'm not going to change the naming then, which will mean the parameter name will be confusing, but I think it keeps the amount of change and deprecation to a minimum. I suggest: |
For my education: Why does the current implementation require so much less memory than a simple:
The mean and std tensors should be quite small, and the resulting tensor is of the same size, so why is the loop implementation more memory-efficient? |
Description
The current maisi vae encoder only supports float32 and float16 with a parameter that needs to be set manually. This PR instead infers the norm datatype from the input and should therefore enable the use of other datatypes
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.