Skip to content

Commit 3861f3c

Browse files
committed
Add matrices
1 parent a106bb9 commit 3861f3c

File tree

8 files changed

+830
-29
lines changed

8 files changed

+830
-29
lines changed

.github/workflows/code_checks.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ jobs:
2828
sudo luarocks install luacheck
2929
- name: lint
3030
run: |
31-
luacheck src
32-
luacheck .spec --config .spec/.luacheckrc
31+
luacheck .
3332
3433
test:
3534
runs-on: ubuntu-latest

.luacheckrc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
files[".spec/**/*"] = {
2+
read_globals = {
3+
"describe",
4+
"it",
5+
"before_each",
6+
"setup",
7+
"teardown",
8+
assert = {
9+
fields = {
10+
truthy = {},
11+
falsy = {},
12+
equal = {},
13+
same = {},
14+
has_error = {},
15+
near = {}
16+
}
17+
}
18+
}
19+
}

.spec/.luacheckrc

Lines changed: 0 additions & 16 deletions
This file was deleted.

.spec/math/abs_spec.lua

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1-
it("abs", function()
1+
describe("abs", function()
22
local abs = require("math.abs")
3-
assert.equal(0, abs(0))
4-
assert.equal(42, abs(42))
5-
assert.equal(42, abs(-42))
3+
it("works for numbers", function()
4+
assert.equal(0, abs(0))
5+
assert.equal(42, abs(42))
6+
assert.equal(42, abs(-42))
7+
assert.equal(math.huge, abs(math.huge))
8+
assert.equal(math.huge, abs(-math.huge))
9+
local nan = abs(0 / 0)
10+
assert(nan ~= nan)
11+
end)
12+
it("works for custom number types", function()
13+
local fraction = require("math.fraction")
14+
assert.equal(fraction.from_number(0), abs(fraction.from_number(0)))
15+
assert.equal(fraction.new(2, 3), abs(fraction.new(2, 3)))
16+
assert.equal(fraction.new(2, 3), abs(fraction.new(-2, 3)))
17+
end)
618
end)

.spec/math/matrix_spec.lua

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
-- Turn off StyLua to preserve matrix formatting
2+
-- stylua: ignore
3+
describe("Matrix", function()
4+
local matrix = require("math.matrix")
5+
local vector = require("math.vector")
6+
describe("constructors", function()
7+
it("zero", function()
8+
assert.same({h = 3, w = 2; 0, 0, 0, 0, 0, 0}, matrix.zero(3, 2))
9+
end)
10+
it("from list, with height", function()
11+
assert.same({
12+
h = 2, w = 3;
13+
1, 2, 3;
14+
4, 5, 6
15+
}, matrix.with_height(2, {
16+
1, 2, 3;
17+
4, 5, 6
18+
}))
19+
end)
20+
it("from list, with width", function()
21+
assert.same({
22+
h = 2, w = 3;
23+
1, 2, 3;
24+
4, 5, 6
25+
}, matrix.with_width(3, {
26+
1, 2, 3;
27+
4, 5, 6
28+
}))
29+
end)
30+
it("square", function()
31+
assert.equal(matrix.with_height(2, {
32+
1, 2;
33+
3, 4
34+
}), matrix.square{
35+
1, 2;
36+
3, 4
37+
})
38+
end)
39+
it("identity", function()
40+
assert.equal(matrix.square{
41+
1, 0;
42+
0, 1
43+
}, matrix.identity(2))
44+
end)
45+
it("diagonal", function()
46+
assert.equal(matrix.square{
47+
1, 0;
48+
0, 2
49+
}, matrix.diagonal{1, 2})
50+
end)
51+
end)
52+
it("copy", function()
53+
local m = matrix.with_height(2, {
54+
1, 2, 3;
55+
4, 5, 6
56+
})
57+
local copy = m:copy()
58+
m:set(2, 3, 9)
59+
assert.equal(matrix.with_height(2, {
60+
1, 2, 3;
61+
4, 5, 6
62+
}), copy)
63+
end)
64+
describe("equals", function()
65+
local m = matrix.with_height(2, {
66+
1, 2, 3;
67+
4, 5, 6
68+
})
69+
it("itself", function()
70+
assert(m:equals(m))
71+
end)
72+
it("copy of itself", function()
73+
assert(m:equals(m:copy()))
74+
end)
75+
it("supports tolerance", function()
76+
assert(matrix.square{0}:equals(matrix.square{1e-9}, 1e-9))
77+
end)
78+
it("different dimensions", function()
79+
assert(not m:equals(matrix.with_height(3, {unpack(m)})))
80+
end)
81+
it("different values", function()
82+
assert(not m:equals(matrix.zero(2, 3)))
83+
end)
84+
end)
85+
describe("getters & setters", function()
86+
it("elements", function()
87+
local m = matrix.with_height(2, {
88+
0, 2, 3;
89+
4, 5, 9
90+
})
91+
assert.equal(0, m:get(1, 1))
92+
assert.equal(9, m:get(2, 3))
93+
m:set(1, 1, 1)
94+
m:set(2, 3, 6)
95+
assert.equal(1, m:get(1, 1))
96+
assert.equal(6, m:get(2, 3))
97+
assert.equal(matrix.with_height(2, {
98+
1, 2, 3;
99+
4, 5, 6
100+
}), m)
101+
end)
102+
it("shape", function()
103+
local m = matrix.with_height(2, {
104+
1, 2, 3;
105+
4, 5, 6
106+
})
107+
m:set_height(2)
108+
assert.equal(matrix.with_height(2, {unpack(m)}), m)
109+
m:set_width(6)
110+
assert.equal(matrix.with_width(6, {unpack(m)}), m)
111+
end)
112+
end)
113+
it("transposition", function()
114+
assert.equal(matrix.square{
115+
1, 2;
116+
3, 4
117+
}, matrix.square{
118+
1, 3;
119+
2, 4
120+
}:transposed())
121+
assert.equal(matrix.with_height(3, {
122+
1, 4;
123+
2, 5;
124+
3, 6
125+
}), matrix.with_height(2, {
126+
1, 2, 3;
127+
4, 5, 6
128+
}):transposed())
129+
end)
130+
it("negation", function()
131+
local m = matrix.square{1}
132+
m:negate()
133+
assert.equal(matrix.square{-1}, m)
134+
end)
135+
it("addition", function()
136+
local m = matrix.square{1, 2, 3, 4}
137+
m:add(matrix.square{4, 3, 2, 1})
138+
assert.equal(matrix.square{5, 5, 5, 5}, m)
139+
end)
140+
it("subtraction", function()
141+
local m = matrix.square{5, 5, 5, 5}
142+
m:subtract(matrix.square{4, 3, 2, 1})
143+
assert.equal(matrix.square{1, 2, 3, 4}, m)
144+
end)
145+
it("scalar multiplication", function()
146+
local m = matrix.square{1, 2, 3, 4}
147+
m:scale(2)
148+
assert.equal(matrix.square{2, 4, 6, 8}, m)
149+
end)
150+
it("matrix-vector multiplication", function()
151+
assert.equal(vector.new{1*1 + 2*2 + 3*3, 4*1 + 5*2 + 6*3},
152+
matrix.with_height(2, {
153+
1, 2, 3;
154+
4, 5, 6
155+
}):multiply_column_vector({1, 2, 3}))
156+
end)
157+
it("vector-matrix multiplication", function()
158+
assert.equal(vector.new{1*1 + 2*4, 1*2 + 2*5, 1*3 + 2*6},
159+
matrix.with_height(2, {
160+
1, 2, 3;
161+
4, 5, 6
162+
}):multiply_row_vector({1, 2}))
163+
end)
164+
describe("matrix multiplication", function()
165+
it("dot product", function()
166+
assert.equal(matrix.square{1*1 + 2*2 + 3*3}, matrix.with_width(3, {1, 2, 3})
167+
:multiply_matrix(matrix.with_height(3, {1, 2, 3})))
168+
end)
169+
it("outer product", function()
170+
assert.equal(matrix.square{
171+
1*1, 1*2, 1*3;
172+
1*2, 2*2, 2*3;
173+
1*3, 3*2, 3*3
174+
}, matrix.with_height(3, {1, 2, 3})
175+
:multiply_matrix(matrix.with_width(3, {1, 2, 3})))
176+
end)
177+
local m = matrix.with_height(2, {
178+
1, 2, 3;
179+
4, 5, 6
180+
})
181+
it("2x3 * 3x2 -> 2x2", function()
182+
assert.equal(matrix.square{
183+
1*1 + 2*2 + 3*3, 1*4 + 2*5 + 3*6;
184+
4*1 + 5*2 + 6*3, 4*4 + 5*5 + 6*6
185+
}, m:multiply_matrix(m:transposed()))
186+
end)
187+
it("3x2 * 2x3 -> 3x3", function()
188+
assert.equal(matrix.square{
189+
1*1 + 4*4, 1*2 + 4*5, 1*3 + 4*6;
190+
2*1 + 5*4, 2*2 + 5*5, 2*3 + 5*6;
191+
3*1 + 6*4, 3*2 + 6*5, 3*3 + 6*6
192+
}, m:transposed():multiply_matrix(m))
193+
end)
194+
end)
195+
describe("operators", function()
196+
it("equals", function()
197+
assert(matrix.square{1} == matrix.square{1})
198+
end)
199+
it("negation", function()
200+
assert.equal(matrix.square{-1}, -matrix.square{1})
201+
end)
202+
it("subtraction", function()
203+
assert.equal(matrix.square{1}, matrix.square{3} - matrix.square{2})
204+
end)
205+
it("addition", function()
206+
assert.equal(matrix.square{3}, matrix.square{1} + matrix.square{2})
207+
end)
208+
describe("multiplication", function()
209+
it("scalar-matrix, matrix-scalar", function()
210+
assert.equal(matrix.square{2, 4, 6, 8}, 2*matrix.square{1, 2, 3, 4})
211+
assert.equal(matrix.square{2, 4, 6, 8}, matrix.square{1, 2, 3, 4}*2)
212+
end)
213+
local m = matrix.with_height(2, {
214+
1, 2, 3;
215+
4, 5, 6
216+
})
217+
it("matrix-vector", function()
218+
local v = vector.new{1, 2, 3}
219+
assert.equal(m:multiply_column_vector(v), m*v)
220+
end)
221+
it("vector-matrix", function()
222+
local v = vector.new{1, 2}
223+
assert.equal(m:multiply_row_vector(v), v*m)
224+
end)
225+
it("matrix-matrix", function()
226+
local t = m:transposed()
227+
assert.equal(t:multiply_matrix(m), t*m)
228+
assert.equal(m:multiply_matrix(t), m*t)
229+
end)
230+
end)
231+
describe("exponentiation", function()
232+
it("rejects non-endomorphisms", function()
233+
assert.has_error(function() return matrix.zero(2, 3)^10 end)
234+
end)
235+
it("identity matrix for exponent 0", function()
236+
assert.equal(matrix.identity(2), matrix.zero(2)^0)
237+
end)
238+
it("equals repeated multiplication", function()
239+
local m = matrix.square{1, 2, 3, 4}
240+
assert.equal(m*m*m*m, m^4)
241+
end)
242+
it("inverts & multiplies", function()
243+
local m = matrix.square{1, 2, 3, 4}
244+
local inv = m:inverse()
245+
assert((inv*inv*inv):equals(m^-3, 1e-9))
246+
end)
247+
end)
248+
end)
249+
describe("inversion", function()
250+
it("returns nothing for non-invertible matrices", function()
251+
assert.equal(nil, (matrix.with_width(2, {1, 2}):inverse()))
252+
assert.equal(nil, (matrix.with_height(2, {1, 2}):inverse()))
253+
assert.equal(nil, (matrix.square{
254+
1, 1;
255+
0, 0,
256+
}:inverse()))
257+
end)
258+
it("works for a 2x2 matrix", function()
259+
assert((1/(1*4 - 2*3) * matrix.square{
260+
4, -2;
261+
-3, 1
262+
}):equals(matrix.square{
263+
1, 2,
264+
3, 4,
265+
}:inverse(), 1e-9))
266+
end)
267+
it("works exactly for fractions", function()
268+
local fraction = require("math.fraction")
269+
local intfrac = fraction.from_number
270+
local m = matrix.square{
271+
4, -2;
272+
-3, 1
273+
}
274+
m:scale(fraction.new(1, 1*4 - 2*3))
275+
assert(m:equals(matrix.square{
276+
intfrac(1), intfrac(2),
277+
intfrac(3), intfrac(4),
278+
}:inverse(intfrac(0)), intfrac(0)))
279+
end)
280+
it("works for random matrices", function()
281+
for n = 1, 10 do
282+
local id = matrix.identity(n)
283+
for _ = 1, 10 do
284+
local random_nums = {}
285+
for i = 1, n^2 do
286+
random_nums[i] = math.random(-1e7, 1e7)
287+
end
288+
local m = matrix.with_height(n, random_nums)
289+
-- It is highly probable that the matrix is invertible;
290+
-- we do not need to account for non-invertible matrices
291+
local inv = m:inverse(1e-6)
292+
assert(id:equals(m*inv, 1e-6))
293+
assert(id:equals(inv*m, 1e-6))
294+
end
295+
end
296+
end)
297+
end)
298+
describe("determinant", function()
299+
it("is zero for non-invertible matrices", function()
300+
assert.equal(0, matrix.with_width(2, {1, 2}):determinant())
301+
assert.equal(0, matrix.with_height(2, {1, 2}):determinant())
302+
assert.equal(0, matrix.square{
303+
1, 1;
304+
0, 0,
305+
}:determinant())
306+
end)
307+
it("2d", function()
308+
assert.near(1 * 4 - 2 * 3, matrix.square{
309+
1, 2;
310+
3, 4
311+
}:determinant(), 1e-9)
312+
end)
313+
it("3d", function()
314+
assert.near(60, matrix.square{
315+
1, 4, 3;
316+
2, 5, 6;
317+
9, 8, 7
318+
}:determinant(), 1e-9)
319+
end)
320+
end)
321+
end)

0 commit comments

Comments
 (0)