{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Classifying News Headlines and Explaining the Result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data is from Kaggle's [News Aggregator Dataset](https://www.kaggle.com/uciml/news-aggregator-dataset)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I sampled 10% of the data to speed up the analysis." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "news = pd.read_csv('data/uci-news-aggregator.csv').sample(frac=0.1)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "42242" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(news)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDTITLEURLPUBLISHERCATEGORYSTORYHOSTNAMETIMESTAMP
5843458435Russell Crowe Sings Johnny Cash on 'The Tonigh...http://screencrush.com/russell-crowe-johnny-cash/ScreenCrushedxzxHQTC1v6cP7MdjlKbJkMlfYwLMscreencrush.com1396019111324
244967245413HP cuts more jobs than expectedhttp://www.digitaljournal.com/business/busines...DigitalJournal.combde8PjvC03vbwIdMC0hkfXZTLVY0sMwww.digitaljournal.com1400928726875
314969315429NTSB faults pilots in last year's Asiana flighthttp://ktar.com/23/1744462/NTSB-faults-pilots-...KTAR.combdeigsQuEj4RZW3M_TqkzwLBT_oUTMktar.com1403705331596
\n", "
" ], "text/plain": [ " ID TITLE \\\n", "58434 58435 Russell Crowe Sings Johnny Cash on 'The Tonigh... \n", "244967 245413 HP cuts more jobs than expected \n", "314969 315429 NTSB faults pilots in last year's Asiana flight \n", "\n", " URL PUBLISHER \\\n", "58434 http://screencrush.com/russell-crowe-johnny-cash/ ScreenCrush \n", "244967 http://www.digitaljournal.com/business/busines... DigitalJournal.com \n", "314969 http://ktar.com/23/1744462/NTSB-faults-pilots-... KTAR.com \n", "\n", " CATEGORY STORY HOSTNAME \\\n", "58434 e dxzxHQTC1v6cP7MdjlKbJkMlfYwLM screencrush.com \n", "244967 b de8PjvC03vbwIdMC0hkfXZTLVY0sM www.digitaljournal.com \n", "314969 b deigsQuEj4RZW3M_TqkzwLBT_oUTM ktar.com \n", "\n", " TIMESTAMP \n", "58434 1396019111324 \n", "244967 1400928726875 \n", "314969 1403705331596 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "news.head(3)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.preprocessing import LabelEncoder" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [], "source": [ "encoder = LabelEncoder()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "X = news['TITLE']\n", "y = encoder.fit_transform(news['CATEGORY'])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We count the number of occurences of each word and use it as our features." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.feature_extraction.text import CountVectorizer" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [], "source": [ "vectorizer = CountVectorizer(min_df=3)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "<31681x9925 sparse matrix of type ''\n", "\twith 267231 stored elements in Compressed Sparse Row format>" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_vectors = vectorizer.fit_transform(X_train)\n", "test_vectors = vectorizer.transform(X_test)\n", "\n", "train_vectors" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use a random forest for classification." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", " min_impurity_split=1e-07, min_samples_leaf=1,\n", " min_samples_split=2, min_weight_fraction_leaf=0.0,\n", " n_estimators=20, n_jobs=1, oob_score=False, random_state=None,\n", " verbose=0, warm_start=False)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rf = RandomForestClassifier(n_estimators=20)\n", "rf.fit(train_vectors, y_train)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.85048764321560455" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred = rf.predict(test_vectors)\n", "accuracy_score(y_test, pred, )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "85% accuracy, not a bad score." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Explaining the result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use lime to explain the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use lime, we need to construct a pipeline that does the process of vectorizing and classfying together." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.pipeline import make_pipeline" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "collapsed": true }, "outputs": [], "source": [ "c = make_pipeline(vectorizer, rf)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from lime.lime_text import LimeTextExplainer\n", "explainer = LimeTextExplainer(class_names=list(encoder.classes_))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We take an example text from data." ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "'Scientific Games to buy Bally Tech'" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example = X_test.sample(1).iloc[0]\n", "example" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "array([[ 0.95, 0. , 0. , 0.05]])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c.predict_proba([example])" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/libelo/anaconda/lib/python3.5/re.py:203: FutureWarning: split() requires a non-empty pattern match.\n", " return _compile(pattern, flags).split(string, maxsplit)\n" ] } ], "source": [ "exp = explainer.explain_instance(example, c.predict_proba, top_labels=1)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exp.show_in_notebook()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Above is the explanation of the classification generated by lime." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Reference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "https://github.com/marcotcr/lime" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "**dreamgonfly@gmail.com**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.1" } }, "nbformat": 4, "nbformat_minor": 0 }