Ref: [Autoencoding Variational Inference for Topic Models](https://openreview.net/pdf?id=BybtVK9lg). In _ICLR_. 2017.

In [1]:
from keras import backend as K
from keras.layers import Input, Dense, Lambda, Activation, Dropout, BatchNormalization, Layer
from keras.models import Model
from keras.optimizers import Adam
from keras.datasets import reuters
from keras.callbacks import EarlyStopping
import numpy as np

Using TensorFlow backend.


In [2]:
V = 10922
(x_train, _), (_, _) = reuters.load_data(start_char=None, oov_char=None, index_from=-1, num_words=V) # remove words having freq(q) <= 5
word_index = reuters.get_word_index()
index2word = {v-1: k for k, v in word_index.items()} # zero-origin word index
x_train = np.array([np.bincount(doc, minlength=V) for doc in x_train])
x_train = x_train[:8000, :]

In [3]:
num_hidden = 100
num_topic = 20
batch_size = 100
alpha = 1./20

In [4]:
mu1 = np.log(alpha) - 1/num_topic*num_topic*np.log(alpha)
sigma1 = 1./alpha*(1-2./num_topic) + 1/(num_topic**2)*num_topic/alpha
inv_sigma1 = 1./sigma1
log_det_sigma = num_topic*np.log(sigma1)

In [5]:
x = Input(batch_shape=(batch_size, V))
h = Dense(num_hidden, activation='softplus')(x)
h = Dense(num_hidden, activation='softplus')(h)
z_mean = BatchNormalization()(Dense(num_topic)(h))
z_log_var = BatchNormalization()(Dense(num_topic)(h))

def sampling(args):
 z_mean, z_log_var = args
 epsilon = K.random_normal(shape=(batch_size, num_topic),
 mean=0., stddev=1.)
 return z_mean + K.exp(z_log_var / 2) * epsilon

unnormalized_z = Lambda(sampling, output_shape=(num_topic,))([z_mean, z_log_var])

theta = Activation('softmax')(unnormalized_z)
theta = Dropout(0.5)(theta)
doc = Dense(units=V)(theta)
doc = BatchNormalization()(doc)
doc = Activation('softmax')(doc)

In [6]:
# Custom loss layer
class CustomVariationalLayer(Layer):
 def __init__(self, **kwargs):
 self.is_placeholder = True
 super(CustomVariationalLayer, self).__init__(**kwargs)

 def vae_loss(self, x, inference_x):
 decoder_loss = K.sum(x * K.log(inference_x), axis=-1)
 encoder_loss = -0.5*(K.sum(inv_sigma1*K.exp(z_log_var) + K.square(z_mean)*inv_sigma1 - 1 - z_log_var, axis=-1) + log_det_sigma)
 return -K.mean(encoder_loss + decoder_loss)

 def call(self, inputs):
 x = inputs[0] 
 inference_x = inputs[1]
 loss = self.vae_loss(x, inference_x)
 self.add_loss(loss, inputs=inputs)
 # We won't actually use the output.
 return x


In [7]:
y = CustomVariationalLayer()([x, doc])
prodLDA = Model(x, y)
prodLDA.compile(optimizer=Adam(lr=0.001, beta_1=0.99), loss=None)

 This is separate from the ipykernel package so we can avoid doing imports until


In [8]:
prodLDA.fit(x_train, verbose=1, batch_size=batch_size, validation_split=0.1, callbacks=[EarlyStopping(patience=3)], epochs=20)

Train on 7200 samples, validate on 800 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20




In [9]:
exp_beta = np.exp(prodLDA.get_weights()[-6]).T
phi = (exp_beta/np.sum(exp_beta, axis=0)).T

In [10]:
for k, phi_k in enumerate(phi):
 print('topic: {}'.format(k))
 for w in np.argsort(phi_k)[::-1][:10]:
 print(index2word[w], phi_k[w])
 print()


topic: 0
mln 0.000111143
billion 0.000108751
vs 0.00010815
4 0.000106495
2 0.000106403
dlrs 0.000106268
0 0.000105673
1 0.00010521
87 0.000104661
tonnes 0.000103869

topic: 1
offices 9.75551e-05
nogales 9.74515e-05
guard 9.73854e-05
automotive 9.7355e-05
unpaid 9.72935e-05
alarm 9.72468e-05
kilometers 9.72122e-05
dixon 9.71581e-05
library 9.71567e-05
independently 9.70571e-05

topic: 2
the 0.000103879
of 0.000102854
offer 0.000102746
a 0.000102459
dlrs 0.000101379
pesos 0.000101333
williams 0.000101298
to 0.00010123
share 0.000101195
norcros 0.000101193

topic: 3
the 0.000106743
trade 0.000104344
to 0.000103626
japan 0.000103166
yeutter 0.000102792
clayton 0.000102494
states 0.000102354
semiconductors 0.000102334
united 0.000102321
venice 0.000102175

topic: 4
vs 0.000121324
shr 0.000114225
cts 0.00011384
net 0.000113774
000 0.000113085
mln 0.000112964
loss 0.000110497
revs 0.00010924
shrs 0.000108045
avg 0.00010794

topic: 5
the 0.000120084
to 0.000113982
of 0.000113214
a 0.000110542
