In [None]:
from pathlib import Path
import numpy as np

grid = np.array(
 list( 
 [x == "#" for x in line]
 for line in Path("game_of_life.txt").read_text().split()
 ),
 dtype=np.int8,
)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(grid, cmap="gray_r")

In [None]:
def update(grid: np.ndarray) -> np.ndarray:
 n, m = grid.shape
 next_grid = np.zeros((n, m), dtype=np.int8)

 for row in range(n):
 for col in range(m):
 live_neighbors = (
 np.sum(grid[row - 1 : row + 2, col - 1 : col + 2]) - grid[row, col]
 )
 if live_neighbors < 2 or live_neighbors > 3:
 next_grid[row, col] = 0
 elif live_neighbors == 3 and grid[row, col] == 0:
 next_grid[row, col] = 1
 else:
 next_grid[row, col] = grid[row, col]

 return next_grid


plt.imshow(update(grid), cmap="gray_r")

In [None]:
from matplotlib import ticker

fig, ax = plt.subplots(1, 4, figsize=(20, 5))
ax[0].imshow(grid, cmap="gray_r")
ax[1].imshow(update(grid), cmap="gray_r")
ax[2].imshow(update(update(grid)), cmap="gray_r")
ax[3].imshow(update(update(update(grid))), cmap="gray_r")
for ax_ in ax:
 ax_.xaxis.set_major_locator(ticker.NullLocator())
 ax_.yaxis.set_major_locator(ticker.NullLocator())

fig.set_tight_layout(True)


In [None]:
%timeit update(grid)

In [None]:
import numba


@numba.jit(nopython=True)
def update_numba(grid: np.ndarray) -> np.ndarray:
 n, m = grid.shape
 next_grid = np.zeros((n, m), dtype=np.int8)

 for row in range(n):
 for col in range(m):
 live_neighbors = (
 np.sum(grid[row - 1 : row + 2, col - 1 : col + 2]) - grid[row, col]
 )
 if live_neighbors < 2 or live_neighbors > 3:
 next_grid[row][col] = 0
 elif live_neighbors == 3 and grid[row][col] == 0:
 next_grid[row][col] = 1
 else:
 next_grid[row][col] = grid[row][col]

 return next_grid

In [None]:
plt.imshow(update_numba(grid), cmap="gray_r")

In [None]:
%timeit update_numba(grid)

In [None]:
%load_ext Cython

In [None]:
%%cython -a
import numpy as np

def update_cython(grid):
 n, m = grid.shape
 next_grid = np.zeros((n, m), dtype=np.int8)

 for row in range(n):
 for col in range(m):
 live_neighbors = np.sum(grid[row-1:row+2, col-1:col+2]) - grid[row, col]
 if live_neighbors < 2 or live_neighbors > 3:
 next_grid[row][col] = 0
 elif live_neighbors == 3 and grid[row][col] == 0:
 next_grid[row][col] = 1
 else:
 next_grid[row][col] = grid[row][col]
 
 return next_grid

In [None]:
plt.imshow(update_cython(grid))

In [None]:
%timeit update_cython(grid)

In [None]:
%%cython -a
import numpy as np

def update_cython2(grid):
 cdef int n, m, row, col, live_neighbors
 n, m = grid.shape
 next_grid = np.zeros((n, m), dtype=np.int8)

 for row in range(n):
 for col in range(m):
 live_neighbors = np.sum(grid[row-1:row+2, col-1:col+2]) - grid[row, col]
 if live_neighbors < 2 or live_neighbors > 3:
 next_grid[row][col] = 0
 elif live_neighbors == 3 and grid[row][col] == 0:
 next_grid[row][col] = 1
 else:
 next_grid[row][col] = grid[row][col]
 
 return next_grid

In [None]:
%timeit update_cython2(grid)

In [None]:
%%cython -a
import numpy as np
from cython import boundscheck, wraparound

@boundscheck(False)
@wraparound(False)
def update_cython3(signed char[:, :] grid):
 cdef int n, m, row, col, live_neighbors
 cdef signed char[:, :] next_grid
 
 n = grid.shape[0]
 m = grid.shape[1]
 next_grid = np.zeros((n, m), dtype=np.int8)

 for row in range(n):
 for col in range(m):
 live_neighbors = (
 grid[row-1, col-1] + grid[row-1, col] + grid[row-1, col+1] +
 grid[row, col-1] + grid[row, col+1] +
 grid[row+1, col-1] + grid[row+1, col] + grid[row+1, col+1]
 )
 
 if live_neighbors < 2 or live_neighbors > 3:
 next_grid[row][col] = 0
 elif live_neighbors == 3 and grid[row][col] == 0:
 next_grid[row][col] = 1
 else:
 next_grid[row][col] = grid[row][col]
 
 return next_grid

In [None]:
plt.imshow(update_cython(grid))

In [None]:
%timeit update_cython3(grid)

In [None]:
import numba

@numba.jit(nopython=True)
def update_numba(grid: np.ndarray) -> np.ndarray:
 next_grid = np.zeros(grid.shape, dtype=np.int8)
 n, m = grid.shape
 
 for row in range(n):
 for col in range(m):
 live_neighbors = (
 grid[row-1, col-1] + grid[row-1, col] + grid[row-1, col+1] +
 grid[row, col-1] + grid[row, col+1] +
 grid[row+1, col-1] + grid[row+1, col] + grid[row+1, col+1]
 )
 
 # If the number of surrounding live cells is < 2 or > 3 then we make the cell at grid[row][col] a dead cell
 if live_neighbors < 2 or live_neighbors > 3:
 next_grid[row][col] = 0
 # If the number of surrounding live cells is 3 and the cell at grid[row][col] was previously dead then make
 # the cell into a live cell
 elif live_neighbors == 3 and grid[row][col] == 0:
 next_grid[row][col] = 1
 # If the number of surrounding live cells is 3 and the cell at grid[row][col] is alive keep it alive
 else:
 next_grid[row][col] = grid[row][col]
 
 return next_grid

In [None]:
%timeit update_numba(grid)