{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 获取数据" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_20newsgroups\n", "news = fetch_20newsgroups(subset='all')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "18846" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(news.data)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 向量化" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.feature_extraction.text import CountVectorizer\n", "vec = CountVectorizer()\n", "X = vec.fit_transform(news.data)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0 0 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 0 0 0 0]\n", " [3 0 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 0 0 0 0]\n", " [0 0 0 0 0 0 0 0 0 0]]\n" ] } ], "source": [ "print(X[:10, :10].toarray())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 用TF-IDF方法为数据加权" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "tf-idf(英语:term frequency–inverse document frequency)是一种用于信息检索与文本挖掘的常用加权技术。tf-idf是一种统计方法,用以评估一字词对于一个文件集或一个语料库中的其中一份文件的重要程度。。\n", "tf-idf算法是创建在这样一个假设之上的:对区别文档最有意义的词语应该是那些在文档中出现频率高,而在整个文档集合的其他文档中出现频率少的词语,所以如果特征空间坐标系取tf词频作为测度,就可以体现同类文本的特点。" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.feature_extraction.text import TfidfTransformer\n", "TFIDF = TfidfTransformer()\n", "X_tfidf = TFIDF.fit_transform(X)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0. 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. ]\n", " [ 0. 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. ]\n", " [ 0. 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. ]\n", " [ 0. 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. ]\n", " [ 0. 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. ]\n", " [ 0. 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. ]\n", " [ 0. 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. ]\n", " [ 0.14795455 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. ]\n", " [ 0. 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. ]\n", " [ 0. 0. 0. 0. 0. 0. 0.\n", " 0. 0. 0. ]]\n" ] } ], "source": [ "print(X_tfidf[:10, :10].toarray())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 分割数据集" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/zjm/anaconda/lib/python2.7/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n", " \"This module will be removed in 0.20.\", DeprecationWarning)\n" ] } ], "source": [ "\n", "from sklearn.cross_validation import train_test_split\n", "tf_Xtrain, tf_Xtest, tf_ytrain, tf_ytest = train_test_split(X_tfidf, news.target, test_size=0.25, random_state=233)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "参数解释:\n", "train_data:所要划分的样本特征集\n", "train_target:所要划分的样本结果\n", "test_size:样本占比,如果是整数的话就是样本的数量\n", "random_state:是随机数的种子。\n", "随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。\n", "随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:\n", "种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 训练朴素贝叶斯模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "简单解释朴素贝叶斯分类原理:http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.naive_bayes import MultinomialNB\n", "\n", "tf_mnb = MultinomialNB()\n", "tf_mnb.fit(tf_Xtrain, tf_ytrain)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 效果评估" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Naive Bayes with TF-IDF dataset Accuracy: 0.853140916808\n", "\n", " precision recall f1-score support\n", "\n", " alt.atheism 0.93 0.71 0.81 227\n", " comp.graphics 0.89 0.73 0.80 254\n", " comp.os.ms-windows.misc 0.79 0.86 0.82 235\n", "comp.sys.ibm.pc.hardware 0.79 0.83 0.81 252\n", " comp.sys.mac.hardware 0.90 0.83 0.87 237\n", " comp.windows.x 0.93 0.85 0.89 243\n", " misc.forsale 0.94 0.70 0.80 256\n", " rec.autos 0.88 0.91 0.90 250\n", " rec.motorcycles 0.97 0.94 0.95 279\n", " rec.sport.baseball 0.95 0.95 0.95 233\n", " rec.sport.hockey 0.90 0.98 0.94 241\n", " sci.crypt 0.73 0.98 0.84 246\n", " sci.electronics 0.85 0.82 0.84 233\n", " sci.med 0.95 0.92 0.93 229\n", " sci.space 0.88 0.95 0.92 240\n", " soc.religion.christian 0.58 0.98 0.73 252\n", " talk.politics.guns 0.78 0.96 0.86 238\n", " talk.politics.mideast 0.93 0.97 0.95 237\n", " talk.politics.misc 1.00 0.70 0.82 188\n", " talk.religion.misc 0.96 0.18 0.31 142\n", "\n", " avg / total 0.87 0.85 0.85 4712\n", "\n" ] } ], "source": [ "from sklearn.metrics import classification_report\n", "tf_ypredict = tf_mnb.predict(tf_Xtest)\n", "print(\"Naive Bayes with TF-IDF dataset Accuracy: {0}\\n\".format(tf_mnb.score(tf_Xtest, tf_ytest)))\n", "print(classification_report(tf_ytest, tf_ypredict, target_names=news.target_names))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "参考文献:https://zhuanlan.zhihu.com/p/25050912" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.13" } }, "nbformat": 4, "nbformat_minor": 2 }