From 1c848d3544e87464743fb18297ecf5ed58ec8b55 Mon Sep 17 00:00:00 2001 From: ALi Date: Mon, 30 May 2022 19:18:34 +0200 Subject: [PATCH] Upgrade deprecated torch.svd() is deprecated in favor of torch.linalg.svd() see https://pytorch.org/docs/stable/generated/torch.svd.html#torch-svd --- torch_optimizer/shampoo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_optimizer/shampoo.py b/torch_optimizer/shampoo.py index fc2706d..88e7765 100644 --- a/torch_optimizer/shampoo.py +++ b/torch_optimizer/shampoo.py @@ -8,7 +8,8 @@ def _matrix_power(matrix: torch.Tensor, power: float) -> torch.Tensor: # use CPU for svd for speed up device = matrix.device matrix = matrix.cpu() - u, s, v = torch.svd(matrix) + u, s, vh = torch.linalg.svd(matrix, full_matrices=False) + v = vh.mH return (u @ s.pow_(power).diag() @ v.t()).to(device)