-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Propose fix perceptual loss sqrt nan #8414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file should go into an appropriate subdirectory in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Roger. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import torch | ||
import torch.optim as optim | ||
from parameterized import parameterized | ||
|
||
from monai.losses.perceptual import normalize_tensor | ||
from monai.utils import set_determinism | ||
|
||
|
||
class TestNormalizeTensorStability(unittest.TestCase): | ||
def setUp(self): | ||
set_determinism(seed=0) | ||
self.addCleanup(set_determinism, None) | ||
|
||
def tearDown(self): | ||
set_determinism(None) | ||
|
||
@parameterized.expand([["e-3", 1e-3], ["e-6", 1e-6], ["e-9", 1e-9], ["e-12", 1e-12]]) # Small values | ||
def test_normalize_tensor_stability(self, name, scale): | ||
"""Test that small values don't produce NaNs + are handled gracefully.""" | ||
# Create tensor | ||
x = torch.zeros(2, 3, 10, 10, requires_grad=True) | ||
|
||
optimizer = optim.Adam([x], lr=0.01) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think the optimizer is needed for this test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed, will remove. |
||
x_scaled = x * scale | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will add a comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand the point of this test with regards to the next one; instead of a zeros tensor, couldn't it be a random one which will be then multiplied by a really small number? |
||
normalized = normalize_tensor(x_scaled) | ||
|
||
# Compute to force backward pass | ||
loss = normalized.sum() | ||
|
||
# this is where it failed before | ||
loss.backward() | ||
|
||
# Check for NaNs in gradients | ||
self.assertFalse(torch.isnan(x.grad).any(), f"NaN gradients detected with scale {scale:.10e}") | ||
|
||
def test_normalize_tensor_zero_input(self): | ||
"""Test that normalize_tensor handles zero inputs gracefully.""" | ||
# Create tensor with zeros | ||
x = torch.zeros(2, 3, 4, 4, requires_grad=True) | ||
|
||
normalized = normalize_tensor(x) | ||
loss = normalized.sum() | ||
loss.backward() | ||
|
||
# Check for NaNs in gradients | ||
self.assertFalse(torch.isnan(x.grad).any(), "NaN gradients detected with zero input") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to remove
eps
from the denominator? As proposedeps
will contribute twice to the final result.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Will remove.