-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathmatrix_functions_types.py
148 lines (103 loc) · 6.39 KB
/
matrix_functions_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
from dataclasses import dataclass, field
import torch
from commons import AbstractDataclass
@dataclass(init=False)
class MatrixFunctionConfig(AbstractDataclass):
"""Base dataclass for matrix function configurations."""
@dataclass(init=False)
class EigendecompositionConfig(MatrixFunctionConfig):
"""Configuration for eigenvalue decomposition."""
@dataclass(kw_only=True)
class EighEigendecompositionConfig(EigendecompositionConfig):
"""Configuration for eigendecomposition with torch.linalg.eigh.
Attributes:
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
eigendecomposition_offload_device (torch.device | str): Device to offload eigendecomposition to. If value is empty string, we don't perform offloading. (Default: "")
"""
retry_double_precision: bool = True
eigendecomposition_offload_device: torch.device | str = ""
def __post_init__(self) -> None:
# Convert an non-empty string to a torch.device; this verifies that the string is a valid device string early.
if self.eigendecomposition_offload_device != "":
self.eigendecomposition_offload_device = torch.device(
self.eigendecomposition_offload_device
)
DefaultEigendecompositionConfig = EighEigendecompositionConfig()
@dataclass(kw_only=True)
class QREigendecompositionConfig(EigendecompositionConfig):
"""Configuration for eigenvalue decomposition via QR algorithm.
Determines whether the QR algorithm has converged based on the estimated eigenvalues Q^T A Q =: B, where Q is the last computed eigenvectors and A is the current Kronecker factor.
The convergence criterion based on the estimated eigenvalues is then defined as ||B - diag(B)||_F <= tolerance * ||B||_F.
The tolerance hyperparameter should therefore be in the interval [0.0, 1.0].
Note that if the criterion based on the estimated eigenvalues is already below or equal to the tolerance given the initial eigenvectors_estimate, the QR iterations will be skipped.
This convergence criterion can be motivated by considering A' = Q diag(B) Q^T as an approximation of A.
We have ||A - A'||_F = ||A - Q diag(B) Q^T||_F = ||Q^T A Q - diag(B)||_F = ||B - diag(B)||_F.
Moreover, we have ||B||_F = ||Q^T A Q||_F = ||A||_F.
Hence, the two relative errors are also equivalent: ||A - A'||_F / ||A||_F = ||B - diag(B)||_F / ||B||_F.
Attributes:
max_iterations (int): The maximum number of iterations to perform. (Default: 1)
tolerance (float): The tolerance for determining convergence in terms of the norm of the off-diagonal elements of the eigenvalue estimate.
(Default: 0.01)
eigenvectors_estimate (Tensor): The current estimate of the eigenvectors. Cannot be set at initialization.
"""
max_iterations: int = 1
tolerance: float = 0.01
eigenvectors_estimate: torch.Tensor = field(init=False)
def __post_init__(self) -> None:
if not (0.0 <= self.tolerance <= 1.0):
raise ValueError(
f"Invalid tolerance value: {self.tolerance}. Must be in the interval [0.0, 1.0]."
)
@dataclass(init=False)
class RootInvConfig(MatrixFunctionConfig):
"""Base dataclass for matrix root inverse method configurations."""
@dataclass(kw_only=True)
class EigenConfig(RootInvConfig, EighEigendecompositionConfig):
"""Configuration for matrix root inverse via an eigendecomposition.
Attributes:
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher (double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
eigendecomposition_offload_device (torch.device | str): Device to offload eigendecomposition to. If value is empty string, we don't perform offloading. (Default: "")
exponent_multiplier (float): Number to be multiplied to the numerator of the inverse root, i.e., eta where the
exponent is -eta / (2 * p). (Default: 1.0)
enhance_stability (bool): Whether to enhance the stability of the root inverse computation through mathematically identical, but numerically more stable conditioning. (Default: False)
"""
exponent_multiplier: float = 1.0
enhance_stability: bool = False
DefaultEigenConfig = EigenConfig()
@dataclass(kw_only=True)
class CoupledNewtonConfig(RootInvConfig):
"""Configuration for matrix root inverse via coupled Newton method.
Attributes:
max_iterations (int): Maximum number of iterations for coupled Newton iteration. (Default: 100)
tolerance (float): Tolerance for computing root inverse using coupled Newton iteration. (Default: 1e-6)
"""
max_iterations: int = 100
tolerance: float = 1e-6
@dataclass(kw_only=True)
class CoupledHigherOrderConfig(RootInvConfig):
"""Configuration for matrix root inverse via coupled higher-order method.
Attributes:
rel_epsilon (float): Relative epsilon for coupled higher order method. Adds epsilon * lambda_max * I to matrix
before taking matrix root, where lambda_max is an upper bound on maximum eigenvalue.
abs_epsilon (float): Absolute epsilon for coupled higher order method. Adds epsilon * I to matrix before taking matrix root. When both "abs_epsilon" and "rel_epsilon" are specified, max(rel_epsilon * lambda_max, abs_epsilon) * I is added to the matrix.
max_iterations (int): Maximum number of iterations for coupled higher order method. (Default: 100)
tolerance (float): Tolerance for computing root inverse using coupled higher order method. (Default: 1e-8)
order (int): Order of the method. Order must be >= 2. Higher order methods accelerate convergence (fewer iterations),
but can take more matmuls per iteration. order=2 represents Newton's method. (Default: 3)
disable_tf32 (bool): Whether to disable tf32 matmuls or not internally. Highly recommend keeping True,
since tf32 is challenging numerically here. (Default: True)
"""
rel_epsilon: float
abs_epsilon: float
max_iterations: int = 100
tolerance: float = 1e-8
order: int = 3
disable_tf32: bool = True