{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "## Открытый курс по машинному обучению\n", "
Автор материала: Ефремова Дина (@ldinka)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#
Исследование возможностей BigARTM
\n", "\n", "##
Тематическое моделирование с помощью BigARTM
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Интро" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "BigARTM — библиотека, предназначенная для тематической категоризации текстов; делает разбиение на темы без «учителя».\n", "\n", "Я собираюсь использовать эту библиотеку для собственных нужд в будущем, но так как она не предназначена для обучения с учителем, решила, что для начала ее стоит протестировать на какой-нибудь уже размеченной выборке. Для этих целей был использован датасет \"20 news groups\".\n", "\n", "Идея экперимента такова:\n", "- делим выборку на обучающую и тестовую;\n", "- обучаем модель на обучающей выборке;\n", "- «подгоняем» выделенные темы под действительные;\n", "- смотрим, насколько хорошо прошло разбиение;\n", "- тестируем модель на тестовой выборке." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Поехали!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Внимание!** Данный проект был реализован с помощью Python 3.6 и BigARTM 0.9.0. Методы, рассмотренные здесь, могут отличаться от методов в других версиях библиотеки." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Немножко теории" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "У нас есть словарь терминов $W = \\{w \\in W\\}$, который представляет из себя мешок слов, биграмм или n-грамм;\n", "\n", "Есть коллекция документов $D = \\{d \\in D\\}$, где $d \\subset W$;\n", "\n", "Есть известное множество тем $T = \\{t \\in T\\}$;\n", "\n", "$n_{dw}$ — сколько раз термин $w$ встретился в документе $d$;\n", "\n", "$n_{d}$ — длина документа $d$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Мы считаем, что существует матрица $\\Phi$ распределения терминов $w$ в темах $t$: (фи) $\\Phi = (\\phi_{wt})$\n", "\n", "и матрица распределения тем $t$ в документах $d$: (тета) $\\Theta = (\\theta_{td})$,\n", "\n", "переумножение которых дает нам тематическую модель, или, другими словами, представление наблюдаемого условного распределения $p(w|d)$ терминов $w$ в документах $d$ коллекции $D$:\n", "\n", "
$\\large p(w|d) = \\Phi \\Theta$
\n", "\n", "
$$\\large p(w|d) = \\sum_{t \\in T} \\phi_{wt} \\theta_{td}$$
\n", "\n", "где $\\phi_{wt} = p(w|t)$ — вероятности терминов $w$ в каждой теме $t$\n", "\n", "и $\\theta_{td} = p(t|d)$ — вероятности тем $t$ в каждом документе $d$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Нам известны наблюдаемые частоты терминов в документах, это:\n", "\n", "
$ \\large \\hat{p}(w|d) = \\frac {n_{dw}} {n_{d}} $
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Таким образом, наша задача тематического моделирования становится задачей стохастического матричного разложения матрицы $\\hat{p}(w|d)$ на стохастические матрицы $\\Phi$ и $\\Theta$.\n", "\n", "Напомню, что матрица является стохастической, если каждый ее столбец представляет дискретное распределение вероятностей, сумма значений каждого столбца равна 1." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Воспользовавшись принципом максимального правдоподобия, т. е. максимизируя логарифм правдоподобия, мы получим:\n", "\n", "
$\n", "\\begin{cases}\n", "\\sum_{d \\in D} \\sum_{w \\in d} n_{dw} \\ln \\sum_{t \\in T} \\phi_{wt} \\theta_{td} \\rightarrow \\max\\limits_{\\Phi,\\Theta};\\\\\n", "\\sum_{w \\in W} \\phi_{wt} = 1, \\qquad \\phi_{wt}\\geq0;\\\\\n", "\\sum_{t \\in T} \\theta_{td} = 1, \\quad\\quad\\;\\; \\theta_{td}\\geq0.\n", "\\end{cases}\n", "$
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Чтобы из множества решений выбрать наиболее подходящее, введем критерий регуляризации $R(\\Phi, \\Theta)$:\n", "\n", "
$\n", "\\begin{cases}\n", "\\sum_{d \\in D} \\sum_{w \\in d} n_{dw} \\ln \\sum_{t \\in T} \\phi_{wt} \\theta_{td} + R(\\Phi, \\Theta) \\rightarrow \\max\\limits_{\\Phi,\\Theta};\\\\\n", "\\sum_{w \\in W} \\phi_{wt} = 1, \\qquad \\phi_{wt}\\geq0;\\\\\n", "\\sum_{t \\in T} \\theta_{td} = 1, \\quad\\quad\\;\\; \\theta_{td}\\geq0.\n", "\\end{cases}\n", "$
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Два наиболее известных частных случая этой системы уравнений:\n", "- **PLSA**, вероятностный латентный семантический анализ, когда $R(\\Phi, \\Theta) = 0$\n", "- **LDA**, латентное размещение Дирихле:\n", "$$R(\\Phi, \\Theta) = \\sum_{t,w} (\\beta_{w} - 1) \\ln \\phi_{wt} + \\sum_{d,t} (\\alpha_{t} - 1) \\ln \\theta_{td} $$\n", "где $\\beta_{w} > 0$, $\\alpha_{t} > 0$ — параметры регуляризатора." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Однако оказывается запас неединственности решения настолько большой, что на модель можно накладывать сразу несколько ограничений, такой подход называется **ARTM**, или аддитивной регуляризацией тематических моделей:\n", "\n", "
$\n", "\\begin{cases}\n", "\\sum_{d,w} n_{dw} \\ln \\sum_{t} \\phi_{wt} \\theta_{td} + \\sum_{i=1}^k \\tau_{i} R_{i}(\\Phi, \\Theta) \\rightarrow \\max\\limits_{\\Phi,\\Theta};\\\\\n", "\\sum_{w \\in W} \\phi_{wt} = 1, \\qquad \\phi_{wt}\\geq0;\\\\\n", "\\sum_{t \\in T} \\theta_{td} = 1, \\quad\\quad\\;\\; \\theta_{td}\\geq0.\n", "\\end{cases}\n", "$
\n", "\n", "где $\\tau_{i}$ — коэффициенты регуляризации." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Теперь давайте познакомимся с библиотекой BigARTM и разберем еще некоторые аспекты тематического моделирования на ходу." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Если Вас очень сильно заинтересовала теоретическая часть категоризации текстов и тематического моделирования, рекомендую посмотреть видеолекции из курса Яндекса на Coursera «Поиск структуры в данных» четвертой недели: Тематическое моделирование." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### BigARTM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Установка" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Естественно, для начала работы с библиотекой ее надо установить. Вот несколько видео, которые рассказывают, как это сделать в зависимости от вашей операционной системы:\n", "- Установка BigARTM в Windows\n", "- Установка BigARTM в Linux\n", "- Установка BigARTM в Mac OS X\n", "\n", "Либо можно воспользоваться инструкцией с официального сайта, которая, скорее всего, будет гораздо актуальнее: здесь. Там же указано, как можно установить BigARTM в качестве Docker-контейнера." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Использование BigARTM" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import artm\n", "import re\n", "import numpy as np\n", "import seaborn as sns; sns.set()\n", "\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "from sklearn.preprocessing import normalize\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score\n", "from matplotlib import pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "artm.version()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "Скачаем датасет ***the 20 news groups*** с заранее известным количеством категорий новостей:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_20newsgroups" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "newsgroups = fetch_20newsgroups('../../data/news_data')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "newsgroups['target_names']" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "Приведем данные к формату *Vowpal Wabbit*. Так как BigARTM не рассчитан на обучение с учителем, то мы поступим следующим образом:\n", "- обучим модель на всем корпусе текстов;\n", "- выделим ключевые слова тем и по ним определим, к какой теме они скорее всего относятся;\n", "- сравним наши полученные результаты разбиения с истинными значенями." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "TEXT_FIELD = \"text\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def to_vw_format(document, label=None):\n", " return str(label or '0') + ' |' + TEXT_FIELD + ' ' + ' '.join(re.findall('\\w{3,}', document.lower())) + '\\n'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_documents = newsgroups['data']\n", "all_targets = newsgroups['target']\n", "len(newsgroups['target'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_documents, test_documents, train_labels, test_labels = \\\n", " train_test_split(all_documents, all_targets, random_state=7)\n", "\n", "with open('../../data/news_data/20news_train_mult.vw', 'w') as vw_train_data:\n", " for text, target in zip(train_documents, train_labels):\n", " vw_train_data.write(to_vw_format(text, target))\n", "with open('../../data/news_data/20news_test_mult.vw', 'w') as vw_test_data:\n", " for text in test_documents:\n", " vw_test_data.write(to_vw_format(text))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Загрузим данные в необходимый для BigARTM формат:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch_vectorizer = artm.BatchVectorizer(data_path=\"../../data/news_data/20news_train_mult.vw\",\n", " data_format=\"vowpal_wabbit\",\n", " target_folder=\"news_batches\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Данные в BigARTM загружаются порционно, укажем в \n", "- *data_path* путь к обучающей выборке,\n", "- *data_format* — формат наших данных, может быть:\n", " * *bow_n_wd* — это вектор $n_{wd}$ в виду массива *numpy.ndarray*, также необходимо передать соответствующий словарь терминов, где ключ — это индекс вектора *numpy.ndarray* $n_{wd}$, а значение — соответствующий токен.\n", " ```python\n", " batch_vectorizer = artm.BatchVectorizer(data_format='bow_n_wd',\n", " n_wd=n_wd,\n", " vocabulary=vocabulary)\n", " ```\n", " * *vowpal_wabbit* — формат Vowpal Wabbit;\n", " * *bow_uci* — UCI формат (например, с *vocab.my_collection.txt* и *docword.my_collection.txt* файлами):\n", " ```python\n", " batch_vectorizer = artm.BatchVectorizer(data_path='',\n", " data_format='bow_uci',\n", " collection_name='my_collection',\n", " target_folder='my_collection_batches')\n", " ```\n", " * *batches* — данные, уже сконверченные в батчи с помощью BigARTM;\n", "- *target_folder* — путь для сохранения батчей.\n", "\n", "Пока это все параметры, что нам нужны для загрузки наших данных.\n", "\n", "После того, как BigARTM создал батчи из данных, можно использовать их для загрузки:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch_vectorizer = artm.BatchVectorizer(data_path=\"news_batches\", data_format='batches')" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "Инициируем модель с известным нам количеством тем. Количество тем — это гиперпараметр, поэтому если он заранее нам неизвестен, то его необходимо настраивать, т. е. брать такое количество тем, при котором разбиение кажется наиболее удачным.\n", "\n", "**Важно!** У нас 20 предметных тем, однако некоторые из них довольно узкоспециализированны и смежны, как например 'comp.os.ms-windows.misc' и 'comp.windows.x', или 'comp.sys.ibm.pc.hardware' и 'comp.sys.mac.hardware', тогда как другие размыты и всеобъемлющи: talk.politics.misc' и 'talk.religion.misc'.\n", "\n", "Скорее всего, нам не удастся в чистом виде выделить все 20 тем — некоторые из них окажутся слитными, а другие наоборот раздробятся на более мелкие. Поэтому мы попробуем построить 40 «предметных» тем и одну фоновую. Чем больше вы будем строить категорий, тем лучше мы сможем подстроиться под данные, однако это довольно трудоемкое занятие сидеть потом и распределять в получившиеся темы по реальным категориям (я правда очень-очень задолбалась!).\n", "\n", "Зачем нужны фоновые темы? Дело в том, что наличие общей лексики в темах приводит к плохой ее интерпретируемости. Выделив общую лексику в отдельную тему, мы сильно снизим ее количество в предметных темах, таким образом оставив там лексическое ядро, т. е. ключевые слова, которые данную тему характеризуют. Также этим преобразованием мы снизим коррелированность тем, они станут более независимыми и различимыми." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "T = 41\n", "model_artm = artm.ARTM(num_topics=T,\n", " topic_names=[str(i) for i in range(T)],\n", " class_ids={TEXT_FIELD:1}, \n", " num_document_passes=1,\n", " reuse_theta=True,\n", " cache_theta=True,\n", " seed=4)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "Передаем в модель следующие параметры:\n", "- *num_topics* — количество тем;\n", "- *topic_names* — названия тем;\n", "- *class_ids* — название модальности и ее вес. Дело в том, что кроме самих текстов, в данных может содержаться такая информация, как автор, изображения, ссылки на другие документы и т. д., по которым также можно обучать модель;\n", "- *num_document_passes* — количество проходов при обучении модели;\n", "- *reuse_theta* — переиспользовать ли матрицу $\\Theta$ с предыдущей итерации;\n", "- *cache_theta* — сохранить ли матрицу $\\Theta$ в модели, чтобы в дальнейшем ее использовать.\n", "\n", "Далее необходимо создать словарь; передадим ему какое-нибудь название, которое будем использовать в будущем для работы с этим словарем." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DICTIONARY_NAME = 'dictionary'\n", "\n", "dictionary = artm.Dictionary(DICTIONARY_NAME)\n", "dictionary.gather(batch_vectorizer.data_path)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "Инициализируем модель с тем именем словаря, что мы передали выше, можно зафиксировать *random seed* для вопроизводимости результатов:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1)\n", "model_artm.initialize(DICTIONARY_NAME)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Добавим к модели несколько метрик:\n", "- перплексию (*PerplexityScore*), чтобы индентифицировать сходимость модели\n", " * Перплексия — это известная в вычислительной лингвистике мера качества модели языка. Можно сказать, что это мера неопределенности или различности слов в тексте.\n", "- специальный *score* ключевых слов (*TopTokensScore*), чтобы в дальнейшем мы могли идентифицировать по ним наши тематики;\n", "- разреженность матрицы $\\Phi$ (*SparsityPhiScore*);\n", "- разреженность матрицы $\\Theta$ (*SparsityThetaScore*)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_artm.scores.add(artm.PerplexityScore(name='perplexity_score',\n", " dictionary=DICTIONARY_NAME))\n", "model_artm.scores.add(artm.SparsityPhiScore(name='sparsity_phi_score', class_id=\"text\"))\n", "model_artm.scores.add(artm.SparsityThetaScore(name='sparsity_theta_score'))\n", "model_artm.scores.add(artm.TopTokensScore(name=\"top_words\", num_tokens=15, class_id=TEXT_FIELD))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Следующая операция *fit_offline* займет некоторое время, мы будем обучать модель в режиме *offline* в 40 проходов. Количество проходов влияет на сходимость модели: чем их больше, тем лучше сходится модель." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "model_artm.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=40)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Построим график сходимости модели и увидим, что модель сходится довольно быстро:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(model_artm.score_tracker[\"perplexity_score\"].value);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Выведем значения разреженности матриц:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print('Phi', model_artm.score_tracker[\"sparsity_phi_score\"].last_value)\n", "print('Theta', model_artm.score_tracker[\"sparsity_theta_score\"].last_value)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "После того, как модель сошлась, добавим к ней регуляризаторы. Для начала сглаживающий регуляризатор — это *SmoothSparsePhiRegularizer* с большим положительным коэффициентом $\\tau$, который нужно применить только к фоновой теме, чтобы выделить в нее как можно больше общей лексики. Пусть тема с последним индексом будет фоновой, передадим в *topic_names* этот индекс:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_artm.regularizers.add(artm.SmoothSparsePhiRegularizer(name='SparsePhi', \n", " tau=1e5, \n", " dictionary=dictionary, \n", " class_ids=TEXT_FIELD,\n", " topic_names=str(T-1)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Дообучим модель, сделав 20 проходов по ней с новым регуляризатором:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "model_artm.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Выведем значения разреженности матриц, заметим, что значение для $\\Theta$ немного увеличилось:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print('Phi', model_artm.score_tracker[\"sparsity_phi_score\"].last_value)\n", "print('Theta', model_artm.score_tracker[\"sparsity_theta_score\"].last_value)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Теперь добавим к модели разреживающий регуляризатор, это тот же *SmoothSparsePhiRegularizer* резуляризатор, только с отрицательным значением $\\tau$ и примененный ко всем предметным темам:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_artm.regularizers.add(artm.SmoothSparsePhiRegularizer(name='SparsePhi2', \n", " tau=-5e5, \n", " dictionary=dictionary, \n", " class_ids=TEXT_FIELD,\n", " topic_names=[str(i) for i in range(T-1)]),\n", " overwrite=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "model_artm.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Видим, что значения разреженности увеличились еще больше:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(model_artm.score_tracker[\"sparsity_phi_score\"].last_value)\n", "print(model_artm.score_tracker[\"sparsity_theta_score\"].last_value)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Посмотрим, сколько категорий-строк матрицы $\\Theta$ после регуляризации осталось, т. е. не занулилось/выродилось. И это одна категория:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(model_artm.score_tracker[\"top_words\"].last_tokens.keys())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Теперь выведем ключевые слова тем, чтобы определить, каким образом прошло разбиение, и сделать соответствие с нашим начальным списком тем:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for topic_name in model_artm.score_tracker[\"top_words\"].last_tokens.keys():\n", " tokens = model_artm.score_tracker[\"top_words\"].last_tokens\n", " res_str = topic_name + ': ' + ', '.join(tokens[topic_name])\n", " print(res_str)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Далее мы будем подгонять разбиение под действительные темы с помощью *confusion matrix*." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "target_dict = {\n", " 'alt.atheism': 0,\n", " 'comp.graphics': 1,\n", " 'comp.os.ms-windows.misc': 2,\n", " 'comp.sys.ibm.pc.hardware': 3,\n", " 'comp.sys.mac.hardware': 4,\n", " 'comp.windows.x': 5,\n", " 'misc.forsale': 6,\n", " 'rec.autos': 7,\n", " 'rec.motorcycles': 8,\n", " 'rec.sport.baseball': 9,\n", " 'rec.sport.hockey': 10,\n", " 'sci.crypt': 11,\n", " 'sci.electronics': 12,\n", " 'sci.med': 13,\n", " 'sci.space': 14,\n", " 'soc.religion.christian': 15,\n", " 'talk.politics.guns': 16,\n", " 'talk.politics.mideast': 17,\n", " 'talk.politics.misc': 18,\n", " 'talk.religion.misc': 19\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mixed = [\n", " 'comp.sys.ibm.pc.hardware',\n", " 'talk.politics.mideast',\n", " 'sci.electronics',\n", " 'rec.sport.hockey',\n", "\n", " 'sci.med',\n", " 'rec.motorcycles',\n", " 'comp.graphics',\n", " 'rec.sport.hockey',\n", "\n", " 'talk.politics.mideast',\n", " 'talk.religion.misc',\n", " 'rec.autos',\n", " 'comp.graphics',\n", "\n", " 'sci.space',\n", " 'soc.religion.christian',\n", " 'comp.os.ms-windows.misc',\n", " 'sci.crypt',\n", "\n", " 'comp.windows.x',\n", " 'misc.forsale',\n", " 'sci.space',\n", " 'sci.crypt',\n", "\n", " 'talk.religion.misc',\n", " 'alt.atheism',\n", " 'comp.os.ms-windows.misc',\n", " 'alt.atheism',\n", " \n", " 'sci.med',\n", " 'comp.os.ms-windows.misc',\n", " 'soc.religion.christian',\n", " 'talk.politics.guns',\n", "\n", " 'rec.autos',\n", " 'rec.autos',\n", " 'talk.politics.mideast',\n", " 'rec.sport.baseball',\n", "\n", " 'talk.religion.misc',\n", " 'talk.politics.misc',\n", " 'rec.sport.hockey',\n", " 'comp.sys.mac.hardware',\n", "\n", " 'misc.forsale',\n", " 'sci.space',\n", " 'talk.politics.guns',\n", " 'rec.autos',\n", " \n", " '-'\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Построим небольшой отчет о правильности нашего разбиения:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "theta_train = model_artm.get_theta()\n", "model_labels = []\n", "keys = np.sort([int(i) for i in theta_train.keys()])\n", "for i in keys:\n", " max_val = 0\n", " max_idx = 0\n", " for j in theta_train[i].keys():\n", " if j == str(T-1):\n", " continue\n", " if theta_train[i][j] > max_val:\n", " max_val = theta_train[i][j]\n", " max_idx = j\n", " topic = mixed[int(max_idx)]\n", " if topic == '-':\n", " print(i, '-')\n", " label = target_dict[topic]\n", " model_labels.append(label)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(classification_report(train_labels, model_labels))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(classification_report(train_labels, model_labels))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mat = confusion_matrix(train_labels, model_labels)\n", "sns.heatmap(mat.T, annot=True, fmt='d', cbar=False)\n", "plt.xlabel('True label')\n", "plt.ylabel('Predicted label');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "accuracy_score(train_labels, model_labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Нам удалось добиться 80% *accuracy*. По матрице ответов мы видим, что для модели темы *comp.sys.ibm.pc.hardware* и *comp.sys.mac.hardware* практически не различимы (честно говоря, для меня тоже), в остальном все более или менее прилично.\n", "\n", "Проверим модель на тестовой выборке:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch_vectorizer_test = artm.BatchVectorizer(data_path=\"../../data/news_data/20news_test_mult.vw\",\n", " data_format=\"vowpal_wabbit\",\n", " target_folder=\"news_batches_test\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "theta_test = model_artm.transform(batch_vectorizer_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_score = []\n", "for i in range(len(theta_test.keys())):\n", " max_val = 0\n", " max_idx = 0\n", " for j in theta_test[i].keys():\n", " if j == str(T-1):\n", " continue\n", " if theta_test[i][j] > max_val:\n", " max_val = theta_test[i][j]\n", " max_idx = j\n", " topic = mixed[int(max_idx)]\n", " label = target_dict[topic]\n", " test_score.append(label)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(classification_report(test_labels, test_score))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mat = confusion_matrix(test_labels, test_score)\n", "sns.heatmap(mat.T, annot=True, fmt='d', cbar=False)\n", "plt.xlabel('True label')\n", "plt.ylabel('Predicted label');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "accuracy_score(test_labels, test_score)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Итого почти 77%, незначительно хуже, чем на обучающей." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Вывод:** безумно много времени пришлось потратить на подгонку категорий к реальным темам, но в итоге я осталась довольна результатом. Такие смежные темы, как *alt.atheism*/*soc.religion.christian*/*talk.religion.misc* или *talk.politics.guns*/*talk.politics.mideast*/*talk.politics.misc* разделились вполне неплохо. Думаю, что я все-таки попробую использовать BigARTM в будущем для своих корыстных целей." ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 1 }