{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#!/usr/bin/env python\n", "import import_ipynb" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "importing Jupyter notebook from q3_word2vec.ipynb\n", "importing Jupyter notebook from q1_softmax.ipynb\n", "importing Jupyter notebook from q2_gradcheck.ipynb\n", "importing Jupyter notebook from q2_sigmoid.ipynb\n", "importing Jupyter notebook from q3_sgd.ipynb\n", "iter_ 100: 18.604024\n", "iter_ 200: 18.779389\n", "iter_ 300: 18.968864\n", "iter_ 400: 19.209857\n", "iter_ 500: 19.156670\n", "iter_ 600: 19.243421\n", "iter_ 700: 19.440181\n", "iter_ 800: 19.520359\n", "iter_ 900: 19.588859\n", "iter_ 1000: 19.882518\n", "iter_ 1100: 20.161624\n", "iter_ 1200: 20.121685\n", "iter_ 1300: 20.228471\n", "iter_ 1400: 20.474703\n", "iter_ 1500: 20.487733\n", "iter_ 1600: 20.621488\n", "iter_ 1700: 20.671889\n", "iter_ 1800: 20.673450\n", "iter_ 1900: 20.947727\n", "iter_ 2000: 21.009787\n", "iter_ 2100: 21.031895\n", "iter_ 2200: 21.075417\n", "iter_ 2300: 21.037911\n", "iter_ 2400: 21.063337\n", "iter_ 2500: 21.159010\n", "iter_ 2600: 21.220136\n", "iter_ 2700: 21.306704\n", "iter_ 2800: 21.317022\n", "iter_ 2900: 21.313567\n", "iter_ 3000: 21.280537\n", "iter_ 3100: 21.394261\n", "iter_ 3200: 21.222326\n", "iter_ 3300: 21.103933\n", "iter_ 3400: 21.026450\n", "iter_ 3500: 20.940565\n", "iter_ 3600: 20.778982\n", "iter_ 3700: 20.867870\n", "iter_ 3800: 20.790209\n", "iter_ 3900: 20.781813\n", "iter_ 4000: 20.573691\n", "iter_ 4100: 20.434695\n", "iter_ 4200: 20.454322\n", "iter_ 4300: 20.281818\n", "iter_ 4400: 20.213335\n", "iter_ 4500: 20.006197\n", "iter_ 4600: 19.938834\n", "iter_ 4700: 19.816390\n", "iter_ 4800: 19.534211\n", "iter_ 4900: 19.492875\n", "iter_ 5000: 19.327795\n", "saved!\n", "iter_ 5100: 19.069202\n", "iter_ 5200: 18.893172\n", "iter_ 5300: 18.771077\n", "iter_ 5400: 18.903853\n", "iter_ 5500: 18.943519\n", "iter_ 5600: 18.835504\n", "iter_ 5700: 18.641866\n", "iter_ 5800: 18.568050\n", "iter_ 5900: 18.406251\n", "iter_ 6000: 18.303576\n", "iter_ 6100: 18.161562\n", "iter_ 6200: 18.011232\n", "iter_ 6300: 17.931368\n", "iter_ 6400: 17.969555\n", "iter_ 6500: 17.838386\n", "iter_ 6600: 17.639902\n", "iter_ 6700: 17.505160\n", "iter_ 6800: 17.334591\n", "iter_ 6900: 17.245499\n", "iter_ 7000: 17.072789\n", "iter_ 7100: 17.000454\n", "iter_ 7200: 16.926262\n", "iter_ 7300: 16.776886\n", "iter_ 7400: 16.724920\n", "iter_ 7500: 16.555643\n", "iter_ 7600: 16.497449\n", "iter_ 7700: 16.367106\n", "iter_ 7800: 16.263621\n", "iter_ 7900: 16.167857\n", "iter_ 8000: 16.075852\n", "iter_ 8100: 15.863419\n", "iter_ 8200: 15.704432\n", "iter_ 8300: 15.516366\n", "iter_ 8400: 15.415402\n", "iter_ 8500: 15.320215\n", "iter_ 8600: 15.224080\n", "iter_ 8700: 15.137174\n", "iter_ 8800: 14.932750\n", "iter_ 8900: 14.888707\n", "iter_ 9000: 14.842695\n", "iter_ 9100: 14.664533\n", "iter_ 9200: 14.700342\n", "iter_ 9300: 14.598106\n", "iter_ 9400: 14.470129\n", "iter_ 9500: 14.346605\n", "iter_ 9600: 14.271724\n", "iter_ 9700: 14.161178\n", "iter_ 9800: 14.101970\n", "iter_ 9900: 14.041289\n", "iter_ 10000: 13.905749\n", "saved!\n", "iter_ 10100: 13.781994\n", "iter_ 10200: 13.761556\n", "iter_ 10300: 13.604726\n", "iter_ 10400: 13.544350\n", "iter_ 10500: 13.491717\n", "iter_ 10600: 13.401767\n", "iter_ 10700: 13.305987\n", "iter_ 10800: 13.281872\n", "iter_ 10900: 13.251243\n", "iter_ 11000: 13.149489\n", "iter_ 11100: 13.071243\n", "iter_ 11200: 12.996503\n", "iter_ 11300: 13.014213\n", "iter_ 11400: 12.944917\n", "iter_ 11500: 12.917141\n", "iter_ 11600: 12.835549\n", "iter_ 11700: 12.826709\n", "iter_ 11800: 12.747148\n", "iter_ 11900: 12.641782\n", "iter_ 12000: 12.665167\n", "iter_ 12100: 12.646362\n", "iter_ 12200: 12.637761\n", "iter_ 12300: 12.515719\n", "iter_ 12400: 12.576099\n", "iter_ 12500: 12.521075\n", "iter_ 12600: 12.418966\n", "iter_ 12700: 12.381063\n", "iter_ 12800: 12.381649\n", "iter_ 12900: 12.322723\n", "iter_ 13000: 12.291331\n", "iter_ 13100: 12.231207\n", "iter_ 13200: 12.221519\n", "iter_ 13300: 12.115059\n", "iter_ 13400: 12.065030\n", "iter_ 13500: 12.064713\n", "iter_ 13600: 12.021912\n", "iter_ 13700: 11.971597\n", "iter_ 13800: 11.848290\n", "iter_ 13900: 11.786099\n", "iter_ 14000: 11.744749\n", "iter_ 14100: 11.701911\n", "iter_ 14200: 11.687953\n", "iter_ 14300: 11.626524\n", "iter_ 14400: 11.649125\n", "iter_ 14500: 11.621814\n", "iter_ 14600: 11.579251\n", "iter_ 14700: 11.540343\n", "iter_ 14800: 11.483450\n", "iter_ 14900: 11.383964\n", "iter_ 15000: 11.345420\n", "saved!\n", "iter_ 15100: 11.189060\n", "iter_ 15200: 11.210092\n", "iter_ 15300: 11.219047\n", "iter_ 15400: 11.203139\n", "iter_ 15500: 11.158248\n", "iter_ 15600: 11.123628\n", "iter_ 15700: 11.099226\n", "iter_ 15800: 11.064870\n", "iter_ 15900: 11.084253\n", "iter_ 16000: 11.015952\n", "iter_ 16100: 11.004468\n", "iter_ 16200: 11.015049\n", "iter_ 16300: 11.023876\n", "iter_ 16400: 11.010123\n", "iter_ 16500: 10.983821\n", "iter_ 16600: 10.938413\n", "iter_ 16700: 10.894742\n", "iter_ 16800: 10.754350\n", "iter_ 16900: 10.664717\n", "iter_ 17000: 10.623251\n", "iter_ 17100: 10.596035\n", "iter_ 17200: 10.617019\n", "iter_ 17300: 10.721184\n", "iter_ 17400: 10.698315\n", "iter_ 17500: 10.758545\n", "iter_ 17600: 10.730561\n", "iter_ 17700: 10.756100\n", "iter_ 17800: 10.756223\n", "iter_ 17900: 10.729578\n", "iter_ 18000: 10.713750\n", "iter_ 18100: 10.733265\n", "iter_ 18200: 10.717193\n", "iter_ 18300: 10.734548\n", "iter_ 18400: 10.626955\n", "iter_ 18500: 10.573120\n", "iter_ 18600: 10.573155\n", "iter_ 18700: 10.548022\n", "iter_ 18800: 10.497992\n", "iter_ 18900: 10.464389\n", "iter_ 19000: 10.467851\n", "iter_ 19100: 10.451004\n", "iter_ 19200: 10.416570\n", "iter_ 19300: 10.369180\n", "iter_ 19400: 10.380386\n", "iter_ 19500: 10.334510\n", "iter_ 19600: 10.426575\n", "iter_ 19700: 10.402202\n", "iter_ 19800: 10.345839\n", "iter_ 19900: 10.414973\n", "iter_ 20000: 10.414057\n", "saved!\n", "iter_ 20100: 10.424682\n", "iter_ 20200: 10.356544\n", "iter_ 20300: 10.423831\n", "iter_ 20400: 10.387770\n", "iter_ 20500: 10.362605\n", "iter_ 20600: 10.376197\n", "iter_ 20700: 10.331289\n", "iter_ 20800: 10.380293\n", "iter_ 20900: 10.337848\n", "iter_ 21000: 10.340521\n", "iter_ 21100: 10.305953\n", "iter_ 21200: 10.317894\n", "iter_ 21300: 10.343518\n", "iter_ 21400: 10.314110\n", "iter_ 21500: 10.271524\n", "iter_ 21600: 10.238044\n", "iter_ 21700: 10.205756\n", "iter_ 21800: 10.176787\n", "iter_ 21900: 10.084072\n", "iter_ 22000: 10.106953\n", "iter_ 22100: 10.053215\n", "iter_ 22200: 10.069332\n", "iter_ 22300: 10.052441\n", "iter_ 22400: 10.012923\n", "iter_ 22500: 10.038104\n", "iter_ 22600: 9.993996\n", "iter_ 22700: 10.004109\n", "iter_ 22800: 10.025915\n", "iter_ 22900: 10.036278\n", "iter_ 23000: 10.014603\n", "iter_ 23100: 9.960675\n", "iter_ 23200: 10.003456\n", "iter_ 23300: 10.055802\n", "iter_ 23400: 9.988262\n", "iter_ 23500: 10.001681\n", "iter_ 23600: 9.990759\n", "iter_ 23700: 9.984735\n", "iter_ 23800: 9.981500\n", "iter_ 23900: 9.894283\n", "iter_ 24000: 9.851653\n", "iter_ 24100: 9.867821\n", "iter_ 24200: 9.866621\n", "iter_ 24300: 9.886243\n", "iter_ 24400: 9.932614\n", "iter_ 24500: 9.943179\n", "iter_ 24600: 9.952576\n", "iter_ 24700: 9.955699\n", "iter_ 24800: 9.905475\n", "iter_ 24900: 9.834542\n", "iter_ 25000: 9.880161\n", "saved!\n", "iter_ 25100: 9.860923\n", "iter_ 25200: 9.863371\n", "iter_ 25300: 9.910072\n", "iter_ 25400: 9.902151\n", "iter_ 25500: 9.936475\n", "iter_ 25600: 9.936556\n", "iter_ 25700: 9.952919\n", "iter_ 25800: 9.977249\n", "iter_ 25900: 9.991944\n", "iter_ 26000: 10.043130\n", "iter_ 26100: 10.079640\n", "iter_ 26200: 10.011304\n", "iter_ 26300: 9.939285\n", "iter_ 26400: 9.895720\n", "iter_ 26500: 9.885912\n", "iter_ 26600: 9.877268\n", "iter_ 26700: 9.855755\n", "iter_ 26800: 9.849249\n", "iter_ 26900: 9.819325\n", "iter_ 27000: 9.791931\n", "iter_ 27100: 9.820158\n", "iter_ 27200: 9.768645\n", "iter_ 27300: 9.825911\n", "iter_ 27400: 9.794431\n", "iter_ 27500: 9.857503\n", "iter_ 27600: 9.816648\n", "iter_ 27700: 9.823728\n", "iter_ 27800: 9.821667\n", "iter_ 27900: 9.862476\n", "iter_ 28000: 9.881795\n", "iter_ 28100: 9.883599\n", "iter_ 28200: 9.912319\n", "iter_ 28300: 9.941006\n", "iter_ 28400: 9.917078\n", "iter_ 28500: 9.914267\n", "iter_ 28600: 9.860607\n", "iter_ 28700: 9.905190\n", "iter_ 28800: 9.939964\n", "iter_ 28900: 9.940590\n", "iter_ 29000: 9.893362\n", "iter_ 29100: 9.916780\n", "iter_ 29200: 9.838282\n", "iter_ 29300: 9.834859\n", "iter_ 29400: 9.831548\n", "iter_ 29500: 9.790570\n", "iter_ 29600: 9.788810\n", "iter_ 29700: 9.746504\n", "iter_ 29800: 9.798515\n", "iter_ 29900: 9.782321\n", "iter_ 30000: 9.704538\n", "saved!\n", "iter_ 30100: 9.729074\n", "iter_ 30200: 9.768284\n", "iter_ 30300: 9.784937\n", "iter_ 30400: 9.780768\n", "iter_ 30500: 9.828559\n", "iter_ 30600: 9.873279\n", "iter_ 30700: 9.866925\n", "iter_ 30800: 9.874105\n", "iter_ 30900: 9.875424\n", "iter_ 31000: 9.853438\n", "iter_ 31100: 9.834655\n", "iter_ 31200: 9.827070\n", "iter_ 31300: 9.808807\n", "iter_ 31400: 9.763422\n", "iter_ 31500: 9.808970\n", "iter_ 31600: 9.875959\n", "iter_ 31700: 9.873590\n", "iter_ 31800: 9.918868\n", "iter_ 31900: 9.970307\n", "iter_ 32000: 10.021178\n", "iter_ 32100: 9.993948\n", "iter_ 32200: 9.943728\n", "iter_ 32300: 9.862252\n", "iter_ 32400: 9.772822\n", "iter_ 32500: 9.730033\n", "iter_ 32600: 9.708618\n", "iter_ 32700: 9.697634\n", "iter_ 32800: 9.719430\n", "iter_ 32900: 9.686365\n", "iter_ 33000: 9.641960\n", "iter_ 33100: 9.700967\n", "iter_ 33200: 9.686835\n", "iter_ 33300: 9.655312\n", "iter_ 33400: 9.677657\n", "iter_ 33500: 9.650491\n", "iter_ 33600: 9.666483\n", "iter_ 33700: 9.671914\n", "iter_ 33800: 9.658939\n", "iter_ 33900: 9.579504\n", "iter_ 34000: 9.546839\n", "iter_ 34100: 9.506833\n", "iter_ 34200: 9.497960\n", "iter_ 34300: 9.504882\n", "iter_ 34400: 9.526258\n", "iter_ 34500: 9.543517\n", "iter_ 34600: 9.579199\n", "iter_ 34700: 9.570666\n", "iter_ 34800: 9.567154\n", "iter_ 34900: 9.540781\n", "iter_ 35000: 9.605205\n", "saved!\n", "iter_ 35100: 9.664605\n", "iter_ 35200: 9.681135\n", "iter_ 35300: 9.614818\n", "iter_ 35400: 9.566828\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "iter_ 35500: 9.610944\n", "iter_ 35600: 9.663755\n", "iter_ 35700: 9.722588\n", "iter_ 35800: 9.722484\n", "iter_ 35900: 9.677905\n", "iter_ 36000: 9.681488\n", "iter_ 36100: 9.697359\n", "iter_ 36200: 9.675975\n", "iter_ 36300: 9.632207\n", "iter_ 36400: 9.572983\n", "iter_ 36500: 9.546814\n", "iter_ 36600: 9.563374\n", "iter_ 36700: 9.569601\n", "iter_ 36800: 9.610571\n", "iter_ 36900: 9.583856\n", "iter_ 37000: 9.579905\n", "iter_ 37100: 9.555246\n", "iter_ 37200: 9.574555\n", "iter_ 37300: 9.529198\n", "iter_ 37400: 9.501708\n", "iter_ 37500: 9.510385\n", "iter_ 37600: 9.576810\n", "iter_ 37700: 9.520415\n", "iter_ 37800: 9.561922\n", "iter_ 37900: 9.574543\n", "iter_ 38000: 9.605944\n", "iter_ 38100: 9.620448\n", "iter_ 38200: 9.662221\n", "iter_ 38300: 9.625025\n", "iter_ 38400: 9.581447\n", "iter_ 38500: 9.615949\n", "iter_ 38600: 9.606902\n", "iter_ 38700: 9.663988\n", "iter_ 38800: 9.608830\n", "iter_ 38900: 9.631785\n", "iter_ 39000: 9.638900\n", "iter_ 39100: 9.590407\n", "iter_ 39200: 9.596133\n", "iter_ 39300: 9.532854\n", "iter_ 39400: 9.522097\n", "iter_ 39500: 9.497443\n", "iter_ 39600: 9.458520\n", "iter_ 39700: 9.447899\n", "iter_ 39800: 9.428893\n", "iter_ 39900: 9.406359\n", "iter_ 40000: 9.386926\n", "saved!\n", "sanity check: cost at convergence should be around or below 10\n", "training took 9083 seconds\n" ] } ], "source": [ "import random\n", "import numpy as np\n", "from utils.treebank import StanfordSentiment\n", "import matplotlib\n", "matplotlib.use('agg')\n", "import matplotlib.pyplot as plt\n", "import time\n", "\n", "from q3_word2vec import *\n", "from q3_sgd import *\n", "\n", "# Reset the random seed to make sure that everyone gets the same results\n", "random.seed(314)\n", "dataset = StanfordSentiment()\n", "tokens = dataset.tokens()\n", "nWords = len(tokens)\n", "\n", "# We are going to train 10-dimensional vectors for this assignment\n", "dimVectors = 10\n", "\n", "# Context size\n", "C = 5\n", "\n", "# Reset the random seed to make sure that everyone gets the same results\n", "random.seed(31415)\n", "np.random.seed(9265)\n", "\n", "startTime=time.time()\n", "wordVectors = np.concatenate(\n", " ((np.random.rand(nWords, dimVectors) - 0.5) /\n", " dimVectors, np.zeros((nWords, dimVectors))),\n", " axis=0)\n", "\n", "wordVectors = sgd(\n", " lambda vec: word2vec_sgd_wrapper(skipgram, tokens, vec, dataset, C,\n", " negSamplingCostAndGradient),\n", " wordVectors, 0.3, 40000, None, True, PRINT_EVERY=100)\n", "# Note that normalization is not called here. This is not a bug,\n", "# normalizing during training loses the notion of length.\n", "\n", "print(\"sanity check: cost at convergence should be around or below 10\")\n", "print(\"training took %d seconds\" % (time.time() - startTime))\n", "\n", "# concatenate the input and output word vectors\n", "# 这里将U,V合并,后面会进行奇异值分解\n", "wordVectors = np.concatenate(\n", " (wordVectors[:nWords,:], wordVectors[nWords:,:]),\n", " axis=0)\n", "\n", "visualizeWords = [\n", " \"the\", \"a\", \"an\", \",\", \".\", \"?\", \"!\", \"``\", \"''\", \"--\",\n", " \"good\", \"great\", \"cool\", \"brilliant\", \"wonderful\", \"well\", \"amazing\",\n", " \"worth\", \"sweet\", \"enjoyable\", \"boring\", \"bad\", \"waste\", \"dumb\",\n", " \"annoying\"]\n", " \n", "visualizeIdx = [tokens[word] for word in visualizeWords]\n", "visualizeVecs = wordVectors[visualizeIdx, :]\n", "\n", "# PCA,采用SVD来实现的,PCA很重要的一点中心化,均值为0\n", "temp = (visualizeVecs - np.mean(visualizeVecs, axis=0))\n", "covariance = 1.0 / len(visualizeIdx) * temp.T.dot(temp)\n", "# SVD的左奇异矩阵恰好就是X.dot(X.T)的特征向量组成的矩阵,而这个矩阵的特征向量恰好就是PCA的主成分\n", "U,S,V = np.linalg.svd(covariance)\n", "coord = temp.dot(U[:,0:2])\n", "\n", "for i in range(len(visualizeWords)):\n", " plt.text(coord[i,0], coord[i,1], visualizeWords[i],\n", " bbox=dict(facecolor='green', alpha=0.1))\n", "\n", "plt.xlim((np.min(coord[:,0]), np.max(coord[:,0])))\n", "plt.ylim((np.min(coord[:,1]), np.max(coord[:,1])))\n", "\n", "plt.savefig('q3_word_vectors.png')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }