Skip to content

Commit e3cdb4e

Browse files
authored
Merge pull request #591 from tkoyama010/feature/identity-matrix
Add identity matrix generation functionality to SimpleArray (#588)
2 parents f55469c + 3788eb5 commit e3cdb4e

File tree

3 files changed

+55
-2
lines changed

3 files changed

+55
-2
lines changed

cpp/modmesh/buffer/SimpleArray.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,26 @@ class SimpleArray
12481248
return data;
12491249
}
12501250

1251+
/**
1252+
* Create an identity matrix of size n x n.
1253+
*
1254+
* @param n Size of the square identity matrix
1255+
* @return SimpleArray representing an n x n identity matrix
1256+
*/
1257+
static SimpleArray eye(size_t n)
1258+
{
1259+
shape_type shape{n, n};
1260+
SimpleArray result(shape, static_cast<value_type>(0));
1261+
1262+
// Set diagonal elements to 1
1263+
for (size_t i = 0; i < n; ++i)
1264+
{
1265+
result(i, i) = static_cast<value_type>(1);
1266+
}
1267+
1268+
return result;
1269+
}
1270+
12511271
explicit operator bool() const noexcept { return bool(m_buffer) && bool(*m_buffer); }
12521272

12531273
size_t nbytes() const noexcept { return size() * ITEMSIZE; }

cpp/modmesh/buffer/pymod/wrap_SimpleArray.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapSimpleArray
315315
.def("mul", &wrapped_type::mul)
316316
.def("div", &wrapped_type::div)
317317
.def("matmul", &wrapped_type::matmul)
318+
.def_static("eye", &wrapped_type::eye, py::arg("n"), "Create an identity matrix of size n x n")
318319
// TODO: In-place operation should return reference to self to support function chaining
319320
.def("iadd", [](wrapped_type & self, wrapped_type const & other)
320321
{ self.iadd(other); })

tests/test_gemm.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import unittest
2828
import numpy as np
29+
from itertools import product
2930
import modmesh as mm
3031

3132

@@ -75,16 +76,47 @@ def test_identity_matrix(self):
7576
# 3x3 matrix
7677
a_data = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0],
7778
[7.0, 8.0, 9.0]], dtype=self.dtype)
78-
identity_data = np.eye(3, dtype=self.dtype)
7979

8080
a = self.SimpleArray(array=a_data)
81-
identity = self.SimpleArray(array=identity_data)
81+
identity = self.SimpleArray.eye(3)
8282

8383
result = a.matmul(identity)
8484

8585
self.assertEqual(list(result.shape), [3, 3])
8686
np.testing.assert_array_almost_equal(result.ndarray, a_data)
8787

88+
def test_eye_method(self):
89+
"""Test eye method creates correct identity matrices"""
90+
# Test cases: different sizes
91+
test_sizes = [1, 2, 3, 4, 5, 10]
92+
93+
for size in test_sizes:
94+
with self.subTest(size=size):
95+
# Create identity matrix using our eye method
96+
identity = self.SimpleArray.eye(size)
97+
98+
# Create expected identity matrix using NumPy
99+
expected = np.eye(size, dtype=self.dtype)
100+
101+
# Check shape
102+
self.assertEqual(list(identity.shape), [size, size])
103+
104+
# Check array values
105+
np.testing.assert_array_almost_equal(identity.ndarray,
106+
expected)
107+
108+
# Verify diagonal and off-diagonal elements explicitly
109+
# using product
110+
for i, j in product(range(size), repeat=2):
111+
if i == j:
112+
self.assertEqual(identity[i, j], 1.0,
113+
f"Diagonal element ({i},{j}) "
114+
f"should be 1.0")
115+
else:
116+
self.assertEqual(identity[i, j], 0.0,
117+
f"Off-diagonal element ({i},{j}) "
118+
f"should be 0.0")
119+
88120
def test_zero_matrix(self):
89121
"""Test multiplication with zero matrix"""
90122
a_data = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=self.dtype)

0 commit comments

Comments
 (0)