|
| 1 | +-- Turn off StyLua to preserve matrix formatting |
| 2 | +-- stylua: ignore |
| 3 | +describe("Matrix", function() |
| 4 | + local matrix = require("math.matrix") |
| 5 | + local vector = require("math.vector") |
| 6 | + describe("constructors", function() |
| 7 | + it("zero", function() |
| 8 | + assert.same({h = 3, w = 2; 0, 0, 0, 0, 0, 0}, matrix.zero(3, 2)) |
| 9 | + end) |
| 10 | + it("from list, with height", function() |
| 11 | + assert.same({ |
| 12 | + h = 2, w = 3; |
| 13 | + 1, 2, 3; |
| 14 | + 4, 5, 6 |
| 15 | + }, matrix.with_height(2, { |
| 16 | + 1, 2, 3; |
| 17 | + 4, 5, 6 |
| 18 | + })) |
| 19 | + end) |
| 20 | + it("from list, with width", function() |
| 21 | + assert.same({ |
| 22 | + h = 2, w = 3; |
| 23 | + 1, 2, 3; |
| 24 | + 4, 5, 6 |
| 25 | + }, matrix.with_width(3, { |
| 26 | + 1, 2, 3; |
| 27 | + 4, 5, 6 |
| 28 | + })) |
| 29 | + end) |
| 30 | + it("square", function() |
| 31 | + assert.equal(matrix.with_height(2, { |
| 32 | + 1, 2; |
| 33 | + 3, 4 |
| 34 | + }), matrix.square{ |
| 35 | + 1, 2; |
| 36 | + 3, 4 |
| 37 | + }) |
| 38 | + end) |
| 39 | + it("identity", function() |
| 40 | + assert.equal(matrix.square{ |
| 41 | + 1, 0; |
| 42 | + 0, 1 |
| 43 | + }, matrix.identity(2)) |
| 44 | + end) |
| 45 | + it("diagonal", function() |
| 46 | + assert.equal(matrix.square{ |
| 47 | + 1, 0; |
| 48 | + 0, 2 |
| 49 | + }, matrix.diagonal{1, 2}) |
| 50 | + end) |
| 51 | + end) |
| 52 | + it("copy", function() |
| 53 | + local m = matrix.with_height(2, { |
| 54 | + 1, 2, 3; |
| 55 | + 4, 5, 6 |
| 56 | + }) |
| 57 | + local copy = m:copy() |
| 58 | + m:set(2, 3, 9) |
| 59 | + assert.equal(matrix.with_height(2, { |
| 60 | + 1, 2, 3; |
| 61 | + 4, 5, 6 |
| 62 | + }), copy) |
| 63 | + end) |
| 64 | + describe("equals", function() |
| 65 | + local m = matrix.with_height(2, { |
| 66 | + 1, 2, 3; |
| 67 | + 4, 5, 6 |
| 68 | + }) |
| 69 | + it("itself", function() |
| 70 | + assert(m:equals(m)) |
| 71 | + end) |
| 72 | + it("copy of itself", function() |
| 73 | + assert(m:equals(m:copy())) |
| 74 | + end) |
| 75 | + it("supports tolerance", function() |
| 76 | + assert(matrix.square{0}:equals(matrix.square{1e-9}, 1e-9)) |
| 77 | + end) |
| 78 | + it("different dimensions", function() |
| 79 | + assert(not m:equals(matrix.with_height(3, {unpack(m)}))) |
| 80 | + end) |
| 81 | + it("different values", function() |
| 82 | + assert(not m:equals(matrix.zero(2, 3))) |
| 83 | + end) |
| 84 | + end) |
| 85 | + describe("getters & setters", function() |
| 86 | + it("elements", function() |
| 87 | + local m = matrix.with_height(2, { |
| 88 | + 0, 2, 3; |
| 89 | + 4, 5, 9 |
| 90 | + }) |
| 91 | + assert.equal(0, m:get(1, 1)) |
| 92 | + assert.equal(9, m:get(2, 3)) |
| 93 | + m:set(1, 1, 1) |
| 94 | + m:set(2, 3, 6) |
| 95 | + assert.equal(1, m:get(1, 1)) |
| 96 | + assert.equal(6, m:get(2, 3)) |
| 97 | + assert.equal(matrix.with_height(2, { |
| 98 | + 1, 2, 3; |
| 99 | + 4, 5, 6 |
| 100 | + }), m) |
| 101 | + end) |
| 102 | + it("shape", function() |
| 103 | + local m = matrix.with_height(2, { |
| 104 | + 1, 2, 3; |
| 105 | + 4, 5, 6 |
| 106 | + }) |
| 107 | + m:set_height(2) |
| 108 | + assert.equal(matrix.with_height(2, {unpack(m)}), m) |
| 109 | + m:set_width(6) |
| 110 | + assert.equal(matrix.with_width(6, {unpack(m)}), m) |
| 111 | + end) |
| 112 | + end) |
| 113 | + it("transposition", function() |
| 114 | + assert.equal(matrix.square{ |
| 115 | + 1, 2; |
| 116 | + 3, 4 |
| 117 | + }, matrix.square{ |
| 118 | + 1, 3; |
| 119 | + 2, 4 |
| 120 | + }:transposed()) |
| 121 | + assert.equal(matrix.with_height(3, { |
| 122 | + 1, 4; |
| 123 | + 2, 5; |
| 124 | + 3, 6 |
| 125 | + }), matrix.with_height(2, { |
| 126 | + 1, 2, 3; |
| 127 | + 4, 5, 6 |
| 128 | + }):transposed()) |
| 129 | + end) |
| 130 | + it("negation", function() |
| 131 | + local m = matrix.square{1} |
| 132 | + m:negate() |
| 133 | + assert.equal(matrix.square{-1}, m) |
| 134 | + end) |
| 135 | + it("addition", function() |
| 136 | + local m = matrix.square{1, 2, 3, 4} |
| 137 | + m:add(matrix.square{4, 3, 2, 1}) |
| 138 | + assert.equal(matrix.square{5, 5, 5, 5}, m) |
| 139 | + end) |
| 140 | + it("subtraction", function() |
| 141 | + local m = matrix.square{5, 5, 5, 5} |
| 142 | + m:subtract(matrix.square{4, 3, 2, 1}) |
| 143 | + assert.equal(matrix.square{1, 2, 3, 4}, m) |
| 144 | + end) |
| 145 | + it("scalar multiplication", function() |
| 146 | + local m = matrix.square{1, 2, 3, 4} |
| 147 | + m:scale(2) |
| 148 | + assert.equal(matrix.square{2, 4, 6, 8}, m) |
| 149 | + end) |
| 150 | + it("matrix-vector multiplication", function() |
| 151 | + assert.equal(vector.new{1*1 + 2*2 + 3*3, 4*1 + 5*2 + 6*3}, |
| 152 | + matrix.with_height(2, { |
| 153 | + 1, 2, 3; |
| 154 | + 4, 5, 6 |
| 155 | + }):multiply_column_vector({1, 2, 3})) |
| 156 | + end) |
| 157 | + it("vector-matrix multiplication", function() |
| 158 | + assert.equal(vector.new{1*1 + 2*4, 1*2 + 2*5, 1*3 + 2*6}, |
| 159 | + matrix.with_height(2, { |
| 160 | + 1, 2, 3; |
| 161 | + 4, 5, 6 |
| 162 | + }):multiply_row_vector({1, 2})) |
| 163 | + end) |
| 164 | + describe("matrix multiplication", function() |
| 165 | + it("dot product", function() |
| 166 | + assert.equal(matrix.square{1*1 + 2*2 + 3*3}, matrix.with_width(3, {1, 2, 3}) |
| 167 | + :multiply_matrix(matrix.with_height(3, {1, 2, 3}))) |
| 168 | + end) |
| 169 | + it("outer product", function() |
| 170 | + assert.equal(matrix.square{ |
| 171 | + 1*1, 1*2, 1*3; |
| 172 | + 1*2, 2*2, 2*3; |
| 173 | + 1*3, 3*2, 3*3 |
| 174 | + }, matrix.with_height(3, {1, 2, 3}) |
| 175 | + :multiply_matrix(matrix.with_width(3, {1, 2, 3}))) |
| 176 | + end) |
| 177 | + local m = matrix.with_height(2, { |
| 178 | + 1, 2, 3; |
| 179 | + 4, 5, 6 |
| 180 | + }) |
| 181 | + it("2x3 * 3x2 -> 2x2", function() |
| 182 | + assert.equal(matrix.square{ |
| 183 | + 1*1 + 2*2 + 3*3, 1*4 + 2*5 + 3*6; |
| 184 | + 4*1 + 5*2 + 6*3, 4*4 + 5*5 + 6*6 |
| 185 | + }, m:multiply_matrix(m:transposed())) |
| 186 | + end) |
| 187 | + it("3x2 * 2x3 -> 3x3", function() |
| 188 | + assert.equal(matrix.square{ |
| 189 | + 1*1 + 4*4, 1*2 + 4*5, 1*3 + 4*6; |
| 190 | + 2*1 + 5*4, 2*2 + 5*5, 2*3 + 5*6; |
| 191 | + 3*1 + 6*4, 3*2 + 6*5, 3*3 + 6*6 |
| 192 | + }, m:transposed():multiply_matrix(m)) |
| 193 | + end) |
| 194 | + end) |
| 195 | + describe("operators", function() |
| 196 | + it("equals", function() |
| 197 | + assert(matrix.square{1} == matrix.square{1}) |
| 198 | + end) |
| 199 | + it("negation", function() |
| 200 | + assert.equal(matrix.square{-1}, -matrix.square{1}) |
| 201 | + end) |
| 202 | + it("subtraction", function() |
| 203 | + assert.equal(matrix.square{1}, matrix.square{3} - matrix.square{2}) |
| 204 | + end) |
| 205 | + it("addition", function() |
| 206 | + assert.equal(matrix.square{3}, matrix.square{1} + matrix.square{2}) |
| 207 | + end) |
| 208 | + describe("multiplication", function() |
| 209 | + it("scalar-matrix, matrix-scalar", function() |
| 210 | + assert.equal(matrix.square{2, 4, 6, 8}, 2*matrix.square{1, 2, 3, 4}) |
| 211 | + assert.equal(matrix.square{2, 4, 6, 8}, matrix.square{1, 2, 3, 4}*2) |
| 212 | + end) |
| 213 | + local m = matrix.with_height(2, { |
| 214 | + 1, 2, 3; |
| 215 | + 4, 5, 6 |
| 216 | + }) |
| 217 | + it("matrix-vector", function() |
| 218 | + local v = vector.new{1, 2, 3} |
| 219 | + assert.equal(m:multiply_column_vector(v), m*v) |
| 220 | + end) |
| 221 | + it("vector-matrix", function() |
| 222 | + local v = vector.new{1, 2} |
| 223 | + assert.equal(m:multiply_row_vector(v), v*m) |
| 224 | + end) |
| 225 | + it("matrix-matrix", function() |
| 226 | + local t = m:transposed() |
| 227 | + assert.equal(t:multiply_matrix(m), t*m) |
| 228 | + assert.equal(m:multiply_matrix(t), m*t) |
| 229 | + end) |
| 230 | + end) |
| 231 | + describe("exponentiation", function() |
| 232 | + it("rejects non-endomorphisms", function() |
| 233 | + assert.has_error(function() return matrix.zero(2, 3)^10 end) |
| 234 | + end) |
| 235 | + it("identity matrix for exponent 0", function() |
| 236 | + assert.equal(matrix.identity(2), matrix.zero(2)^0) |
| 237 | + end) |
| 238 | + it("equals repeated multiplication", function() |
| 239 | + local m = matrix.square{1, 2, 3, 4} |
| 240 | + assert.equal(m*m*m*m, m^4) |
| 241 | + end) |
| 242 | + it("inverts & multiplies", function() |
| 243 | + local m = matrix.square{1, 2, 3, 4} |
| 244 | + local inv = m:inverse() |
| 245 | + assert((inv*inv*inv):equals(m^-3, 1e-9)) |
| 246 | + end) |
| 247 | + end) |
| 248 | + end) |
| 249 | + describe("inversion", function() |
| 250 | + it("returns nothing for non-invertible matrices", function() |
| 251 | + assert.equal(nil, (matrix.with_width(2, {1, 2}):inverse())) |
| 252 | + assert.equal(nil, (matrix.with_height(2, {1, 2}):inverse())) |
| 253 | + assert.equal(nil, (matrix.square{ |
| 254 | + 1, 1; |
| 255 | + 0, 0, |
| 256 | + }:inverse())) |
| 257 | + end) |
| 258 | + it("works for a 2x2 matrix", function() |
| 259 | + assert((1/(1*4 - 2*3) * matrix.square{ |
| 260 | + 4, -2; |
| 261 | + -3, 1 |
| 262 | + }):equals(matrix.square{ |
| 263 | + 1, 2, |
| 264 | + 3, 4, |
| 265 | + }:inverse(), 1e-9)) |
| 266 | + end) |
| 267 | + it("works exactly for fractions", function() |
| 268 | + local fraction = require("math.fraction") |
| 269 | + local intfrac = fraction.from_number |
| 270 | + local m = matrix.square{ |
| 271 | + 4, -2; |
| 272 | + -3, 1 |
| 273 | + } |
| 274 | + m:scale(fraction.new(1, 1*4 - 2*3)) |
| 275 | + assert(m:equals(matrix.square{ |
| 276 | + intfrac(1), intfrac(2), |
| 277 | + intfrac(3), intfrac(4), |
| 278 | + }:inverse(intfrac(0)), intfrac(0))) |
| 279 | + end) |
| 280 | + it("works for random matrices", function() |
| 281 | + for n = 1, 10 do |
| 282 | + local id = matrix.identity(n) |
| 283 | + for _ = 1, 10 do |
| 284 | + local random_nums = {} |
| 285 | + for i = 1, n^2 do |
| 286 | + random_nums[i] = math.random(-1e7, 1e7) |
| 287 | + end |
| 288 | + local m = matrix.with_height(n, random_nums) |
| 289 | + -- It is highly probable that the matrix is invertible; |
| 290 | + -- we do not need to account for non-invertible matrices |
| 291 | + local inv = m:inverse(1e-6) |
| 292 | + assert(id:equals(m*inv, 1e-6)) |
| 293 | + assert(id:equals(inv*m, 1e-6)) |
| 294 | + end |
| 295 | + end |
| 296 | + end) |
| 297 | + end) |
| 298 | + describe("determinant", function() |
| 299 | + it("is zero for non-invertible matrices", function() |
| 300 | + assert.equal(0, matrix.with_width(2, {1, 2}):determinant()) |
| 301 | + assert.equal(0, matrix.with_height(2, {1, 2}):determinant()) |
| 302 | + assert.equal(0, matrix.square{ |
| 303 | + 1, 1; |
| 304 | + 0, 0, |
| 305 | + }:determinant()) |
| 306 | + end) |
| 307 | + it("2d", function() |
| 308 | + assert.near(1 * 4 - 2 * 3, matrix.square{ |
| 309 | + 1, 2; |
| 310 | + 3, 4 |
| 311 | + }:determinant(), 1e-9) |
| 312 | + end) |
| 313 | + it("3d", function() |
| 314 | + assert.near(60, matrix.square{ |
| 315 | + 1, 4, 3; |
| 316 | + 2, 5, 6; |
| 317 | + 9, 8, 7 |
| 318 | + }:determinant(), 1e-9) |
| 319 | + end) |
| 320 | + end) |
| 321 | +end) |
0 commit comments