@@ -32,7 +32,7 @@ class Matrix:
32
32
def zeros (cls , rows : int , cols : int ):
33
33
if rows < 1 and cols < 1 :
34
34
raise ValueError ("invalid rows or columns provided" )
35
- return [[0 ] * rows ] * cols
35
+ return Matrix ( [[0 for _ in range ( rows )] for _ in range ( cols )])
36
36
37
37
def __init__ (self , data : list [list [int | float ]]) -> None :
38
38
if data .__len__ ():
@@ -51,25 +51,54 @@ def size(self):
51
51
return (self .data [0 ].__len__ (), self .data .__len__ ())
52
52
53
53
def get_row (self , index : int ):
54
- return [ d [ index ] for d in self .data ]
54
+ return self .data [ index ]
55
55
56
56
def get_column (self , index : int ):
57
- return self .data [ index ]
57
+ return [ d [ index ] for d in self .data ]
58
58
59
59
def __str__ (self ) -> str :
60
60
if self .data .__len__ ():
61
- return "| |"
62
- return " \n " .join (
63
- [ "| " + "" . join ([ f" { i :^5d } " for i in row ]) + " |" for row in self . data ]
64
- )
61
+ return "\n " . join (
62
+ [ "| " + "" .join ([ f" { i :^5d } " for i in row ]) + " |" for row in self . data ]
63
+ )
64
+ return "| |"
65
65
66
66
def __mul__ (self , other : "Matrix" ):
67
67
"""
68
68
This method overloads python's default multiplication operation between
69
69
two Matrix objects so that we can easily perform `*` operation.
70
70
"""
71
- # TODO
72
- pass
71
+ (self_rows , self_cols ) = self .size
72
+ (other_rows , other_cols ) = other .size
73
+ if self_rows != other_cols :
74
+ raise ValueError ("Dimensions mismatch" )
75
+
76
+ # by for loop and matrix.zeros
77
+ # this part is easy to understand but is a bit more computationally expensive
78
+
79
+ # result = Matrix.zeros(other_rows, self_cols)
80
+ # for row in range(other_rows):
81
+ # for col in range(self_cols):
82
+ # result.data[row][col] = sum(
83
+ # [a * b for (a, b) in zip(self.get_row(row), other.get_column(col))]
84
+ # )
85
+ # return result
86
+
87
+ # by comprehension quicker
88
+ return Matrix (
89
+ [
90
+ [
91
+ sum (
92
+ [
93
+ a * b
94
+ for (a , b ) in zip (self .get_row (row ), other .get_column (col ))
95
+ ]
96
+ )
97
+ for col in range (self_cols )
98
+ ]
99
+ for row in range (other_rows )
100
+ ]
101
+ )
73
102
74
103
75
104
if __name__ == "__main__" :
@@ -83,7 +112,13 @@ def __mul__(self, other: "Matrix"):
83
112
# [5, 6],
84
113
# ]
85
114
# )
86
- m1 = Matrix ([[1 , 2 ], [3 , 4 ]])
87
- m2 = Matrix ([[2 , 3 ], [4 , 5 ]])
88
115
89
- print ("result is: \n " , m1 * m2 )
116
+ # if we uncomment lines below, we get Dimensions mismatch exception
117
+ # m1 = Matrix([[1, 2, 3], [4, 5, 6]])
118
+ # m2 = Matrix([[2, 3, 4], [5, 6, 7]])
119
+ # print("result is: \n", m1 * m2)
120
+
121
+ m1 = Matrix ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
122
+ m2 = Matrix ([[2 , 3 ], [4 , 5 ], [6 , 7 ]])
123
+ print ("m1 X m2 = :" , m1 * m2 , sep = "\n " )
124
+ print ("m2 X m1 = :" , m2 * m1 , sep = "\n " )
0 commit comments