{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Ref: [Autoencoding Variational Inference for Topic Models](https://openreview.net/pdf?id=BybtVK9lg). In _ICLR_. 2017." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "from keras import backend as K\n", "from keras.layers import Input, Dense, Lambda, Activation, Dropout, BatchNormalization, Layer\n", "from keras.models import Model\n", "from keras.optimizers import Adam\n", "from keras.datasets import reuters\n", "from keras.callbacks import EarlyStopping\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "V = 10922\n", "(x_train, _), (_, _) = reuters.load_data(start_char=None, oov_char=None, index_from=-1, num_words=V) # remove words having freq(q) <= 5\n", "word_index = reuters.get_word_index()\n", "index2word = {v-1: k for k, v in word_index.items()} # zero-origin word index\n", "x_train = np.array([np.bincount(doc, minlength=V) for doc in x_train])\n", "x_train = x_train[:8000, :]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "num_hidden = 100\n", "num_topic = 20\n", "batch_size = 100\n", "alpha = 1./20" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "mu1 = np.log(alpha) - 1/num_topic*num_topic*np.log(alpha)\n", "sigma1 = 1./alpha*(1-2./num_topic) + 1/(num_topic**2)*num_topic/alpha\n", "inv_sigma1 = 1./sigma1\n", "log_det_sigma = num_topic*np.log(sigma1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "x = Input(batch_shape=(batch_size, V))\n", "h = Dense(num_hidden, activation='softplus')(x)\n", "h = Dense(num_hidden, activation='softplus')(h)\n", "z_mean = BatchNormalization()(Dense(num_topic)(h))\n", "z_log_var = BatchNormalization()(Dense(num_topic)(h))\n", "\n", "def sampling(args):\n", " z_mean, z_log_var = args\n", " epsilon = K.random_normal(shape=(batch_size, num_topic),\n", " mean=0., stddev=1.)\n", " return z_mean + K.exp(z_log_var / 2) * epsilon\n", "\n", "unnormalized_z = Lambda(sampling, output_shape=(num_topic,))([z_mean, z_log_var])\n", "\n", "theta = Activation('softmax')(unnormalized_z)\n", "theta = Dropout(0.5)(theta)\n", "doc = Dense(units=V)(theta)\n", "doc = BatchNormalization()(doc)\n", "doc = Activation('softmax')(doc)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Custom loss layer\n", "class CustomVariationalLayer(Layer):\n", " def __init__(self, **kwargs):\n", " self.is_placeholder = True\n", " super(CustomVariationalLayer, self).__init__(**kwargs)\n", "\n", " def vae_loss(self, x, inference_x):\n", " decoder_loss = K.sum(x * K.log(inference_x), axis=-1)\n", " encoder_loss = -0.5*(K.sum(inv_sigma1*K.exp(z_log_var) + K.square(z_mean)*inv_sigma1 - 1 - z_log_var, axis=-1) + log_det_sigma)\n", " return -K.mean(encoder_loss + decoder_loss)\n", "\n", " def call(self, inputs):\n", " x = inputs[0] \n", " inference_x = inputs[1]\n", " loss = self.vae_loss(x, inference_x)\n", " self.add_loss(loss, inputs=inputs)\n", " # We won't actually use the output.\n", " return x\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/nzw/.pyenv/versions/miniconda3-latest/lib/python3.6/site-packages/ipykernel_launcher.py:3: UserWarning: Output \"custom_variational_layer_1\" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to \"custom_variational_layer_1\" during training.\n", " This is separate from the ipykernel package so we can avoid doing imports until\n" ] } ], "source": [ "y = CustomVariationalLayer()([x, doc])\n", "prodLDA = Model(x, y)\n", "prodLDA.compile(optimizer=Adam(lr=0.001, beta_1=0.99), loss=None)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 7200 samples, validate on 800 samples\n", "Epoch 1/20\n", "7200/7200 [==============================] - 16s - loss: 1294.8741 - val_loss: 1296.3189\n", "Epoch 2/20\n", "7200/7200 [==============================] - 13s - loss: 1236.1560 - val_loss: 1274.4536\n", "Epoch 3/20\n", "7200/7200 [==============================] - 13s - loss: 1212.5309 - val_loss: 1254.4617\n", "Epoch 4/20\n", "7200/7200 [==============================] - 14s - loss: 1192.9693 - val_loss: 1223.0534\n", "Epoch 5/20\n", "7200/7200 [==============================] - 16s - loss: 1177.8462 - val_loss: 1189.5439\n", "Epoch 6/20\n", "7200/7200 [==============================] - 17s - loss: 1163.8908 - val_loss: 1163.1716\n", "Epoch 7/20\n", "7200/7200 [==============================] - 14s - loss: 1149.6908 - val_loss: 1141.6182\n", "Epoch 8/20\n", "7200/7200 [==============================] - 15s - loss: 1137.0046 - val_loss: 1115.0331\n", "Epoch 9/20\n", "7200/7200 [==============================] - 15s - loss: 1123.7918 - val_loss: 1087.5322\n", "Epoch 10/20\n", "7200/7200 [==============================] - 15s - loss: 1113.6162 - val_loss: 1070.2166\n", "Epoch 11/20\n", "7200/7200 [==============================] - 16s - loss: 1101.6660 - val_loss: 1054.5499\n", "Epoch 12/20\n", "7200/7200 [==============================] - 16s - loss: 1091.5757 - val_loss: 1046.8468\n", "Epoch 13/20\n", "7200/7200 [==============================] - 15s - loss: 1084.2708 - val_loss: 1036.8501\n", "Epoch 14/20\n", "7200/7200 [==============================] - 14s - loss: 1073.5024 - val_loss: 1025.1848\n", "Epoch 15/20\n", "7200/7200 [==============================] - 13s - loss: 1066.2755 - val_loss: 1020.8630\n", "Epoch 16/20\n", "7200/7200 [==============================] - 13s - loss: 1058.9499 - val_loss: 1014.4392\n", "Epoch 17/20\n", "7200/7200 [==============================] - 13s - loss: 1051.1697 - val_loss: 1011.2585\n", "Epoch 18/20\n", "7200/7200 [==============================] - 13s - loss: 1043.8778 - val_loss: 1002.2133\n", "Epoch 19/20\n", "7200/7200 [==============================] - 13s - loss: 1038.1715 - val_loss: 998.5989\n", "Epoch 20/20\n", "7200/7200 [==============================] - 14s - loss: 1031.7404 - val_loss: 990.7487\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prodLDA.fit(x_train, verbose=1, batch_size=batch_size, validation_split=0.1, callbacks=[EarlyStopping(patience=3)], epochs=20)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "exp_beta = np.exp(prodLDA.get_weights()[-6]).T\n", "phi = (exp_beta/np.sum(exp_beta, axis=0)).T" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "topic: 0\n", "mln 0.000111143\n", "billion 0.000108751\n", "vs 0.00010815\n", "4 0.000106495\n", "2 0.000106403\n", "dlrs 0.000106268\n", "0 0.000105673\n", "1 0.00010521\n", "87 0.000104661\n", "tonnes 0.000103869\n", "\n", "topic: 1\n", "offices 9.75551e-05\n", "nogales 9.74515e-05\n", "guard 9.73854e-05\n", "automotive 9.7355e-05\n", "unpaid 9.72935e-05\n", "alarm 9.72468e-05\n", "kilometers 9.72122e-05\n", "dixon 9.71581e-05\n", "library 9.71567e-05\n", "independently 9.70571e-05\n", "\n", "topic: 2\n", "the 0.000103879\n", "of 0.000102854\n", "offer 0.000102746\n", "a 0.000102459\n", "dlrs 0.000101379\n", "pesos 0.000101333\n", "williams 0.000101298\n", "to 0.00010123\n", "share 0.000101195\n", "norcros 0.000101193\n", "\n", "topic: 3\n", "the 0.000106743\n", "trade 0.000104344\n", "to 0.000103626\n", "japan 0.000103166\n", "yeutter 0.000102792\n", "clayton 0.000102494\n", "states 0.000102354\n", "semiconductors 0.000102334\n", "united 0.000102321\n", "venice 0.000102175\n", "\n", "topic: 4\n", "vs 0.000121324\n", "shr 0.000114225\n", "cts 0.00011384\n", "net 0.000113774\n", "000 0.000113085\n", "mln 0.000112964\n", "loss 0.000110497\n", "revs 0.00010924\n", "shrs 0.000108045\n", "avg 0.00010794\n", "\n", "topic: 5\n", "the 0.000120084\n", "to 0.000113982\n", "of 0.000113214\n", "a 0.000110542\n", "in 0.000110401\n", "and 0.000109963\n", "said 0.000109562\n", "that 0.000104584\n", "for 0.000103823\n", "banks 0.000103126\n", "\n", "topic: 6\n", "twa 0.000100107\n", "usair 9.96932e-05\n", "idc 9.91386e-05\n", "twa's 9.91085e-05\n", "offer 9.88553e-05\n", "alvite 9.88493e-05\n", "usair's 9.87797e-05\n", "ecuador's 9.87569e-05\n", "said 9.85903e-05\n", "lawsuit 9.83429e-05\n", "\n", "topic: 7\n", "in 0.000106211\n", "pct 0.000105936\n", "0 0.000105215\n", "1 0.000104341\n", "rose 0.000104238\n", "2 0.000104187\n", "unadjusted 0.000103774\n", "87 0.000103751\n", "09 0.000103438\n", "billion 0.000103206\n", "\n", "topic: 8\n", "div 0.000107786\n", "qtly 0.000107582\n", "prior 0.000104525\n", "record 0.000104042\n", "pay 0.000103519\n", "juergen 0.000103384\n", "eckenfelder 0.000103248\n", "decades 0.000103185\n", "playing 0.000103119\n", "overhanging 0.000103105\n", "\n", "topic: 9\n", "the 0.000118233\n", "to 0.000110448\n", "of 0.000109871\n", "in 0.000108754\n", "said 0.000108514\n", "and 0.000107417\n", "a 0.000106782\n", "economists 0.000103912\n", "fed 0.000103817\n", "pct 0.000103388\n", "\n", "topic: 10\n", "shares 9.83596e-05\n", "gold 9.79488e-05\n", "it 9.76937e-05\n", "offer 9.72621e-05\n", "hillards 9.7235e-05\n", "ton 9.715e-05\n", "assistance 9.70506e-05\n", "filing 9.70396e-05\n", "rated 9.70243e-05\n", "debentures 9.69472e-05\n", "\n", "topic: 11\n", "vs 0.000123313\n", "mln 0.000119359\n", "dlrs 0.000117206\n", "cts 0.000116674\n", "net 0.000115827\n", "shr 0.000114957\n", "000 0.000113607\n", "loss 0.00011315\n", "oper 0.000111243\n", "1 0.000110801\n", "\n", "topic: 12\n", "undisclosed 0.000100513\n", "nogales 9.92817e-05\n", "bolivia's 9.90571e-05\n", "inc 9.90198e-05\n", "refine 9.87756e-05\n", "haq 9.87713e-05\n", "vulnerability 9.83347e-05\n", "eckenfelder 9.82606e-05\n", "unitary 9.8222e-05\n", "remittances 9.81787e-05\n", "\n", "topic: 13\n", "undisclosed 9.84391e-05\n", "covenants 9.83275e-05\n", "inc 9.82571e-05\n", "corp 9.81662e-05\n", "shrinking 9.81463e-05\n", "bolivia's 9.80812e-05\n", "completed 9.80501e-05\n", "nogales 9.8049e-05\n", "lieberman 9.7992e-05\n", "rawl 9.79763e-05\n", "\n", "topic: 14\n", "shr 0.00011173\n", "vs 0.000110093\n", "cts 0.00010956\n", "net 0.000109266\n", "revs 0.000109084\n", "vulnerability 0.000101041\n", "avg 0.000101024\n", "lieberman 0.000100964\n", "note 0.00010096\n", "calculating 0.000100699\n", "\n", "topic: 15\n", "vs 0.000123786\n", "shr 0.000116089\n", "net 0.000115306\n", "cts 0.000114294\n", "loss 0.000112816\n", "000 0.000112405\n", "revs 0.000111282\n", "mln 0.000110697\n", "profit 0.000109111\n", "avg 0.000108816\n", "\n", "topic: 16\n", "the 0.000117573\n", "to 0.000111937\n", "of 0.000110009\n", "and 0.000108801\n", "said 0.000108758\n", "in 0.000108275\n", "a 0.000107725\n", "opec 0.000106638\n", "oil 0.000105008\n", "prices 0.000104015\n", "\n", "topic: 17\n", "the 0.000111432\n", "in 0.000106903\n", "to 0.000106766\n", "of 0.00010578\n", "said 0.000104663\n", "and 0.000104482\n", "a 0.000103509\n", "mln 0.000103303\n", "year 0.000102863\n", "pct 0.000102565\n", "\n", "topic: 18\n", "the 0.000115889\n", "to 0.00011093\n", "of 0.000108026\n", "rep 0.000106465\n", "a 0.000106093\n", "and 0.000105682\n", "subcommittee 0.000105528\n", "said 0.000105319\n", "trade 0.000105141\n", "bill 0.000105072\n", "\n", "topic: 19\n", "div 0.000106618\n", "qtly 0.000105465\n", "prior 0.000103936\n", "record 0.000103885\n", "decades 0.000103455\n", "eckenfelder 0.000103418\n", "matane 0.000103244\n", "harkin 0.00010324\n", "donohue's 0.000103181\n", "preferring 0.000103011\n", "\n" ] } ], "source": [ "for k, phi_k in enumerate(phi):\n", " print('topic: {}'.format(k))\n", " for w in np.argsort(phi_k)[::-1][:10]:\n", " print(index2word[w], phi_k[w])\n", " print()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 1 }