Skip to content

Commit

Permalink
NVIDIAGH-439 added 2d SVD function
Browse files Browse the repository at this point in the history
  • Loading branch information
LeWerner42 authored and gdaviet committed Feb 4, 2025
1 parent e760065 commit 4e32db8
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added

- Added support for vec4f grid construction in `wp.Volume.allocate_by_tiles()`
- Add 2D SVD `svd2` to support 2d simulations ([GH-436](https://github.com/NVIDIA/warp/issues/436)).

### Changed

Expand Down
6 changes: 6 additions & 0 deletions docs/modules/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,12 @@ Vector Math
while the left and right basis vectors are returned in ``U`` and ``V``.


.. py:function:: svd2(A: Matrix[2,2,Float], U: Matrix[2,2,Float], sigma: Vector[2,Float], V: Matrix[2,2,Scalar]) -> None
Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
while the left and right basis vectors are returned in ``U`` and ``V``.


.. py:function:: qr3(A: Matrix[3,3,Float], Q: Matrix[3,3,Float], R: Matrix[3,3,Float]) -> None
Compute the QR decomposition of a 3x3 matrix ``A``. The orthogonal matrix is returned in ``Q``,
Expand Down
15 changes: 15 additions & 0 deletions warp/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,21 @@ def matrix_transform_dispatch_func(input_types: Mapping[str, type], return_type:
while the left and right basis vectors are returned in ``U`` and ``V``.""",
)

add_builtin(
"svd2",
input_types={
"A": matrix(shape=(2, 2), dtype=Float),
"U": matrix(shape=(2, 2), dtype=Float),
"sigma": vector(length=2, dtype=Float),
"V": matrix(shape=(2, 2), dtype=Scalar),
},
value_type=None,
group="Vector Math",
export=False,
doc="""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
while the left and right basis vectors are returned in ``U`` and ``V``.""",
)

add_builtin(
"qr3",
input_types={
Expand Down
116 changes: 116 additions & 0 deletions warp/native/svd.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,62 @@ void _svd(// input A
);
}


template<typename Type>
inline CUDA_CALLABLE
void _svd_2(// input A
Type a11, Type a12,
Type a21, Type a22,
// output U
Type &u11, Type &u12,
Type &u21, Type &u22,
// output S
Type &s11, Type &s12,
Type &s21, Type &s22,
// output V
Type &v11, Type &v12,
Type &v21, Type &v22)
{
// Step 1: Compute ATA
Type ATA11 = a11 * a11 + a21 * a21;
Type ATA12 = a11 * a12 + a21 * a22;
Type ATA22 = a12 * a12 + a22 * a22;

// Step 2: Eigenanalysis
Type trace = ATA11 + ATA22;
Type det = ATA11 * ATA22 - ATA12 * ATA12;
Type sqrt_term = sqrt(trace * trace - Type(4.0) * det);
Type lambda1 = (trace + sqrt_term) * Type(0.5);
Type lambda2 = (trace - sqrt_term) * Type(0.5);

// Step 3: Singular values
Type sigma1 = sqrt(lambda1);
Type sigma2 = sqrt(lambda2);

// Step 4: Eigenvectors (find V)
Type v1x = ATA12, v1y = lambda1 - ATA11; // For first eigenvector
Type v2x = ATA12, v2y = lambda2 - ATA11; // For second eigenvector
Type norm1 = sqrt(v1x * v1x + v1y * v1y);
Type norm2 = sqrt(v2x * v2x + v2y * v2y);

v11 = v1x / norm1; v12 = v2x / norm2;
v21 = v1y / norm1; v22 = v2y / norm2;

// Step 5: Compute U
Type inv_sigma1 = (sigma1 > Type(1e-6)) ? Type(1.0) / sigma1 : Type(0.0);
Type inv_sigma2 = (sigma2 > Type(1e-6)) ? Type(1.0) / sigma2 : Type(0.0);

u11 = (a11 * v11 + a12 * v21) * inv_sigma1;
u12 = (a11 * v12 + a12 * v22) * inv_sigma2;
u21 = (a21 * v11 + a22 * v21) * inv_sigma1;
u22 = (a21 * v12 + a22 * v22) * inv_sigma2;

// Step 6: Set S
s11 = sigma1; s12 = Type(0.0);
s21 = Type(0.0); s22 = sigma2;
}


template<typename Type>
inline CUDA_CALLABLE void svd3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& U, vec_t<3,Type>& sigma, mat_t<3,3,Type>& V) {
Type s12, s13, s21, s23, s31, s32;
Expand Down Expand Up @@ -483,6 +539,66 @@ inline CUDA_CALLABLE void adj_svd3(const mat_t<3,3,Type>& A,
adj_A = adj_A + (u_term + v_term + sigma_term);
}

template<typename Type>
inline CUDA_CALLABLE void svd2(const mat_t<2,2,Type>& A, mat_t<2,2,Type>& U, vec_t<2,Type>& sigma, mat_t<2,2,Type>& V) {
Type s12, s21;
_svd_2(A.data[0][0], A.data[0][1],
A.data[1][0], A.data[1][1],

U.data[0][0], U.data[0][1],
U.data[1][0], U.data[1][1],

sigma[0], s12,
s21, sigma[1],

V.data[0][0], V.data[0][1],
V.data[1][0], V.data[1][1]);
}

template<typename Type>
inline CUDA_CALLABLE void adj_svd2(const mat_t<2,2,Type>& A,
const mat_t<2,2,Type>& U,
const vec_t<2,Type>& sigma,
const mat_t<2,2,Type>& V,
mat_t<2,2,Type>& adj_A,
const mat_t<2,2,Type>& adj_U,
const vec_t<2,Type>& adj_sigma,
const mat_t<2,2,Type>& adj_V) {
Type s1_squared = sigma[0] * sigma[0];
Type s2_squared = sigma[1] * sigma[1];

// Compute inverse of (s1^2 - s2^2) if possible, use small epsilon to prevent division by zero
Type F01 = Type(1) / min(s2_squared - s1_squared, Type(-1e-6f));

// Construct the matrix F for the adjoint
mat_t<2,2,Type> F = mat_t<2,2,Type>(0.0, F01,
-F01, 0.0);

// Create a matrix to handle the adjoint of the singular values (diagonal matrix)
mat_t<2,2,Type> adj_sigma_mat = mat_t<2,2,Type>(adj_sigma[0], 0.0,
0.0, adj_sigma[1]);

// Matrix for handling singular values (diagonal matrix with sigma values)
mat_t<2,2,Type> s_mat = mat_t<2,2,Type>(sigma[0], 0.0,
0.0, sigma[1]);

// Compute the transpose of U and V
mat_t<2,2,Type> UT = transpose(U);
mat_t<2,2,Type> VT = transpose(V);

// Compute the term for sigma (diagonal matrix of adjoint singular values)
mat_t<2,2,Type> sigma_term = mul(U, mul(adj_sigma_mat, VT));

// Compute the adjoint contributions for U (left singular vectors)
mat_t<2,2,Type> u_term = mul(mul(U, mul(cw_mul(F, (mul(UT, adj_U) - mul(transpose(adj_U), U))), s_mat)), VT);

// Compute the adjoint contributions for V (right singular vectors)
mat_t<2,2,Type> v_term = mul(U, mul(s_mat, mul(cw_mul(F, (mul(VT, adj_V) - mul(transpose(adj_V), V))), VT)));

// Combine the terms to compute the adjoint of A
adj_A = adj_A + (u_term + v_term + sigma_term);
}


template<typename Type>
inline CUDA_CALLABLE void qr3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& Q, mat_t<3,3,Type>& R) {
Expand Down
8 changes: 8 additions & 0 deletions warp/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,14 @@ def svd3(A: Matrix[3, 3, Float], U: Matrix[3, 3, Float], sigma: Vector[3, Float]
...


@over
def svd2(A: Matrix[2, 2, Float], U: Matrix[2, 2, Float], sigma: Vector[2, Float], V: Matrix[2, 2, Scalar]):
"""Compute the SVD of a 2x2 matrix ``A``. The singular values are returned in ``sigma``,
while the left and right basis vectors are returned in ``U`` and ``V``.
"""
...


@over
def qr3(A: Matrix[3, 3, Float], Q: Matrix[3, 3, Float], R: Matrix[3, 3, Float]):
"""Compute the QR decomposition of a 3x3 matrix ``A``. The orthogonal matrix is returned in ``Q``,
Expand Down
121 changes: 121 additions & 0 deletions warp/tests/test_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,124 @@ def check_mat_svd(
assert_np_equal((plusval - minusval) / (2 * dx), m3grads[ii, jj], tol=fdtol)


def test_svd_2D(test, device, dtype, register_kernels=False):
rng = np.random.default_rng(123)

tol = {
np.float16: 1.0e-3,
np.float32: 1.0e-6,
np.float64: 1.0e-6,
}.get(dtype, 0)

wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
vec2 = wp.types.vector(length=2, dtype=wptype)
mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)

def check_mat_svd2(
m2: wp.array(dtype=mat22),
Uout: wp.array(dtype=mat22),
sigmaout: wp.array(dtype=vec2),
Vout: wp.array(dtype=mat22),
outcomponents: wp.array(dtype=wptype),
):
U = mat22()
sigma = vec2()
V = mat22()

wp.svd2(m2[0], U, sigma, V) # Assuming there's a 2D SVD kernel

Uout[0] = U
sigmaout[0] = sigma
Vout[0] = V

# multiply outputs by 2 so we've got something to backpropagate:
idx = 0
for i in range(2):
for j in range(2):
outcomponents[idx] = wptype(2) * U[i, j]
idx = idx + 1

for i in range(2):
outcomponents[idx] = wptype(2) * sigma[i]
idx = idx + 1

for i in range(2):
for j in range(2):
outcomponents[idx] = wptype(2) * V[i, j]
idx = idx + 1

kernel = getkernel(check_mat_svd2, suffix=dtype.__name__)

output_select_kernel = get_select_kernel(wptype)

if register_kernels:
return

m2 = wp.array(randvals(rng, [1, 2, 2], dtype) + np.eye(2), dtype=mat22, requires_grad=True, device=device)

outcomponents = wp.zeros(2 * 2 * 2 + 2, dtype=wptype, requires_grad=True, device=device)
Uout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
sigmaout = wp.zeros(1, dtype=vec2, requires_grad=True, device=device)
Vout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)

wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)

Uout_np = Uout.numpy()[0].astype(np.float64)
sigmaout_np = np.diag(sigmaout.numpy()[0].astype(np.float64))
Vout_np = Vout.numpy()[0].astype(np.float64)

assert_np_equal(
np.matmul(Uout_np, np.matmul(sigmaout_np, Vout_np.T)), m2.numpy()[0].astype(np.float64), tol=30 * tol
)

if dtype == np.float16:
# Skip gradient check for float16 due to rounding errors
return

# Check gradients:
out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
idx = 0
for idx in range(2 * 2 + 2 + 2 * 2):
tape = wp.Tape()
with tape:
wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
tape.backward(out)
m2grads = 1.0 * tape.gradients[m2].numpy()[0]

tape.zero()

dx = 0.0001
fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
for ii in range(2):
for jj in range(2):
m2test = 1.0 * m2.numpy()
m2test[0, ii, jj] += dx
wp.launch(
kernel,
dim=1,
inputs=[wp.array(m2test, dtype=mat22, device=device)],
outputs=[Uout, sigmaout, Vout, outcomponents],
device=device,
)
wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
plusval = out.numpy()[0]

m2test = 1.0 * m2.numpy()
m2test[0, ii, jj] -= dx
wp.launch(
kernel,
dim=1,
inputs=[wp.array(m2test, dtype=mat22, device=device)],
outputs=[Uout, sigmaout, Vout, outcomponents],
device=device,
)
wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
minusval = out.numpy()[0]

assert_np_equal((plusval - minusval) / (2 * dx), m2grads[ii, jj], tol=fdtol)


def test_qr(test, device, dtype, register_kernels=False):
rng = np.random.default_rng(123)

Expand Down Expand Up @@ -1826,6 +1944,9 @@ def test_tpl_ops_with_anon(self):
TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
)
add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
add_function_test_register_kernel(
TestMat, f"test_svd_2D{dtype.__name__}", test_svd_2D, devices=devices, dtype=dtype
)
add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
add_function_test_register_kernel(
Expand Down

0 comments on commit 4e32db8

Please sign in to comment.