Skip to content

Commit fc9cc1f

Browse files
committed
matrix multiplication example
1 parent 7d1b5e9 commit fc9cc1f

File tree

1 file changed

+47
-12
lines changed

1 file changed

+47
-12
lines changed

problem_solving/basic/matrix_multiplication.py

+47-12
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class Matrix:
3232
def zeros(cls, rows: int, cols: int):
3333
if rows < 1 and cols < 1:
3434
raise ValueError("invalid rows or columns provided")
35-
return [[0] * rows] * cols
35+
return Matrix([[0 for _ in range(rows)] for _ in range(cols)])
3636

3737
def __init__(self, data: list[list[int | float]]) -> None:
3838
if data.__len__():
@@ -51,25 +51,54 @@ def size(self):
5151
return (self.data[0].__len__(), self.data.__len__())
5252

5353
def get_row(self, index: int):
54-
return [d[index] for d in self.data]
54+
return self.data[index]
5555

5656
def get_column(self, index: int):
57-
return self.data[index]
57+
return [d[index] for d in self.data]
5858

5959
def __str__(self) -> str:
6060
if self.data.__len__():
61-
return "| |"
62-
return "\n".join(
63-
["| " + "".join([f"{i:^5d}" for i in row]) + " |" for row in self.data]
64-
)
61+
return "\n".join(
62+
["| " + "".join([f"{i:^5d}" for i in row]) + " |" for row in self.data]
63+
)
64+
return "| |"
6565

6666
def __mul__(self, other: "Matrix"):
6767
"""
6868
This method overloads python's default multiplication operation between
6969
two Matrix objects so that we can easily perform `*` operation.
7070
"""
71-
# TODO
72-
pass
71+
(self_rows, self_cols) = self.size
72+
(other_rows, other_cols) = other.size
73+
if self_rows != other_cols:
74+
raise ValueError("Dimensions mismatch")
75+
76+
# by for loop and matrix.zeros
77+
# this part is easy to understand but is a bit more computationally expensive
78+
79+
# result = Matrix.zeros(other_rows, self_cols)
80+
# for row in range(other_rows):
81+
# for col in range(self_cols):
82+
# result.data[row][col] = sum(
83+
# [a * b for (a, b) in zip(self.get_row(row), other.get_column(col))]
84+
# )
85+
# return result
86+
87+
# by comprehension quicker
88+
return Matrix(
89+
[
90+
[
91+
sum(
92+
[
93+
a * b
94+
for (a, b) in zip(self.get_row(row), other.get_column(col))
95+
]
96+
)
97+
for col in range(self_cols)
98+
]
99+
for row in range(other_rows)
100+
]
101+
)
73102

74103

75104
if __name__ == "__main__":
@@ -83,7 +112,13 @@ def __mul__(self, other: "Matrix"):
83112
# [5, 6],
84113
# ]
85114
# )
86-
m1 = Matrix([[1, 2], [3, 4]])
87-
m2 = Matrix([[2, 3], [4, 5]])
88115

89-
print("result is: \n", m1 * m2)
116+
# if we uncomment lines below, we get Dimensions mismatch exception
117+
# m1 = Matrix([[1, 2, 3], [4, 5, 6]])
118+
# m2 = Matrix([[2, 3, 4], [5, 6, 7]])
119+
# print("result is: \n", m1 * m2)
120+
121+
m1 = Matrix([[1, 2, 3], [4, 5, 6]])
122+
m2 = Matrix([[2, 3], [4, 5], [6, 7]])
123+
print("m1 X m2 = :", m1 * m2, sep="\n")
124+
print("m2 X m1 = :", m2 * m1, sep="\n")

0 commit comments

Comments
 (0)