{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Using wrappers for Scikit learn API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This tutorial is about using gensim models as a part of your scikit learn workflow with the help of wrappers found at ```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The wrapper available (as of now) are :\n", "* LdaModel (```gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel.SklearnWrapperLdaModel```),which implements gensim's ```LdaModel``` in a scikit-learn interface" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### LdaModel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use LdaModel begin with importing LdaModel wrapper" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel import SklearnWrapperLdaModel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we will create a dummy set of texts and convert it into a corpus" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from gensim.corpora import Dictionary\n", "texts = [['complier', 'system', 'computer'],\n", " ['eulerian', 'node', 'cycle', 'graph', 'tree', 'path'],\n", " ['graph', 'flow', 'network', 'graph'],\n", " ['loading', 'computer', 'system'],\n", " ['user', 'server', 'system'],\n", " ['tree','hamiltonian'],\n", " ['graph', 'trees'],\n", " ['computer', 'kernel', 'malfunction','computer'],\n", " ['server','system','computer']]\n", "dictionary = Dictionary(texts)\n", "corpus = [dictionary.doc2bow(text) for text in texts]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then to run the LdaModel on it" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:gensim.models.ldamodel:too few updates, training might not converge; consider increasing the number of passes or iterations to improve accuracy\n" ] }, { "data": { "text/plain": [ "[(0,\n", " u'0.164*\"computer\" + 0.117*\"system\" + 0.105*\"graph\" + 0.061*\"server\" + 0.057*\"tree\" + 0.046*\"malfunction\" + 0.045*\"kernel\" + 0.045*\"complier\" + 0.043*\"loading\" + 0.039*\"hamiltonian\"'),\n", " (1,\n", " u'0.102*\"graph\" + 0.083*\"system\" + 0.072*\"tree\" + 0.064*\"server\" + 0.059*\"user\" + 0.059*\"computer\" + 0.057*\"trees\" + 0.056*\"eulerian\" + 0.055*\"node\" + 0.052*\"flow\"')]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model=SklearnWrapperLdaModel(num_topics=2,id2word=dictionary,iterations=20, random_state=1)\n", "model.fit(corpus)\n", "model.print_topics(2)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### Integration with Sklearn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To provide a better example of how it can be used with Sklearn, Let's use CountVectorizer method of sklearn. For this example we will use [20 Newsgroups data set](http://qwone.com/~jason/20Newsgroups/). We will only use the categories rec.sport.baseball and sci.crypt and use it to generate topics." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "from gensim import matutils\n", "from gensim.models.ldamodel import LdaModel\n", "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import CountVectorizer\n", "from gensim.sklearn_integration.sklearn_wrapper_gensim_ldaModel import SklearnWrapperLdaModel" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "rand = np.random.mtrand.RandomState(1) # set seed for getting same result\n", "cats = ['rec.sport.baseball', 'sci.crypt']\n", "data = fetch_20newsgroups(subset='train',\n", " categories=cats,\n", " shuffle=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we use countvectorizer to convert the collection of text documents to a matrix of token counts." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "vec = CountVectorizer(min_df=10, stop_words='english')\n", "\n", "X = vec.fit_transform(data.data)\n", "vocab = vec.get_feature_names() #vocab to be converted to id2word \n", "\n", "id2word=dict([(i, s) for i, s in enumerate(vocab)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we just need to fit X and id2word to our Lda wrapper." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "[(0,\n", " u'0.018*\"cryptography\" + 0.018*\"face\" + 0.017*\"fierkelab\" + 0.008*\"abuse\" + 0.007*\"constitutional\" + 0.007*\"collection\" + 0.007*\"finish\" + 0.007*\"150\" + 0.007*\"fast\" + 0.006*\"difference\"'),\n", " (1,\n", " u'0.022*\"corporate\" + 0.022*\"accurate\" + 0.012*\"chance\" + 0.008*\"decipher\" + 0.008*\"example\" + 0.008*\"basically\" + 0.008*\"dawson\" + 0.008*\"cases\" + 0.008*\"consideration\" + 0.008*\"follow\"'),\n", " (2,\n", " u'0.034*\"argue\" + 0.031*\"456\" + 0.031*\"arithmetic\" + 0.024*\"courtesy\" + 0.020*\"beastmaster\" + 0.019*\"bitnet\" + 0.015*\"false\" + 0.015*\"classified\" + 0.014*\"cubs\" + 0.014*\"digex\"'),\n", " (3,\n", " u'0.108*\"abroad\" + 0.089*\"asking\" + 0.060*\"cryptography\" + 0.035*\"certain\" + 0.030*\"ciphertext\" + 0.030*\"book\" + 0.028*\"69\" + 0.028*\"demand\" + 0.028*\"87\" + 0.027*\"cracking\"'),\n", " (4,\n", " u'0.022*\"clark\" + 0.019*\"authentication\" + 0.017*\"candidates\" + 0.016*\"decryption\" + 0.015*\"attempt\" + 0.013*\"creation\" + 0.013*\"1993apr5\" + 0.013*\"acceptable\" + 0.013*\"algorithms\" + 0.013*\"employer\"')]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "obj=SklearnWrapperLdaModel(id2word=id2word,num_topics=5,passes=20)\n", "lda=obj.fit(X)\n", "lda.print_topics()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "#### Using together with Scikit learn's Logistic Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now lets try Sklearn's logistic classifier to classify the given categories into two types.Ideally we should get postive weights when cryptography is talked about and negative when baseball is talked about." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn import linear_model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def print_features(clf, vocab, n=10):\n", " ''' Better printing for sorted list '''\n", " coef = clf.coef_[0]\n", " print 'Positive features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[::-1][:n] if coef[j] > 0]))\n", " print 'Negative features: %s' % (' '.join(['%s:%.2f' % (vocab[j], coef[j]) for j in np.argsort(coef)[:n] if coef[j] < 0]))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Positive features: clipper:1.50 code:1.24 key:1.04 encryption:0.95 chip:0.37 nsa:0.37 government:0.36 uk:0.36 org:0.23 cryptography:0.23\n", "Negative features: baseball:-1.32 game:-0.71 year:-0.61 team:-0.38 edu:-0.27 games:-0.26 players:-0.23 ball:-0.17 season:-0.14 phillies:-0.11\n" ] } ], "source": [ "clf=linear_model.LogisticRegression(penalty='l1', C=0.1) #l1 penalty used\n", "clf.fit(X,data.target)\n", "print_features(clf,vocab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }