{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Hierarchies in BigARTM\n", "\n", "Authors — **Nadezhda Chirkova** and **Artem Popov**\n", "\n", "In this tutorial we describe principles of building hierarchies in BigARTM. Notebook was writed for Python 2.7." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "\n", "import artm\n", "from artm import hARTM\n", "\n", "import sys\n", "sys.path.append('utils/')\n", "# you need sklearn for simple loading\n", "from load_collections import load_20newsgroups\n", "\n", "import glob \n", "import os" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'0.8.3'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "artm.version()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Method explaination" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Usual ARTM model\n", "__Data:__ documents set $D$, words set $W$, document-word matrix $\\{n_{dw}\\}_{D \\times W}$. \n", "\n", "__Model:__ Denote $p(w|d) = \\frac{n_{dw}}{\\sum_w n_{dw}}$, $T$ is a topics set. The topic model is\n", "$$ p(w|d) = \\sum_{t \\in T} p(w|t) p(t|d) = \\sum_{t \\in T} \\phi_{wt} \\theta_{td}, \\hspace{3cm} (1) $$\n", "with parameters\n", "\n", "* $\\Phi = \\{\\phi_{wt}\\}_{W \\times T}$\n", "* $\\Theta = \\{ \\theta_{td}\\}_{T \\times D}$\n", "\n", "__Parameter learning:__ regularizer maximum likelihood maximization\n", "$$ \\sum_d \\sum_w n_{dw} \\ln \\sum_t \\phi_{wt} \\theta_{td} + \\sum_i \\tau_i R_i(\\Phi, \\Theta) \\rightarrow max_{\\Phi, \\Theta} $$\n", "where regularizers $R(\\Phi, \\Theta) = \\sum_i \\tau_i R_i(\\Phi, \\Theta)$ allows introducing additional subject-specific criterias, $\\tau_i$ are regularizers' coefficients." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### How hierarchy is constructed from several usual models\n", "#### Hierarchy definition:\n", "* __Topic hierarchy__ is an oriented multipartite (multilevel) graph of topics so that edges connect only topics from neighboring levels. \n", "* Zero level consists of the only node called __root__.\n", "* Each none-zero level has more topics than previous one. Previous level is called __parent level__.\n", "* If there is edge topic-subtopic in hierarchy, topic is also called __parent topic__ or __ancestor__. \n", "\n", "#### Hierarchy construction:\n", "* Root is associated with the whole collection and doesn't need modeling.\n", "* _Every non-zero level is a usual topic model._\n", "* First level has few topics that are main collection topics. First level topics have the only parent topic (root). \n", "* For each level with index > 1 we need to to establish parent-children relationship with previous level topics.\n", "\n", "### Establishing parent-children relations\n", "When we have built parent level, let's denote its topics $a \\in A$ (ancestor) and matrices $\\Phi^p$ and $\\Theta^p$. Now we will build next level model with topics set $T$.\n", "\n", "Let's introduce new matrix factorization problem:\n", " $$ \\phi^p_{wa} = p(w|a) \\approx \\sum_{t} p(w|t) p(t|a) = \\sum_t \\phi_{wt} \\psi_{ta}$$ \n", " \n", "with new parameters matrix $\\Psi = \\{ \\psi_{ta} \\}_{T \\times A}$ containing probabilities p(topic | ancestor topic) calles __link weights__.\n", "\n", "If KL-divergence is a similarity measure between distributions, previous equation produces regularizer for next level model:\n", " $$ R(\\Phi, \\Psi) = \\sum_w \\sum_a \\phi_{wa} \\ln \\sum_t \\phi_{wt} \\psi_{ta} \\rightarrow max_{\\Phi, \\Psi} $$.\n", "\n", " $$ \\sum_d \\sum_w n_{dw} \\ln \\sum_t \\phi_{wt} \\theta_{td} + \\tau R(\\Phi, \\Psi) \\rightarrow max_{\\Phi, \\Psi, \\Theta} $$\n", "Both likelihood and regularizer formulas have common structure. So there is a simple way to train $\\Psi$ simultaneously with $\\Phi$ and $\\Theta$:\n", "\n", "_we just add $|A|$ pseudodocuments to collection, each representing parent $\\Phi$ column: $n_{aw} = \\tau p(w|a)$._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. BigARTM implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hierarchy in BigARTM is implemented in hierarchy_utils module. To build hierarchy, create hARTM instance:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "hier = hARTM()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should pass to hARTM parameters that are common for all levels. These are the same parameters that you pass to usual ARTM model.\n", "\n", "Levels will be built one by one. To add first level, use add_level method specifying remaining model parameters (unique for the level):" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "level0 = hier.add_level(num_topics=10)\n", "level0.scores.add(artm.TopTokensScore(name='TopTokensScore', num_tokens=20, \n", " class_id='text'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This method returns ARTM object so you can work with it as you used: initialize it, fit offline, add regularizer ans scores etc. For example:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "load_20newsgroups()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "data_path = 'data/20newsgroups/20newsgroups_train.vw'\n", "batches_path = 'data/20newsgroups/batches'" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "if len(glob.glob(os.path.join(batches_path + '*.batch'))) < 1:\n", " batch_vectorizer = artm.BatchVectorizer(data_path=data_path, data_format='vowpal_wabbit',\n", " target_folder=batches_path)\n", "else:\n", " batch_vectorizer = artm.BatchVectorizer(data_path=batches_path, data_format='batches')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "dictionary = artm.Dictionary('dictionary')\n", "dictionary.gather(batches_path)\n", "dictionary.filter(min_df=5, max_tf=1000)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "level0.initialize(dictionary=dictionary)\n", "level0.fit_offline(batch_vectorizer, num_collection_passes=30)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "topic_0: gun, state, entry, output, guns, control, bill, crime, states, police, rights, section, tax, firearms, federal, laws, weapons, court, public, ok\n", "topic_1: card, dos, mb, scsi, disk, hard, memory, pc, mac, video, price, drives, monitor, apple, controller, ram, bus, speed, mhz, software\n", "topic_2: encryption, chip, public, nasa, technology, clipper, president, security, keys, privacy, internet, national, launch, administration, earth, research, access, computer, des, phone\n", "topic_3: jesus, true, life, bible, christian, church, fact, mean, faith, christians, religion, christ, evidence, rather, seems, man, word, reason, argument, truth\n", "topic_4: window, db, server, motif, application, widget, sun, display, list, mit, code, manager, ms, running, user, screen, lib, run, cs, subject\n", "topic_5: she, car, her, little, went, maybe, enough, lot, day, thought, anything, put, tell, remember, heard, told, old, again, left, bike\n", "topic_6: game, team, play, games, season, hockey, league, players, win, period, st, la, teams, nhl, vs, player, san, pts, gm, second\n", "topic_7: armenian, israel, turkish, war, jews, armenians, israeli, health, during, turkey, medical, against, jewish, history, children, state, anti, killed, turks, population\n", "topic_8: pl, tm, di, ei, cx, wm, bhj, giz, hz, ah, ey, lk, um, ww, qs, sl, mw, chz, mq, uw\n", "topic_9: image, ftp, files, software, graphics, version, pub, email, send, info, ca, format, images, contact, current, line, code, package, jpeg, source\n" ] } ], "source": [ "for topic_name in level0.topic_names:\n", " print topic_name + ': ',\n", " print \", \".join(level0.score_tracker['TopTokensScore'].last_tokens[topic_name])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When first level is fit, you have to add next level:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "level1 = hier.add_level(num_topics=20, topic_names=['child_topic_' + str(i) for i in range(20)], \n", " parent_level_weight=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When you add this level, parent levels phi matrix will be saved into special, parent level batch.\n", "It is the way how pseudoduments are created.\n", "This created batch will be added to other batches when you fit model.\n", "Explaination of add_level parameters:\n", "* parent_level_weight is regularizer's coefficient $\\tau$. Token_values in parent level batch will be multiplied by parent_level_weight during learning.\n", "* tmp_files_path is a path where model can save this parent level batch.\n", "\n", "These two parameters are ignored during creation of first level.\n", "\n", "Now you can learn level1 model by any means. For example:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "level1.scores.add(artm.TopTokensScore(name='TopTokensScore', num_tokens=20, \n", " class_id='text'))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "level1.initialize(dictionary=dictionary)\n", "level1.fit_offline(batch_vectorizer, num_collection_passes=30)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The part of $\\Theta$ matrix corresponding to parent level batch is $\\Psi$ matrix. To get it, use get_psi method:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "psi = level1.get_psi()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note than level0 has no get_psi method." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "child_topic_0: \n", "seems anything isn actually post rather real maybe reason try least enough idea wrong seem course nothing perhaps tell little\n", "child_topic_1: \n", "window db server motif application widget display sun list mit manager code running lib user screen run font xterm subject\n", "child_topic_2: \n", "group posting important money support discussion feel lot newsgroup groups care article gay business men agree personal freedom postings community\n", "child_topic_3: \n", "she her went saw told day came started home says left heard took thought old maybe remember tell again little\n", "child_topic_4: \n", "ftp pub graphics ca software contact package version email fax comp send cs uk anonymous ac unix computer university address\n", "child_topic_5: \n", "entry output gun ok rules build check line section info entries printf eof stream int char size open title contest\n", "child_topic_6: \n", "jesus bible christian true life faith christians religion christ truth evidence man word argument belief religious cannot science christianity example\n", "child_topic_7: \n", "ground wire current circuit wiring neutral voltage box electrical amp high signal connected line cable test audio electronics subject usually\n", "child_topic_8: \n", "church stephanopoulos president moral mean catholic father pope yes holy society love spirit sex son authority george marriage day morality\n", "child_topic_9: \n", "game team games baseball san league hit runs won win players lost th average home season run pitching fan cubs\n", "child_topic_10: \n", "armenian israel turkish jews armenians war israeli turkey medical during health jewish turks against history anti greek armenia arab genocide\n", "child_topic_11: \n", "card dos mb scsi disk hard memory pc mac video drives price monitor apple controller bus ram mhz sale speed\n", "child_topic_12: \n", "image files info jpeg software gif color format looking version images faq hi graphics send advance email post quality keyboard\n", "child_topic_13: \n", "against fire fbi war koresh killed children country military death kill evidence media trial weapons far compound judge batf killing\n", "child_topic_14: \n", "encryption chip technology clipper nasa public security keys privacy internet launch president administration national earth des research access secure phone\n", "child_topic_15: \n", "game team play hockey season period games la nhl vs players pts gm st league win goal teams pittsburgh flyers\n", "child_topic_16: \n", "car bike water cars engine front enough buy miles oil little speed big air road side high money driving lot\n", "child_topic_17: \n", "pl tm di ei wm giz bhj ey um sl tq kn bxn qax gk bj qq ql tg mf\n", "child_topic_18: \n", "cx hz ww qs ah lk uw ck mw mv sc mc pl md zd tl mt chz ma cj\n", "child_topic_19: \n", "gun state states bill police control rights crime laws tax firearms public federal congress court amendment clinton guns weapons house\n" ] } ], "source": [ "for topic_name in level1.topic_names:\n", " print(topic_name + ': ')\n", " print(\" \".join(level1.score_tracker['TopTokensScore'].last_tokens[topic_name]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can get levels specifying level_index (from 0 as usually in python so first level has index 0):" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "level_index = 0\n", "some_level = hier.get_level(level_index)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To delete level, use:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "level_index = 1\n", "hier.del_level(level_index)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Be careful:__ if you delete not the last level, all next levels will be deleted too." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can access number of levels of hierarchy using .num_levels:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1\n" ] } ], "source": [ "print(hier.num_levels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Improving hierarchy structure" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Hierarchy sparsing regularizer\n", "When building a level of hierarchy a problem can occur. Some topics may have low link weights for all ancestor topics:\n", "$$ p(t|a) \\approx 0 \\quad \\forall a $$\n", "This may occur due to real lack of appropriate parent topics or because such topic tries to be a subtopic of all ancestors. \n", "\n", "To avoid this situation, special __hierarchy sparsing regularizer__ can be used. It affects $\\Psi$ matrix and makes all distributions p(ancestor | topic) be sparse. In this case each topic will have a small amount of parents. As with other sparsing regularizers, we maximize KL(uniform distribution | p(a|t) ). After transformations we get regularizer criteria:\n", "$$R_2(\\Psi) = \\sum_a \\sum_t p(a | t) = \\sum_a \\sum_t \\frac{p(t|a) p(a)} {\\sum_{t'} p(t'|a) p(a)} = \n", "\\sum_a \\sum_t \\frac{\\psi_{ta} p(a)} {\\sum_{t'} \\psi_{ta} p(a)} \\rightarrow max_{\\Psi}$$\n", "\n", "Values p(a) don't slightly affect $\\Psi$ so can be set uniform. Updated M-step:\n", "$$ \\psi_{ta} = \\text{norm}_{t} \\Biggl[ n_{ta} - \\biggl( \\frac 1 {|A|} - p(a | t) \\biggr) \\Biggr]$$\n", "\n", "If ancestor $a$ has high $p(a | t)$ for some $t$, it will be more increased. Links with low $p(a | t)$ will be reduced." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reularizer usage\n", "As $\\Psi$ in BigARTM is part of $\\Theta$, then HierarchySparsingRegularizer is theta regularizer. It can be used the same way as other BigARTM regularizers:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "level1 = hier.add_level(num_topics=20, topic_names=['child_topic_' + str(i) for i in range(20)], \n", " parent_level_weight=1)\n", "level1.regularizers.add(artm.HierarchySparsingThetaRegularizer(name=\"HierSp\", tau=1.0))\n", "level1.scores.add(artm.TopTokensScore(name='TopTokensScore', num_tokens=20, \n", " class_id='text'))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true }, "outputs": [], "source": [ "level1.initialize(dictionary=dictionary)\n", "level1.fit_offline(batch_vectorizer, num_collection_passes=30)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This regularizer can affect only special parent levels phi batches. It means that if you add HierarchySparsingRegularizer to usual, not hierarchy level model, regularizer will have no effect. The same with regular batches' theta, it will not be affected by this regularizer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Hierarchy structure quality measures" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can use all BigARTM scores to assess separate level models. Also there are some measures that describe hierarchy structure. They can be easily computed using numpy so they are not implemented in BigARTM." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Support\n", "Usually it is needed to set psi threshold that is min value of $p(t | a)$ so that link a-t will be included to topic graph. But with high threshold some topics will have no parents. We define __support__ as maximum avialable threshold for Psi matrix with which all topics will have at least one ancestor." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "psi = level1.get_psi()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Psi support: 0.0324899\n" ] } ], "source": [ "print \"Psi support:\", psi.values.max(axis=1).min()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Mean parents count\n", "In BigARTM hierarchy is defined as multilevel topic graph rathjer than topic tree. So it is reasonable to evaluate mean count of parents." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean parents count: 1.15\n" ] } ], "source": [ "psi_threshold = 0.01\n", "parent_counts = np.zeros(0)\n", "for level_idx in range(1, hier.num_levels):\n", " psi = hier.get_level(level_idx).get_psi().values\n", " parent_counts = np.hstack((parent_counts, (psi > psi_threshold).sum(axis=1)))\n", "print \"Mean parents count:\", parent_counts.mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Top tokens relations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can construct parent-children relations with $p(parent|child)$ matrix, but psi matrix gives us only $p(child|parent)$. We can get $p(parent|child)$ this way:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "batch = artm.messages.Batch()\n", "batch_name = 'phi1.batch'\n", "\n", "with open(batch_name, \"rb\") as f:\n", " batch.ParseFromString(f.read())\n", " \n", "Ntw = np.zeros(len(level0.topic_names))\n", " \n", "for i,item in enumerate(batch.item):\n", " for (token_id, token_weight) in zip(item.field[0].token_id, item.field[0].token_weight):\n", " Ntw[i] += token_weight\n", "\n", "Nt1t0 = np.array(psi) * Ntw\n", "psi_bayes = (Nt1t0 / Nt1t0.sum(axis=1)[:, np.newaxis]).T" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": true }, "outputs": [], "source": [ "indexes_child = np.argmax(psi_bayes, axis=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Look for the topic_4 and its children:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "topic_4:\n", "window db server motif application widget sun display list mit code manager ms running user screen lib run cs subject\n", "\n", " child_topic_4: \n", "ftp pub graphics ca software contact package version email fax comp send cs uk anonymous ac unix computer university address\n", "\n", " child_topic_7: \n", "ground wire current circuit wiring neutral voltage box electrical amp high signal connected line cable test audio electronics subject usually\n", "\n", " child_topic_12: \n", "image files info jpeg software gif color format looking version images faq hi graphics send advance email post quality keyboard\n", "\n" ] } ], "source": [ "topic_parent_name = 'topic_4'\n", "print(topic_parent_name + ':')\n", "print(\" \".join(level0.score_tracker['TopTokensScore'].last_tokens[topic_parent_name]))\n", "print('')\n", "\n", "for child in np.where(indexes_child == i)[0]:\n", " print(' ' + level1.topic_names[child] + ': ')\n", " print(\" \".join(level1.score_tracker['TopTokensScore'].last_tokens[level1.topic_names[child]]))\n", " print('')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.12" } }, "nbformat": 4, "nbformat_minor": 1 }