Implementing a Matrix Class in Python

python
deep learning
machine learning
In this blog post I implement a Matrix class following Lesson 10 of the fastai course (Part 2) and add some additional functionality.
Author

Vishal Bakshi

Published

September 2, 2024

Background

In Lesson 10 of the fastai course (Part 2) Jeremy introduces the following Matrix class to allow for PyTorch- and NumPy-like indexing into a Python list:

class Matrix:
  def __init__(self, xs): self.xs = xs
  def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]

In this notebook, I’ll expand this class to add the following methods to further mimic NumPy arrays:

  • A __repr__ method which returns a displayable/printable representation of the Matrix
  • A shape method which displays the shape of the Matrix.
  • A min method
  • A max method
import numpy as np

Expanding the Matrix Class

I’ll start by defining the Matrix class as done in Lesson 10 and instantiating a Matrix object:

class Matrix:
  def __init__(self, xs): self.xs = xs
  def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]
lst1 = [
    [1, 2, 3],
    [4, 5, 6]
]
m = Matrix(lst1)

I’ll illustrate the __getitem__ method:

m[1,2]
6

Adding a __repr__ Method

Currently when I print out the Matrix object, it just shows a default object representation. From the Python docs:

Called by the repr() built-in function to compute the “official” string representation of an object. If at all possible, this should look like a valid Python expression that could be used to recreate an object with the same value (given an appropriate environment). If this is not possible, a string of the form <…some useful description…> should be returned. The return value must be a string object.

m
<__main__.Matrix at 0x7ce8aeae24d0>

NumPy has a prettier object representation:

arr = np.array(lst1)
arr
array([[1, 2, 3],
       [4, 5, 6]])

I’ll add a __repr__ method that does that (thanks Claude and ChatGPT!):

class Matrix:
  def __init__(self, xs): self.xs = xs
  def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]
  def __repr__(self):
        # Convert each element to string
        str_matrix = [[str(elem) for elem in row] for row in self.xs]

        # Compute max column widths
        col_widths = [max(map(len, col)) for col in zip(*str_matrix)]

        # Format each row with proper padding
        formatted_rows = (
            '[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']'
            for row in str_matrix
        )

        # Join rows with newline and proper indentation
        matrix_str = ',\n        '.join(formatted_rows)

        # Add the class name and wrapping brackets
        return f"Matrix([{matrix_str}])"

That looks much prettier! Albeit a lot more code in the class definition.

m = Matrix(lst1)
m
Matrix([[1, 2, 3],
        [4, 5, 6]])

I’ll walk through each line in this __repr__ method. We start with a nested list comprehension with maintains the nested list structure of self.xs but replaces the values with strings:

str_matrix = [[str(elem) for elem in row] for row in m.xs]
str_matrix
[['1', '2', '3'], ['4', '5', '6']]

To determine the “width” of each column, the following list comprehension does the following:

  • Transpose str_matrix with zip(*str_matrix)
  • Calculate the number of characters in each column with map(len, col)
  • Return the maximum length of each value in each column with max()
col_widths = [max(map(len, col)) for col in zip(*str_matrix)]
col_widths
[1, 1, 1]
list(zip(*str_matrix)) # transposed matrix
[('1', '4'), ('2', '5'), ('3', '6')]
[list(map(len, col)) for col in zip(*str_matrix)] # width of each value in each column
[[1, 1], [1, 1], [1, 1]]
[max(map(len, col)) for col in zip(*str_matrix)] # maximum width of each column value
[1, 1, 1]

The next line is another nested list comprehension. We iterate over each row in str_matrix inside the .join call and then do the following:

  • zip together the row and the col_widths list and iterate over it
  • for each elem, width in zip(row, col_widths) we call elem.rjust(width) which right-adjusts elem inside the given string width.

formatted_rows is a list of strings where each row is a string representation of each row in the Matrix:

formatted_rows = (
            '[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']'
            for row in str_matrix
        )
list(formatted_rows)
['[1, 2, 3]', '[4, 5, 6]']

Building up the formatted_rows logic line by line:

[row for row in str_matrix]
[['1', '2', '3'], ['4', '5', '6']]
[list(((elem, width) for elem, width in zip(row, col_widths))) for row in str_matrix]
[[('1', 1), ('2', 1), ('3', 1)], [('4', 1), ('5', 1), ('6', 1)]]
[list((elem.rjust(width) for elem, width in zip(row, col_widths))) for row in str_matrix]
[['1', '2', '3'], ['4', '5', '6']]
[', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) for row in str_matrix]
['1, 2, 3', '4, 5, 6']
['[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']' for row in str_matrix]
['[1, 2, 3]', '[4, 5, 6]']

The final line add a new line for each row and indents it to make it print prettily. Note there are 8 spaces after the newline character \n for the 8 characters in the string Matrix([:

formatted_rows = (
            '[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']'
            for row in str_matrix
        )

matrix_str = ',\n        '.join(formatted_rows)
print(f"Matrix([{matrix_str}])")
Matrix([[1, 2, 3],
        [4, 5, 6]])

Testing out the __repr__ method on a Matrix with a different size and different values (note the right-adjustment of the values):

lst2 = [[1, 2], [2, 30], [4, 50]]
m = Matrix(lst2)
m
Matrix([[1,  2],
        [2, 30],
        [4, 50]])

Pretty! And pretty complicated.

Adding a shape Property

One of the most useful properties in PyTorch and NumPy is the shape of an array or tensor. I constantly use it throughout my coding to make sure I’m dealing with the appropriately shaped tensors and arrays.

np.array(lst2).shape
(3, 2)

I’ll implement a shape property for my Matrix:

class Matrix:
  def __init__(self, xs): self.xs = xs
  def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]

  @property
  def shape(self): return len(self.xs), len(self.xs[0])

  def __repr__(self):
        # Convert each element to string
        str_matrix = [[str(elem) for elem in row] for row in self.xs]

        # Compute max column widths
        col_widths = [max(map(len, col)) for col in zip(*str_matrix)]

        # Format each row with proper padding
        formatted_rows = (
            '[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']'
            for row in str_matrix
        )

        # Join rows with newline and proper indentation
        matrix_str = ',\n        '.join(formatted_rows)

        # Add the class name and wrapping brackets
        return f"Matrix([{matrix_str}])"
m = Matrix(lst2)
m
Matrix([[1,  2],
        [2, 30],
        [4, 50]])
m.shape
(3, 2)

Explaining my one-liner: since the Matrix is always going to be rectangular, the number of rows is the len of the list xs and the number of columns is the len of either of the rows.

len(m.xs), len(m.xs[0])
(3, 2)

That’s it for the shape method! Simple.

Adding a min and max Method

The last pair of methods I’ll implement are a min and a max method similar to NumPy.

arr = np.array(lst2)
arr
array([[ 1,  2],
       [ 2, 30],
       [ 4, 50]])
arr.min()
1
arr.min(axis=0)
array([1, 2])
arr.min(axis=1)
array([1, 2, 4])
arr.max()
50
arr.max(axis=0)
array([ 4, 50])
arr.max(axis=1)
array([ 2, 30, 50])
from itertools import chain

class Matrix:
  def __init__(self, xs): self.xs = xs
  def __getitem__(self, idxs): return self.xs[idxs[0]][idxs[1]]

  @property
  def shape(self): return len(self.xs), len(self.xs[0])

  def min(self, axis=None):
    if axis is None: return min(chain(*self.xs))
    elif axis == 0: return min(self.xs)
    elif axis == 1: return min(zip(*self.xs))
    else: raise IndexError("Matrix only has two axes")

  def max(self, axis=None):
    if axis is None: return max(chain(*self.xs))
    elif axis == 0: return max(self.xs)
    elif axis == 1: return max(zip(*self.xs))
    else: raise IndexError("Matrix only has two axes")

  def __repr__(self):
        # Convert each element to string
        str_matrix = [[str(elem) for elem in row] for row in self.xs]

        # Compute max column widths
        col_widths = [max(map(len, col)) for col in zip(*str_matrix)]

        # Format each row with proper padding
        formatted_rows = (
            '[' + ', '.join(elem.rjust(width) for elem, width in zip(row, col_widths)) + ']'
            for row in str_matrix
        )

        # Join rows with newline and proper indentation
        matrix_str = ',\n        '.join(formatted_rows)

        # Add the class name and wrapping brackets
        return f"Matrix([{matrix_str}])"
m = Matrix(lst2)
m
Matrix([[1,  2],
        [2, 30],
        [4, 50]])

When no axis is specified, min returns the minimum value of the flattened Matrix (flattened using chain(*self.xs)).

m.min()
1

When axis=0, min returns the minimum value in each row, which is simply the built-in return value when min is applied to a list (min(self.xs)).

m.min(axis=0)
[1, 2]

When axis=1, min returns the minimum value in each column. I first transpose the list and then pass it through min to get the desired result:

m.min(axis=1)
(1, 2, 4)

Finally, if the axis is some other value, I throw an IndexError:

m.min(axis=3)
IndexError: Matrix only has two axes

max operates similarly:

m.max(), m.max(axis=0), m.max(axis=1)
(50, [4, 50], (2, 30, 50))
m.max(axis=3)
IndexError: Matrix only has two axes

That’s it! Another relatively simple implementation.

Final Thoughts

Creating my own Matrix class was fun and educational. I didn’t expect the object representation as a string to be so involved—and I haven’t even added functionality like NumPy or PyTorch where they “summarize” very long arrays and tensors instead of listing out all of the values.

I also learned a few things along the way:

  • Using zip to transpose a list
  • Using the @property decorator
  • Justifying string with rjust
  • Implementing the built-in IndexError for situations when the user passes an incompatible axis argument

I’m assuming that we’ll be building on this Matrix class further down the road in Part 2, so it was helpful to get a jump on that and start thinking about how this class would function right now.

I hope you enjoyed this blog post! Follow me on Twitter @vishal_learner.