from matplotlib import pyplot as plt
import numpy as np
import torch
from IPython.display import HTML, display
def set_default(figsize=(10, 10), dpi=100):
plt.style.use(['dark_background', 'bmh'])
plt.rc('axes', facecolor='k')
plt.rc('figure', facecolor='k')
plt.rc('figure', figsize=figsize, dpi=dpi)
def plot_data(X, y, d=0, auto=False, zoom=1):
X = X.cpu()
y = y.cpu()
plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
plt.axis('square')
plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
if auto is True: plt.axis('equal')
plt.axis('off')
_m, _c = 0, '.15'
plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
def plot_model(X, y, model):
model.cpu()
mesh = np.arange(-1.1, 1.1, 0.01)
xx, yy = np.meshgrid(mesh, mesh)
with torch.no_grad():
data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
Z = model(data).detach()
Z = np.argmax(Z, axis=1).reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
plot_data(X, y)
zieger = plt.imread('res/ziegler.png')
def show_scatterplot(X, colors, title='', axis=True):
colors = zieger[colors[:,0], colors[:,1]]
X = X.numpy()
# plt.figure()
plt.axis('equal')
plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
# plt.grid(True)
plt.title(title)
plt.axis('off')
_m, _c = 0, '.15'
if axis:
plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
def plot_bases(bases, plotting=True, width=0.04):
bases[2:] -= bases[:2]
# if plot_bases.a: plot_bases.a.set_visible(False)
# if plot_bases.b: plot_bases.b.set_visible(False)
if plotting:
plot_bases.a = plt.arrow(*bases[0], *bases[2], width=width, color='r', zorder=10, alpha=1., length_includes_head=True)
plot_bases.b = plt.arrow(*bases[1], *bases[3], width=width, color='g', zorder=10, alpha=1., length_includes_head=True)
plot_bases.a = None
plot_bases.b = None
def show_mat(mat, vect, prod, threshold=-1):
# Subplot grid definition
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharex=False, sharey=True,
gridspec_kw={'width_ratios':[5,1,1]})
# Plot matrices
cax1 = ax1.matshow(mat.numpy(), clim=(-1, 1))
ax2.matshow(vect.numpy(), clim=(-1, 1))
cax3 = ax3.matshow(prod.numpy(), clim=(threshold, 1))
# Set titles
ax1.set_title(f'A: {mat.size(0)} \u00D7 {mat.size(1)}')
ax2.set_title(f'a^(i): {vect.numel()}')
ax3.set_title(f'p: {prod.numel()}')
# Remove xticks for vectors
ax2.set_xticks(tuple())
ax3.set_xticks(tuple())
# Plot colourbars
fig.colorbar(cax1, ax=ax2)
fig.colorbar(cax3, ax=ax3)
# Fix y-axis limits
ax1.set_ylim(bottom=max(len(prod), len(vect)) - 0.5)
colors = dict(
aqua='#8dd3c7',
yellow='#ffffb3',
lavender='#bebada',
red='#fb8072',
blue='#80b1d3',
orange='#fdb462',
green='#b3de69',
pink='#fccde5',
grey='#d9d9d9',
violet='#bc80bd',
unk1='#ccebc5',
unk2='#ffed6f',
)
def _cstr(s, color='black'):
if s == ' ':
return f' '
else:
return f'{s} '
# print html
def _print_color(t):
display(HTML(''.join([_cstr(ti, color=ci) for ti, ci in t])))
# get appropriate color for value
def _get_clr(value):
colors = ('#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8',
'#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
'#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
'#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e')
value = int((value * 100) / 5)
if value == len(colors): value -= 1 # fixing bugs...
return colors[value]
def _visualise_values(output_values, result_list):
text_colours = []
for i in range(len(output_values)):
text = (result_list[i], _get_clr(output_values[i]))
text_colours.append(text)
_print_color(text_colours)
def print_colourbar():
color_range = torch.linspace(-2.5, 2.5, 20)
to_print = [(f'{x:.2f}', _get_clr((x+2.5)/5)) for x in color_range]
_print_color(to_print)
# Let's only focus on the last time step for now
# First, the cell state (Long term memory)
def plot_state(data, state, b, decoder):
actual_data = decoder(data[b, :, :].numpy())
seq_len = len(actual_data)
seq_len_w_pad = len(state)
for s in range(state.size(2)):
states = torch.sigmoid(state[:, b, s])
_visualise_values(states[seq_len_w_pad - seq_len:], list(actual_data))