#!/usr/bin/env python from __future__ import division from __future__ import print_function import numpy as np from mpi4py import MPI from time import time #=============================================================================# my_N = 3000 my_M = 3000 #=============================================================================# NORTH = 0 SOUTH = 1 EAST = 2 WEST = 3 def pprint(string, comm=MPI.COMM_WORLD): if comm.rank == 0: print(string) if __name__ == "__main__": comm = MPI.COMM_WORLD mpi_rows = int(np.floor(np.sqrt(comm.size))) mpi_cols = comm.size // mpi_rows if mpi_rows*mpi_cols > comm.size: mpi_cols -= 1 if mpi_rows*mpi_cols > comm.size: mpi_rows -= 1 pprint("Creating a %d x %d processor grid..." % (mpi_rows, mpi_cols) ) ccomm = comm.Create_cart( (mpi_rows, mpi_cols), periods=(True, True), reorder=True) my_mpi_row, my_mpi_col = ccomm.Get_coords( ccomm.rank ) neigh = [0,0,0,0] neigh[NORTH], neigh[SOUTH] = ccomm.Shift(0, 1) neigh[EAST], neigh[WEST] = ccomm.Shift(1, 1) # Create matrices my_A = np.random.normal(size=(my_N, my_M)).astype(np.float32) my_B = np.random.normal(size=(my_N, my_M)).astype(np.float32) my_C = np.zeros_like(my_A) tile_A = my_A tile_B = my_B tile_A_ = np.empty_like(my_A) tile_B_ = np.empty_like(my_A) req = [None, None, None, None] t0 = time() for r in range(mpi_rows): req[EAST] = ccomm.Isend(tile_A , neigh[EAST]) req[WEST] = ccomm.Irecv(tile_A_, neigh[WEST]) req[SOUTH] = ccomm.Isend(tile_B , neigh[SOUTH]) req[NORTH] = ccomm.Irecv(tile_B_, neigh[NORTH]) #t0 = time() my_C += np.dot(tile_A, tile_B) #t1 = time() req[0].Waitall(req) #t2 = time() #print("Time computing %6.2f %6.2f" % (t1-t0, t2-t1)) comm.barrier() t_total = time()-t0 t0 = time() np.dot(tile_A, tile_B) t_serial = time()-t0 pprint(78*"=") pprint("Computed (serial) %d x %d x %d in %6.2f seconds" % (my_M, my_M, my_N, t_serial)) pprint(" ... expecting parallel computation to take %6.2f seconds" % (mpi_rows*mpi_rows*mpi_cols*t_serial / comm.size)) pprint("Computed (parallel) %d x %d x %d in %6.2f seconds" % (mpi_rows*my_M, mpi_rows*my_M, mpi_cols*my_N, t_total)) #print "[%d] (%d,%d): %s" % (comm.rank, my_mpi_row, my_mpi_col, neigh) comm.barrier()