{ "cells": [ { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "# Translating French to English with Pytorch" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "%matplotlib inline\n", "import re, pickle, collections, bcolz, numpy as np, keras, sklearn, math, operator" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "from gensim.models import word2vec\n", "\n", "import torch, torch.nn as nn\n", "from torch.autograd import Variable\n", "from torch import optim\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "path='/data/datasets/fr-en-109-corpus/'\n", "dpath = 'data/translate/'" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "heading_collapsed": true }, "source": [ "## Prepare corpus" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "The French-English parallel corpus can be downloaded from http://www.statmt.org/wmt10/training-giga-fren.tar. It was created by Chris Callison-Burch, who crawled millions of web pages and then used 'a set of simple heuristics to transform French URLs onto English URLs (i.e. replacing \"fr\" with \"en\" and about 40 other hand-written rules), and assume that these documents are translations of each other'." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "fname=path+'giga-fren.release2.fixed'\n", "en_fname = fname+'.en'\n", "fr_fname = fname+'.fr'" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "To make this problem a little simpler so we can train our model more quickly, we'll just learn to translate questions that begin with 'Wh' (e.g. what, why, where which). Here are our regexps that filter the sentences we want." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "52331" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "re_eq = re.compile('^(Wh[^?.!]+\\?)')\n", "re_fq = re.compile('^([^?.!]+\\?)')\n", "\n", "lines = ((re_eq.search(eq), re_fq.search(fq)) \n", " for eq, fq in zip(open(en_fname), open(fr_fname)))\n", "\n", "qs = [(e.group(), f.group()) for e,f in lines if e and f]; len(qs)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "[('What is light ?', 'Qu’est-ce que la lumière?'),\n", " ('Who are we?', 'Où sommes-nous?'),\n", " ('Where did we come from?', \"D'où venons-nous?\"),\n", " ('What would we do without it?', 'Que ferions-nous sans elle ?'),\n", " ('What is the absolute location (latitude and longitude) of Badger, Newfoundland and Labrador?',\n", " 'Quelle sont les coordonnées (latitude et longitude) de Badger, à Terre-Neuve-etLabrador?'),\n", " ('What is the major aboriginal group on Vancouver Island?',\n", " 'Quel est le groupe autochtone principal sur l’île de Vancouver?')]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "qs[:6]" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "Because it takes a while to load the data, we save the results to make it easier to load in later." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "pickle.dump(qs, open(dpath+'fr-en-qs.pkl', 'wb'))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "qs = pickle.load(open(dpath+'fr-en-qs.pkl', 'rb'))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "en_qs, fr_qs = zip(*qs)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "Because we are translating at word level, we need to tokenize the text first. (Note that it is also possible to translate at character level, which doesn't require tokenizing.) There are many tokenizers available, but we found we got best results using these simple heuristics." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "re_apos = re.compile(r\"(\\w)'s\\b\") # make 's a separate word\n", "re_mw_punc = re.compile(r\"(\\w[’'])(\\w)\") # other ' in a word creates 2 words\n", "re_punc = re.compile(\"([\\\"().,;:/_?!—])\") # add spaces around punctuation\n", "re_mult_space = re.compile(r\" *\") # replace multiple spaces with just one\n", "\n", "def simple_toks(sent):\n", " sent = re_apos.sub(r\"\\1 's\", sent)\n", " sent = re_mw_punc.sub(r\"\\1 \\2\", sent)\n", " sent = re_punc.sub(r\" \\1 \", sent).replace('-', ' ')\n", " sent = re_mult_space.sub(' ', sent)\n", " return sent.lower().split()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "[['qu’', 'est', 'ce', 'que', 'la', 'lumière', '?'],\n", " ['où', 'sommes', 'nous', '?'],\n", " [\"d'\", 'où', 'venons', 'nous', '?'],\n", " ['que', 'ferions', 'nous', 'sans', 'elle', '?']]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fr_qtoks = list(map(simple_toks, fr_qs)); fr_qtoks[:4]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "[['what', 'is', 'light', '?'],\n", " ['who', 'are', 'we', '?'],\n", " ['where', 'did', 'we', 'come', 'from', '?'],\n", " ['what', 'would', 'we', 'do', 'without', 'it', '?']]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "en_qtoks = list(map(simple_toks, en_qs)); en_qtoks[:4]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "['rachel', \"'s\", 'baby', 'is', 'cuter', 'than', 'other', \"'s\", '.']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "simple_toks(\"Rachel's baby is cuter than other's.\")" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "Special tokens used to pad the end of sentences, and to mark the start of a sentence." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "PAD = 0; SOS = 1" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "Enumerate the unique words (*vocab*) in the corpus, and also create the reverse map (word->index). Then use this mapping to encode every sentence as a list of int indices." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "def toks2ids(sents):\n", " voc_cnt = collections.Counter(t for sent in sents for t in sent)\n", " vocab = sorted(voc_cnt, key=voc_cnt.get, reverse=True)\n", " vocab.insert(PAD, \"\")\n", " vocab.insert(SOS, \"\")\n", " w2id = {w:i for i,w in enumerate(vocab)}\n", " ids = [[w2id[t] for t in sent] for sent in sents]\n", " return ids, vocab, w2id, voc_cnt" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "fr_ids, fr_vocab, fr_w2id, fr_counts = toks2ids(fr_qtoks)\n", "en_ids, en_vocab, en_w2id, en_counts = toks2ids(en_qtoks)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "heading_collapsed": true }, "source": [ "## Word vectors" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "Stanford's GloVe word vectors can be downloaded from https://nlp.stanford.edu/projects/glove/ (in the code below we have preprocessed them into a bcolz array). We use these because each individual word has a single word vector, which is what we need for translation. Word2vec, on the other hand, often uses multi-word phrases." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "def load_glove(loc):\n", " return (bcolz.open(loc+'.dat')[:],\n", " pickle.load(open(loc+'_words.pkl','rb'), encoding='latin1'),\n", " pickle.load(open(loc+'_idx.pkl','rb'), encoding='latin1'))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "en_vecs, en_wv_word, en_wv_idx = load_glove('/data/datasets/nlp/glove/results/6B.100d')\n", "en_w2v = {w: en_vecs[en_wv_idx[w]] for w in en_wv_word}\n", "n_en_vec, dim_en_vec = en_vecs.shape" ] }, { "cell_type": "code", "execution_count": 87, "metadata": { "collapsed": false, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([-0.32306999, -0.87616003, 0.21977 , 0.25268 , 0.22976001,\n", " 0.73879999, -0.37954 , -0.35306999, -0.84368998, -1.11129999,\n", " -0.30265999, 0.33177999, -0.25113001, 0.30447999, -0.077491 ,\n", " -0.89815003, 0.092496 , -1.14069998, -0.58323997, 0.66869003,\n", " -0.23122001, -0.95854998, 0.28262001, -0.078848 , 0.75314999,\n", " 0.26583999, 0.34220001, -0.33949 , 0.95608002, 0.065641 ,\n", " 0.45747 , 0.39835 , 0.57964998, 0.39267001, -0.21851 ,\n", " 0.58794999, -0.55998999, 0.63367999, -0.043983 , -0.68730998,\n", " -0.37841001, 0.38025999, 0.61641002, -0.88269001, -0.12346 ,\n", " -0.37928 , -0.38317999, 0.23868001, 0.66850001, -0.43320999,\n", " -0.11065 , 0.081723 , 1.15690005, 0.78957999, -0.21223 ,\n", " -2.3211 , -0.67806 , 0.44560999, 0.65706998, 0.1045 ,\n", " 0.46217 , 0.19912 , 0.25802001, 0.057194 , 0.53443003,\n", " -0.43133 , -0.34311 , 0.59789002, -0.58416998, 0.068995 ,\n", " 0.23943999, -0.85180998, 0.30379 , -0.34176999, -0.25746 ,\n", " -0.031101 , -0.16285001, 0.45168999, -0.91627002, 0.64521003,\n", " 0.73281002, -0.22752 , 0.30226001, 0.044801 , -0.83740997,\n", " 0.55005997, -0.52506 , -1.73570001, 0.47510001, -0.70486999,\n", " 0.056939 , -0.71319997, 0.089623 , 0.41394001, -1.33630002,\n", " -0.61914998, -0.33089 , -0.52881002, 0.16483 , -0.98878002], dtype=float32)" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "en_w2v['king']" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "For French word vectors, we're using those from http://fauconnier.github.io/index.html" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "w2v_path='/data/datasets/nlp/frWac_non_lem_no_postag_no_phrase_200_skip_cut100.bin'\n", "fr_model = word2vec.Word2Vec.load_word2vec_format(w2v_path, binary=True)\n", "fr_voc = fr_model.vocab\n", "dim_fr_vec = 200" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "We need to map each word index in our vocabs to their word vector. Not every word in our vocabs will be in our word vectors, since our tokenization approach won't be identical to the word vector creators - in these cases we simply create a random vector." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "def create_emb(w2v, targ_vocab, dim_vec):\n", " vocab_size = len(targ_vocab)\n", " emb = np.zeros((vocab_size, dim_vec))\n", " found=0\n", "\n", " for i, word in enumerate(targ_vocab):\n", " try: emb[i] = w2v[word]; found+=1\n", " except KeyError: emb[i] = np.random.normal(scale=0.6, size=(dim_vec,))\n", "\n", " return emb, found" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "((19549, 100), 17201)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "en_embs, found = create_emb(en_w2v, en_vocab, dim_en_vec); en_embs.shape, found" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "((26709, 200), 21878)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fr_embs, found = create_emb(fr_model, fr_vocab, dim_fr_vec); fr_embs.shape, found" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "heading_collapsed": true }, "source": [ "## Prep data" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "Each sentence has to be of equal length. Keras has a convenient function `pad_sequences` to truncate and/or pad each sentence as required - even although we're not using keras for the neural net, we can still use any functions from it we need!" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "((52331, 30), (52331, 30), (19549, 100))" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from keras.preprocessing.sequence import pad_sequences\n", "\n", "maxlen = 30\n", "en_padded = pad_sequences(en_ids, maxlen, 'int64', \"post\", \"post\")\n", "fr_padded = pad_sequences(fr_ids, maxlen, 'int64', \"post\", \"post\")\n", "en_padded.shape, fr_padded.shape, en_embs.shape" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "And of course we need to separate our training and test sets..." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "[(47097, 30), (5234, 30), (47097, 30), (5234, 30)]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn import model_selection\n", "fr_train, fr_test, en_train, en_test = model_selection.train_test_split(\n", " fr_padded, en_padded, test_size=0.1)\n", "\n", "[o.shape for o in (fr_train, fr_test, en_train, en_test)]" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "Here's an example of a French and English sentence, after encoding and padding." ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(array([ 272, 22, 3074, 9126, 5, 1600, 3, 2407, 5, 45, 4997,\n", " 11, 7, 33, 15, 10, 5, 3596, 2, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0]),\n", " array([ 4, 2817, 257, 3, 3925, 2725, 2107, 11, 7,\n", " 4, 8, 90, 11835, 2, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0]))" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fr_train[0], en_train[0]" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Model" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Basic encoder-decoder" ] }, { "cell_type": "code", "execution_count": 65, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "def long_t(arr): return Variable(torch.LongTensor(arr)).cuda()" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "fr_emb_t = torch.FloatTensor(fr_embs).cuda()\n", "en_emb_t = torch.FloatTensor(en_embs).cuda()" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "def create_emb(emb_mat, non_trainable=False):\n", " output_size, emb_size = emb_mat.size()\n", " emb = nn.Embedding(output_size, emb_size)\n", " emb.load_state_dict({'weight': emb_mat})\n", " if non_trainable:\n", " for param in emb.parameters(): \n", " param.requires_grad = False\n", " return emb, emb_size, output_size" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Turning a sequence into a representation can be done using an RNN (called the 'encoder'. This approach is useful because RNN's are able to keep track of state and memory, which is obviously important in forming a complete understanding of a sentence.\n", "* `bidirectional=True` passes the original sequence through an RNN, and the reversed sequence through a different RNN and concatenates the results. This allows us to look forward and backwards.\n", "* We do this because in language things that happen later often influence what came before (i.e. in Spanish, \"el chico, la chica\" means the boy, the girl; the word for \"the\" is determined by the gender of the subject, which comes after)." ] }, { "cell_type": "code", "execution_count": 57, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "class EncoderRNN(nn.Module):\n", " def __init__(self, embs, hidden_size, n_layers=2):\n", " super(EncoderRNN, self).__init__()\n", " self.emb, emb_size, output_size = create_emb(embs, True)\n", " self.n_layers = n_layers\n", " self.hidden_size = hidden_size\n", " self.gru = nn.GRU(emb_size, hidden_size, batch_first=True, num_layers=n_layers)\n", "# ,bidirectional=True)\n", " \n", " def forward(self, input, hidden):\n", " return self.gru(self.emb(input), hidden)\n", "\n", " def initHidden(self, batch_size):\n", " return Variable(torch.zeros(self.n_layers, batch_size, self.hidden_size))" ] }, { "cell_type": "code", "execution_count": 58, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "def encode(inp, encoder):\n", " batch_size, input_length = inp.size()\n", " hidden = encoder.initHidden(batch_size).cuda()\n", " enc_outputs, hidden = encoder(inp, hidden)\n", " return long_t([SOS]*batch_size), enc_outputs, hidden " ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Finally, we arrive at a vector representation of the sequence which captures everything we need to translate it. We feed this vector into more RNN's, which are trying to generate the labels. After this, we make a classification for what each word is in the output sequence." ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "class DecoderRNN(nn.Module):\n", " def __init__(self, embs, hidden_size, n_layers=2):\n", " super(DecoderRNN, self).__init__()\n", " self.emb, emb_size, output_size = create_emb(embs)\n", " self.gru = nn.GRU(emb_size, hidden_size, batch_first=True, num_layers=n_layers)\n", " self.out = nn.Linear(hidden_size, output_size)\n", " \n", " def forward(self, inp, hidden):\n", " emb = self.emb(inp).unsqueeze(1)\n", " res, hidden = self.gru(emb, hidden)\n", " res = F.log_softmax(self.out(res[:,0]))\n", " return res, hidden" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "This graph demonstrates the accuracy decay for a neural translation task. With an encoding/decoding technique, larger input sequences result in less accuracy.\n", "\n", "\n", "\n", "This can be mitigated using an attentional model." ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "heading_collapsed": true }, "source": [ "### Adding broadcasting to Pytorch" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "Using *broadcasting* makes a lot of numerical programming far simpler. Here's a couple of examples, using numpy:" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(array([1, 2, 3]), (3,))" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v=np.array([1,2,3]); v, v.shape" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(array([[1, 2, 3],\n", " [2, 4, 6],\n", " [3, 6, 9]]), (3, 3))" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m=np.array([v,v*2,v*3]); m, m.shape" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 2, 4, 6],\n", " [ 3, 6, 9],\n", " [ 4, 8, 12]])" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m+v" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(array([[1],\n", " [2],\n", " [3]]), (3, 1))" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v1=np.expand_dims(v,-1); v1, v1.shape" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "collapsed": false, "deletable": true, "editable": true, "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 2, 3, 4],\n", " [ 4, 6, 8],\n", " [ 6, 9, 12]])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m+v1" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true, "hidden": true }, "source": [ "But Pytorch doesn't support broadcasting. So let's add it to the basic operators, and to a general tensor dot product:" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "collapsed": true, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "def unit_prefix(x, n=1):\n", " for i in range(n): x = x.unsqueeze(0)\n", " return x\n", "\n", "def align(x, y, start_dim=2):\n", " xd, yd = x.dim(), y.dim()\n", " if xd > yd: y = unit_prefix(y, xd - yd)\n", " elif yd > xd: x = unit_prefix(x, yd - xd)\n", "\n", " xs, ys = list(x.size()), list(y.size())\n", " nd = len(ys)\n", " for i in range(start_dim, nd):\n", " td = nd-i-1\n", " if ys[td]==1: ys[td] = xs[td]\n", " elif xs[td]==1: xs[td] = ys[td]\n", " return x.expand(*xs), y.expand(*ys)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "collapsed": true, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "def aligned_op(x,y,f): return f(*align(x,y,0))\n", "\n", "def add(x, y): return aligned_op(x, y, operator.add)\n", "def sub(x, y): return aligned_op(x, y, operator.sub)\n", "def mul(x, y): return aligned_op(x, y, operator.mul)\n", "def div(x, y): return aligned_op(x, y, operator.truediv)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "collapsed": true, "deletable": true, "editable": true, "hidden": true }, "outputs": [], "source": [ "def dot(x, y):\n", " assert(1