{ "metadata": { "name": "", "signature": "sha256:2b9c7c850052aff1fb14b68ffc3cec8d18bd1d7e35b271cbe9c08076e64c9928" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "code", "collapsed": false, "input": [ "import cPickle as pickle\n", "import msgpack\n", "\n", "import numpy as np" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 1 }, { "cell_type": "code", "collapsed": false, "input": [ "# Load vocabulary w/ word frequencies\n", "with open('wmt11.head.vocab', 'rb') as f:\n", " vocab = msgpack.load(f)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 3 }, { "cell_type": "code", "collapsed": false, "input": [ "# Load requisite vector data\n", "with open('wmt11.head.vectors', 'rb') as f:\n", " W = pickle.load(f)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 4 }, { "cell_type": "code", "collapsed": false, "input": [ "id2word = dict((id, word) for word, (id, _) in vocab.iteritems())" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 7 }, { "cell_type": "code", "collapsed": false, "input": [ "# Normalize word vectors\n", "for i, row in enumerate(W):\n", " W[i, :] /= np.linalg.norm(row)\n", " \n", "# Remove context word vectors\n", "W = W[:len(vocab), :]" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 10 }, { "cell_type": "code", "collapsed": false, "input": [ "def most_similar(positive, negative, topn=10, freq_threshold=5):\n", " # Build a \"mean\" vector for the given positive and negative terms\n", " mean_vecs = []\n", " for word in positive: mean_vecs.append(W[vocab[word][0]])\n", " for word in negative: mean_vecs.append(-1 * W[vocab[word][0]])\n", " \n", " mean = np.array(mean_vecs).mean(axis=0)\n", " mean /= np.linalg.norm(mean)\n", " \n", " # Now calculate cosine distances between this mean vector and all others\n", " dists = np.dot(W, mean)\n", " \n", " best = np.argsort(dists)[::-1][:topn + len(positive) + len(negative) + 100]\n", " result = [(id2word[i], dists[i]) for i in best if (vocab[id2word[i]] >= freq_threshold\n", " and id2word[i] not in positive\n", " and id2word[i] not in negative)]\n", " return result[:topn]" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 11 }, { "cell_type": "code", "collapsed": false, "input": [ "most_similar(['king', 'woman'], ['man'], topn=50)" ], "language": "python", "metadata": {}, "outputs": [ { "metadata": {}, "output_type": "pyout", "prompt_number": 14, "text": [ "[('queen', 0.69478105985944205),\n", " ('ace', 0.63016505472136219),\n", " ('trick', 0.62198680411172658),\n", " ('library', 0.61596180822198343),\n", " ('diamond', 0.61546379436428578),\n", " ('club', 0.60882108049620698),\n", " ('horse', 0.60577931043391597),\n", " ('ski', 0.5980682567370863),\n", " ('tennis', 0.59252997663757134),\n", " ('chef', 0.58578732345724127),\n", " ('museum', 0.58238877554666368),\n", " ('grandmother', 0.58148552464037506),\n", " ('diamonds', 0.58011253208849856),\n", " ('crown', 0.57997983286899146),\n", " ('seller', 0.57651369635738636),\n", " ('tip', 0.57446540473288965),\n", " ('oldest', 0.56849683935598727),\n", " ('holder', 0.56698181344304011),\n", " ('row', 0.56597681090513596),\n", " ('Museum', 0.56365025428845172),\n", " ('royal', 0.56291276425071346),\n", " ('Royal', 0.56191759370337424),\n", " ('farmer', 0.55962264699238262),\n", " ('Queen', 0.55947426321308247),\n", " ('colony', 0.55792198467607856),\n", " ('Maine', 0.55782081129452066),\n", " ('hat', 0.55772124209691432),\n", " ('dog', 0.5566658071093793),\n", " ('Valley', 0.55537812887550264),\n", " ('soccer', 0.55403076872031942),\n", " ('cinema', 0.55362014730217401),\n", " ('Latvia', 0.55191882205612497),\n", " ('hero', 0.55175383201232631),\n", " ('dancer', 0.55130889459560439),\n", " ('spade', 0.55042516938080366),\n", " ('Country', 0.54924256358808599),\n", " ('Yale', 0.54889249198494516),\n", " ('Rock', 0.54845215690150428),\n", " ('girlfriend', 0.54695823638084806),\n", " ('pool', 0.54691472405799435),\n", " ('neighbor', 0.54683901446532124),\n", " ('bars', 0.54670398577022916),\n", " ('bottle', 0.54548588559461253),\n", " ('pope', 0.54363205675334925),\n", " ('boyfriend', 0.54230260301709221),\n", " ('classic', 0.54154692864168852),\n", " ('interior', 0.54121481559609896),\n", " ('Buffalo', 0.54109694727311264),\n", " ('buyer', 0.5408952311440024),\n", " ('sheriff', 0.54027326795728747)]" ] } ], "prompt_number": 14 }, { "cell_type": "code", "collapsed": false, "input": [ "most_similar(['brought', 'seek'], ['bring'], topn=50)" ], "language": "python", "metadata": {}, "outputs": [ { "metadata": {}, "output_type": "pyout", "prompt_number": 15, "text": [ "[('sought', 0.80168320406931981),\n", " ('seeking', 0.73662888334926047),\n", " ('forced', 0.69273739435205384),\n", " ('attempted', 0.68510971171255386),\n", " ('tried', 0.67516210714164604),\n", " ('allowed', 0.65480577594618783),\n", " ('urged', 0.64988576500767947),\n", " ('managed', 0.64642237086872134),\n", " ('seeks', 0.64400589953863585),\n", " ('refused', 0.63527139994723147),\n", " ('intended', 0.63348152287487691),\n", " ('unable', 0.62647201702998501),\n", " ('demanded', 0.62625912225250269),\n", " ('prompted', 0.62515185408955964),\n", " ('threatened', 0.62393983356386451),\n", " ('determined', 0.62224734077632959),\n", " ('attempting', 0.6181202137691465),\n", " ('hoped', 0.61761471480150942),\n", " ('prepared', 0.61306593357078376),\n", " ('encouraged', 0.61228972301898099),\n", " ('requested', 0.60900154224998204),\n", " ('followed', 0.60838821784825281),\n", " ('helped', 0.60657759212041662),\n", " ('attempt', 0.60619642784126782),\n", " ('failed', 0.6045095128678335),\n", " ('led', 0.60300298435099864),\n", " ('opted', 0.59973601096877438),\n", " ('granted', 0.59786114263781998),\n", " ('initiated', 0.59435889304397749),\n", " ('chosen', 0.59148716010685476),\n", " ('faced', 0.58759291122220392),\n", " ('wanted', 0.58484842439765283),\n", " ('refusing', 0.58473430252602809),\n", " ('addressed', 0.58405417697747952),\n", " ('offered', 0.58351159606145453),\n", " ('asking', 0.58349246210416927),\n", " ('rejected', 0.58123114186229263),\n", " ('decided', 0.58080346748275247),\n", " ('pledged', 0.57872006649827579),\n", " ('pressed', 0.57805915462397062),\n", " ('ordered', 0.57777715792227946),\n", " ('received', 0.57658184971738702),\n", " ('designed', 0.57650776001629578),\n", " ('persuaded', 0.57468989479783739),\n", " ('urging', 0.57295698093987468),\n", " ('accepted', 0.57247322897335085),\n", " ('allowing', 0.57049684140481094),\n", " ('able', 0.56909602033660478),\n", " ('calling', 0.56550493868341545),\n", " ('required', 0.565493188089782)]" ] } ], "prompt_number": 15 }, { "cell_type": "code", "collapsed": false, "input": [], "language": "python", "metadata": {}, "outputs": [] } ], "metadata": {} } ] }