import numpy as npImplementing a Matrix Class in Python
Matrix class following Lesson 10 of the fastai course (Part 2) and add some additional functionality.
Background
In Lesson 10 of the fastai course (Part 2) Jeremy introduces the following Matrix class to allow for PyTorch- and NumPy-like indexing into a Python list:
class Matrix:
def __init__(self, xs): self.xs = xs
def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]In this notebook, I’ll expand this class to add the following methods to further mimic NumPy arrays:
- A
__repr__method which returns a displayable/printable representation of theMatrix - A
shapemethod which displays the shape of theMatrix. - A
minmethod - A
maxmethod
Expanding the Matrix Class
I’ll start by defining the Matrix class as done in Lesson 10 and instantiating a Matrix object:
class Matrix:
def __init__(self, xs): self.xs = xs
def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]lst1 = [
[1, 2, 3],
[4, 5, 6]
]m = Matrix(lst1)I’ll illustrate the __getitem__ method:
m[1,2]6
Adding a __repr__ Method
Currently when I print out the Matrix object, it just shows a default object representation. From the Python docs:
Called by the repr() built-in function to compute the “official” string representation of an object. If at all possible, this should look like a valid Python expression that could be used to recreate an object with the same value (given an appropriate environment). If this is not possible, a string of the form <…some useful description…> should be returned. The return value must be a string object.
m<__main__.Matrix at 0x7ce8aeae24d0>
NumPy has a prettier object representation:
arr = np.array(lst1)
arrarray([[1, 2, 3],
[4, 5, 6]])
I’ll add a __repr__ method that does that (thanks Claude and ChatGPT!):
class Matrix:
def __init__(self, xs): self.xs = xs
def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]
def __repr__(self):
# Convert each element to string
str_matrix = [[str(elem) for elem in row] for row in self.xs]
# Compute max column widths
col_widths = [max(map(len, col)) for col in zip(*str_matrix)]
# Format each row with proper padding
formatted_rows = (
'[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']'
for row in str_matrix
)
# Join rows with newline and proper indentation
matrix_str = ',\n '.join(formatted_rows)
# Add the class name and wrapping brackets
return f"Matrix([{matrix_str}])"That looks much prettier! Albeit a lot more code in the class definition.
m = Matrix(lst1)
mMatrix([[1, 2, 3],
[4, 5, 6]])
I’ll walk through each line in this __repr__ method. We start with a nested list comprehension with maintains the nested list structure of self.xs but replaces the values with strings:
str_matrix = [[str(elem) for elem in row] for row in m.xs]
str_matrix[['1', '2', '3'], ['4', '5', '6']]
To determine the “width” of each column, the following list comprehension does the following:
- Transpose
str_matrixwithzip(*str_matrix) - Calculate the number of characters in each column with
map(len, col) - Return the maximum length of each value in each column with
max()
col_widths = [max(map(len, col)) for col in zip(*str_matrix)]
col_widths[1, 1, 1]
list(zip(*str_matrix)) # transposed matrix[('1', '4'), ('2', '5'), ('3', '6')]
[list(map(len, col)) for col in zip(*str_matrix)] # width of each value in each column[[1, 1], [1, 1], [1, 1]]
[max(map(len, col)) for col in zip(*str_matrix)] # maximum width of each column value[1, 1, 1]
The next line is another nested list comprehension. We iterate over each row in str_matrix inside the .join call and then do the following:
ziptogether therowand thecol_widthslist and iterate over it- for each
elem, widthinzip(row, col_widths)we callelem.rjust(width)which right-adjustseleminside the given stringwidth.
formatted_rows is a list of strings where each row is a string representation of each row in the Matrix:
formatted_rows = (
'[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']'
for row in str_matrix
)
list(formatted_rows)['[1, 2, 3]', '[4, 5, 6]']
Building up the formatted_rows logic line by line:
[row for row in str_matrix][['1', '2', '3'], ['4', '5', '6']]
[list(((elem, width) for elem, width in zip(row, col_widths))) for row in str_matrix][[('1', 1), ('2', 1), ('3', 1)], [('4', 1), ('5', 1), ('6', 1)]]
[list((elem.rjust(width) for elem, width in zip(row, col_widths))) for row in str_matrix][['1', '2', '3'], ['4', '5', '6']]
[', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) for row in str_matrix]['1, 2, 3', '4, 5, 6']
['[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']' for row in str_matrix]['[1, 2, 3]', '[4, 5, 6]']
The final line add a new line for each row and indents it to make it print prettily. Note there are 8 spaces after the newline character \n for the 8 characters in the string Matrix([:
formatted_rows = (
'[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']'
for row in str_matrix
)
matrix_str = ',\n '.join(formatted_rows)
print(f"Matrix([{matrix_str}])")Matrix([[1, 2, 3],
[4, 5, 6]])
Testing out the __repr__ method on a Matrix with a different size and different values (note the right-adjustment of the values):
lst2 = [[1, 2], [2, 30], [4, 50]]m = Matrix(lst2)
mMatrix([[1, 2],
[2, 30],
[4, 50]])
Pretty! And pretty complicated.
Adding a shape Property
One of the most useful properties in PyTorch and NumPy is the shape of an array or tensor. I constantly use it throughout my coding to make sure I’m dealing with the appropriately shaped tensors and arrays.
np.array(lst2).shape(3, 2)
I’ll implement a shape property for my Matrix:
class Matrix:
def __init__(self, xs): self.xs = xs
def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]
@property
def shape(self): return len(self.xs), len(self.xs[0])
def __repr__(self):
# Convert each element to string
str_matrix = [[str(elem) for elem in row] for row in self.xs]
# Compute max column widths
col_widths = [max(map(len, col)) for col in zip(*str_matrix)]
# Format each row with proper padding
formatted_rows = (
'[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']'
for row in str_matrix
)
# Join rows with newline and proper indentation
matrix_str = ',\n '.join(formatted_rows)
# Add the class name and wrapping brackets
return f"Matrix([{matrix_str}])"m = Matrix(lst2)
mMatrix([[1, 2],
[2, 30],
[4, 50]])
m.shape(3, 2)
Explaining my one-liner: since the Matrix is always going to be rectangular, the number of rows is the len of the list xs and the number of columns is the len of either of the rows.
len(m.xs), len(m.xs[0])(3, 2)
That’s it for the shape method! Simple.
Adding a min and max Method
The last pair of methods I’ll implement are a min and a max method similar to NumPy.
arr = np.array(lst2)
arrarray([[ 1, 2],
[ 2, 30],
[ 4, 50]])
arr.min()1
arr.min(axis=0)array([1, 2])
arr.min(axis=1)array([1, 2, 4])
arr.max()50
arr.max(axis=0)array([ 4, 50])
arr.max(axis=1)array([ 2, 30, 50])
from itertools import chain
class Matrix:
def __init__(self, xs): self.xs = xs
def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]
@property
def shape(self): return len(self.xs), len(self.xs[0])
def min(self, axis=None):
if axis is None: return min(chain(*self.xs))
elif axis == 0: return min(self.xs)
elif axis == 1: return min(zip(*self.xs))
else: raise IndexError("Matrix only has two axes")
def max(self, axis=None):
if axis is None: return max(chain(*self.xs))
elif axis == 0: return max(self.xs)
elif axis == 1: return max(zip(*self.xs))
else: raise IndexError("Matrix only has two axes")
def __repr__(self):
# Convert each element to string
str_matrix = [[str(elem) for elem in row] for row in self.xs]
# Compute max column widths
col_widths = [max(map(len, col)) for col in zip(*str_matrix)]
# Format each row with proper padding
formatted_rows = (
'[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']'
for row in str_matrix
)
# Join rows with newline and proper indentation
matrix_str = ',\n '.join(formatted_rows)
# Add the class name and wrapping brackets
return f"Matrix([{matrix_str}])"m = Matrix(lst2)
mMatrix([[1, 2],
[2, 30],
[4, 50]])
When no axis is specified, min returns the minimum value of the flattened Matrix (flattened using chain(*self.xs)).
m.min()1
When axis=0, min returns the minimum value in each row, which is simply the built-in return value when min is applied to a list (min(self.xs)).
m.min(axis=0)[1, 2]
When axis=1, min returns the minimum value in each column. I first transpose the list and then pass it through min to get the desired result:
m.min(axis=1)(1, 2, 4)
Finally, if the axis is some other value, I throw an IndexError:
m.min(axis=3)--------------------------------------------------------------------------- IndexError Traceback (most recent call last) <ipython-input-44-018777b7bdbb> in <cell line: 1>() ----> 1 m.min(axis=3) <ipython-input-38-bda5d9aaa71a> in min(self, axis) 12 elif axis == 0: return min(self.xs) 13 elif axis == 1: return min(zip(*self.xs)) ---> 14 else: raise IndexError("Matrix only has two axes") 15 16 def max(self, axis=None): IndexError: Matrix only has two axes
max operates similarly:
m.max(), m.max(axis=0), m.max(axis=1)(50, [4, 50], (2, 30, 50))
m.max(axis=3)--------------------------------------------------------------------------- IndexError Traceback (most recent call last) <ipython-input-46-7ad019f29af0> in <cell line: 1>() ----> 1 m.max(axis=3) <ipython-input-38-bda5d9aaa71a> in max(self, axis) 18 elif axis == 0: return max(self.xs) 19 elif axis == 1: return max(zip(*self.xs)) ---> 20 else: raise IndexError("Matrix only has two axes") 21 22 def __repr__(self): IndexError: Matrix only has two axes
That’s it! Another relatively simple implementation.
Final Thoughts
Creating my own Matrix class was fun and educational. I didn’t expect the object representation as a string to be so involved—and I haven’t even added functionality like NumPy or PyTorch where they “summarize” very long arrays and tensors instead of listing out all of the values.
I also learned a few things along the way:
- Using
zipto transpose a list - Using the
@propertydecorator - Justifying string with
rjust - Implementing the built-in
IndexErrorfor situations when the user passes an incompatibleaxisargument
I’m assuming that we’ll be building on this Matrix class further down the road in Part 2, so it was helpful to get a jump on that and start thinking about how this class would function right now.
I hope you enjoyed this blog post! Follow me on Twitter @vishal_learner.