"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Интро"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"BigARTM — библиотека, предназначенная для тематической категоризации текстов; делает разбиение на темы без «учителя».\n",
"\n",
"Я собираюсь использовать эту библиотеку для собственных нужд в будущем, но так как она не предназначена для обучения с учителем, решила, что для начала ее стоит протестировать на какой-нибудь уже размеченной выборке. Для этих целей был использован датасет \"20 news groups\".\n",
"\n",
"Идея экперимента такова:\n",
"- делим выборку на обучающую и тестовую;\n",
"- обучаем модель на обучающей выборке;\n",
"- «подгоняем» выделенные темы под действительные;\n",
"- смотрим, насколько хорошо прошло разбиение;\n",
"- тестируем модель на тестовой выборке."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Поехали!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Внимание!** Данный проект был реализован с помощью Python 3.6 и BigARTM 0.9.0. Методы, рассмотренные здесь, могут отличаться от методов в других версиях библиотеки."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Немножко теории"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"У нас есть словарь терминов $W = \\{w \\in W\\}$, который представляет из себя мешок слов, биграмм или n-грамм;\n",
"\n",
"Есть коллекция документов $D = \\{d \\in D\\}$, где $d \\subset W$;\n",
"\n",
"Есть известное множество тем $T = \\{t \\in T\\}$;\n",
"\n",
"$n_{dw}$ — сколько раз термин $w$ встретился в документе $d$;\n",
"\n",
"$n_{d}$ — длина документа $d$."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Мы считаем, что существует матрица $\\Phi$ распределения терминов $w$ в темах $t$: (фи) $\\Phi = (\\phi_{wt})$\n",
"\n",
"и матрица распределения тем $t$ в документах $d$: (тета) $\\Theta = (\\theta_{td})$,\n",
"\n",
"переумножение которых дает нам тематическую модель, или, другими словами, представление наблюдаемого условного распределения $p(w|d)$ терминов $w$ в документах $d$ коллекции $D$:\n",
"\n",
"
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Таким образом, наша задача тематического моделирования становится задачей стохастического матричного разложения матрицы $\\hat{p}(w|d)$ на стохастические матрицы $\\Phi$ и $\\Theta$.\n",
"\n",
"Напомню, что матрица является стохастической, если каждый ее столбец представляет дискретное распределение вероятностей, сумма значений каждого столбца равна 1."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Воспользовавшись принципом максимального правдоподобия, т. е. максимизируя логарифм правдоподобия, мы получим:\n",
"\n",
"
\n",
"\n",
"где $\\tau_{i}$ — коэффициенты регуляризации."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь давайте познакомимся с библиотекой BigARTM и разберем еще некоторые аспекты тематического моделирования на ходу."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Если Вас очень сильно заинтересовала теоретическая часть категоризации текстов и тематического моделирования, рекомендую посмотреть видеолекции из курса Яндекса на Coursera «Поиск структуры в данных» четвертой недели: Тематическое моделирование."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### BigARTM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Установка"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Естественно, для начала работы с библиотекой ее надо установить. Вот несколько видео, которые рассказывают, как это сделать в зависимости от вашей операционной системы:\n",
"- Установка BigARTM в Windows\n",
"- Установка BigARTM в Linux\n",
"- Установка BigARTM в Mac OS X\n",
"\n",
"Либо можно воспользоваться инструкцией с официального сайта, которая, скорее всего, будет гораздо актуальнее: здесь. Там же указано, как можно установить BigARTM в качестве Docker-контейнера."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Использование BigARTM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import artm\n",
"import re\n",
"import numpy as np\n",
"import seaborn as sns; sns.set()\n",
"\n",
"from sklearn.metrics import classification_report, confusion_matrix\n",
"from sklearn.preprocessing import normalize\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"artm.version()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"Скачаем датасет ***the 20 news groups*** с заранее известным количеством категорий новостей:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_20newsgroups"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"newsgroups = fetch_20newsgroups('../../data/news_data')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"newsgroups['target_names']"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"Приведем данные к формату *Vowpal Wabbit*. Так как BigARTM не рассчитан на обучение с учителем, то мы поступим следующим образом:\n",
"- обучим модель на всем корпусе текстов;\n",
"- выделим ключевые слова тем и по ним определим, к какой теме они скорее всего относятся;\n",
"- сравним наши полученные результаты разбиения с истинными значенями."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"TEXT_FIELD = \"text\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def to_vw_format(document, label=None):\n",
" return str(label or '0') + ' |' + TEXT_FIELD + ' ' + ' '.join(re.findall('\\w{3,}', document.lower())) + '\\n'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"all_documents = newsgroups['data']\n",
"all_targets = newsgroups['target']\n",
"len(newsgroups['target'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_documents, test_documents, train_labels, test_labels = \\\n",
" train_test_split(all_documents, all_targets, random_state=7)\n",
"\n",
"with open('../../data/news_data/20news_train_mult.vw', 'w') as vw_train_data:\n",
" for text, target in zip(train_documents, train_labels):\n",
" vw_train_data.write(to_vw_format(text, target))\n",
"with open('../../data/news_data/20news_test_mult.vw', 'w') as vw_test_data:\n",
" for text in test_documents:\n",
" vw_test_data.write(to_vw_format(text))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Загрузим данные в необходимый для BigARTM формат:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_vectorizer = artm.BatchVectorizer(data_path=\"../../data/news_data/20news_train_mult.vw\",\n",
" data_format=\"vowpal_wabbit\",\n",
" target_folder=\"news_batches\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Данные в BigARTM загружаются порционно, укажем в \n",
"- *data_path* путь к обучающей выборке,\n",
"- *data_format* — формат наших данных, может быть:\n",
" * *bow_n_wd* — это вектор $n_{wd}$ в виду массива *numpy.ndarray*, также необходимо передать соответствующий словарь терминов, где ключ — это индекс вектора *numpy.ndarray* $n_{wd}$, а значение — соответствующий токен.\n",
" ```python\n",
" batch_vectorizer = artm.BatchVectorizer(data_format='bow_n_wd',\n",
" n_wd=n_wd,\n",
" vocabulary=vocabulary)\n",
" ```\n",
" * *vowpal_wabbit* — формат Vowpal Wabbit;\n",
" * *bow_uci* — UCI формат (например, с *vocab.my_collection.txt* и *docword.my_collection.txt* файлами):\n",
" ```python\n",
" batch_vectorizer = artm.BatchVectorizer(data_path='',\n",
" data_format='bow_uci',\n",
" collection_name='my_collection',\n",
" target_folder='my_collection_batches')\n",
" ```\n",
" * *batches* — данные, уже сконверченные в батчи с помощью BigARTM;\n",
"- *target_folder* — путь для сохранения батчей.\n",
"\n",
"Пока это все параметры, что нам нужны для загрузки наших данных.\n",
"\n",
"После того, как BigARTM создал батчи из данных, можно использовать их для загрузки:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_vectorizer = artm.BatchVectorizer(data_path=\"news_batches\", data_format='batches')"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"Инициируем модель с известным нам количеством тем. Количество тем — это гиперпараметр, поэтому если он заранее нам неизвестен, то его необходимо настраивать, т. е. брать такое количество тем, при котором разбиение кажется наиболее удачным.\n",
"\n",
"**Важно!** У нас 20 предметных тем, однако некоторые из них довольно узкоспециализированны и смежны, как например 'comp.os.ms-windows.misc' и 'comp.windows.x', или 'comp.sys.ibm.pc.hardware' и 'comp.sys.mac.hardware', тогда как другие размыты и всеобъемлющи: talk.politics.misc' и 'talk.religion.misc'.\n",
"\n",
"Скорее всего, нам не удастся в чистом виде выделить все 20 тем — некоторые из них окажутся слитными, а другие наоборот раздробятся на более мелкие. Поэтому мы попробуем построить 40 «предметных» тем и одну фоновую. Чем больше вы будем строить категорий, тем лучше мы сможем подстроиться под данные, однако это довольно трудоемкое занятие сидеть потом и распределять в получившиеся темы по реальным категориям (я правда очень-очень задолбалась!).\n",
"\n",
"Зачем нужны фоновые темы? Дело в том, что наличие общей лексики в темах приводит к плохой ее интерпретируемости. Выделив общую лексику в отдельную тему, мы сильно снизим ее количество в предметных темах, таким образом оставив там лексическое ядро, т. е. ключевые слова, которые данную тему характеризуют. Также этим преобразованием мы снизим коррелированность тем, они станут более независимыми и различимыми."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"T = 41\n",
"model_artm = artm.ARTM(num_topics=T,\n",
" topic_names=[str(i) for i in range(T)],\n",
" class_ids={TEXT_FIELD:1}, \n",
" num_document_passes=1,\n",
" reuse_theta=True,\n",
" cache_theta=True,\n",
" seed=4)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"Передаем в модель следующие параметры:\n",
"- *num_topics* — количество тем;\n",
"- *topic_names* — названия тем;\n",
"- *class_ids* — название модальности и ее вес. Дело в том, что кроме самих текстов, в данных может содержаться такая информация, как автор, изображения, ссылки на другие документы и т. д., по которым также можно обучать модель;\n",
"- *num_document_passes* — количество проходов при обучении модели;\n",
"- *reuse_theta* — переиспользовать ли матрицу $\\Theta$ с предыдущей итерации;\n",
"- *cache_theta* — сохранить ли матрицу $\\Theta$ в модели, чтобы в дальнейшем ее использовать.\n",
"\n",
"Далее необходимо создать словарь; передадим ему какое-нибудь название, которое будем использовать в будущем для работы с этим словарем."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DICTIONARY_NAME = 'dictionary'\n",
"\n",
"dictionary = artm.Dictionary(DICTIONARY_NAME)\n",
"dictionary.gather(batch_vectorizer.data_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"Инициализируем модель с тем именем словаря, что мы передали выше, можно зафиксировать *random seed* для вопроизводимости результатов:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(1)\n",
"model_artm.initialize(DICTIONARY_NAME)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Добавим к модели несколько метрик:\n",
"- перплексию (*PerplexityScore*), чтобы индентифицировать сходимость модели\n",
" * Перплексия — это известная в вычислительной лингвистике мера качества модели языка. Можно сказать, что это мера неопределенности или различности слов в тексте.\n",
"- специальный *score* ключевых слов (*TopTokensScore*), чтобы в дальнейшем мы могли идентифицировать по ним наши тематики;\n",
"- разреженность матрицы $\\Phi$ (*SparsityPhiScore*);\n",
"- разреженность матрицы $\\Theta$ (*SparsityThetaScore*)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_artm.scores.add(artm.PerplexityScore(name='perplexity_score',\n",
" dictionary=DICTIONARY_NAME))\n",
"model_artm.scores.add(artm.SparsityPhiScore(name='sparsity_phi_score', class_id=\"text\"))\n",
"model_artm.scores.add(artm.SparsityThetaScore(name='sparsity_theta_score'))\n",
"model_artm.scores.add(artm.TopTokensScore(name=\"top_words\", num_tokens=15, class_id=TEXT_FIELD))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Следующая операция *fit_offline* займет некоторое время, мы будем обучать модель в режиме *offline* в 40 проходов. Количество проходов влияет на сходимость модели: чем их больше, тем лучше сходится модель."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model_artm.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=40)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Построим график сходимости модели и увидим, что модель сходится довольно быстро:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(model_artm.score_tracker[\"perplexity_score\"].value);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выведем значения разреженности матриц:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('Phi', model_artm.score_tracker[\"sparsity_phi_score\"].last_value)\n",
"print('Theta', model_artm.score_tracker[\"sparsity_theta_score\"].last_value)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"После того, как модель сошлась, добавим к ней регуляризаторы. Для начала сглаживающий регуляризатор — это *SmoothSparsePhiRegularizer* с большим положительным коэффициентом $\\tau$, который нужно применить только к фоновой теме, чтобы выделить в нее как можно больше общей лексики. Пусть тема с последним индексом будет фоновой, передадим в *topic_names* этот индекс:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_artm.regularizers.add(artm.SmoothSparsePhiRegularizer(name='SparsePhi', \n",
" tau=1e5, \n",
" dictionary=dictionary, \n",
" class_ids=TEXT_FIELD,\n",
" topic_names=str(T-1)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Дообучим модель, сделав 20 проходов по ней с новым регуляризатором:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model_artm.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выведем значения разреженности матриц, заметим, что значение для $\\Theta$ немного увеличилось:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('Phi', model_artm.score_tracker[\"sparsity_phi_score\"].last_value)\n",
"print('Theta', model_artm.score_tracker[\"sparsity_theta_score\"].last_value)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь добавим к модели разреживающий регуляризатор, это тот же *SmoothSparsePhiRegularizer* резуляризатор, только с отрицательным значением $\\tau$ и примененный ко всем предметным темам:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_artm.regularizers.add(artm.SmoothSparsePhiRegularizer(name='SparsePhi2', \n",
" tau=-5e5, \n",
" dictionary=dictionary, \n",
" class_ids=TEXT_FIELD,\n",
" topic_names=[str(i) for i in range(T-1)]),\n",
" overwrite=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"model_artm.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Видим, что значения разреженности увеличились еще больше:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(model_artm.score_tracker[\"sparsity_phi_score\"].last_value)\n",
"print(model_artm.score_tracker[\"sparsity_theta_score\"].last_value)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Посмотрим, сколько категорий-строк матрицы $\\Theta$ после регуляризации осталось, т. е. не занулилось/выродилось. И это одна категория:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(model_artm.score_tracker[\"top_words\"].last_tokens.keys())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Теперь выведем ключевые слова тем, чтобы определить, каким образом прошло разбиение, и сделать соответствие с нашим начальным списком тем:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for topic_name in model_artm.score_tracker[\"top_words\"].last_tokens.keys():\n",
" tokens = model_artm.score_tracker[\"top_words\"].last_tokens\n",
" res_str = topic_name + ': ' + ', '.join(tokens[topic_name])\n",
" print(res_str)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Далее мы будем подгонять разбиение под действительные темы с помощью *confusion matrix*."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"target_dict = {\n",
" 'alt.atheism': 0,\n",
" 'comp.graphics': 1,\n",
" 'comp.os.ms-windows.misc': 2,\n",
" 'comp.sys.ibm.pc.hardware': 3,\n",
" 'comp.sys.mac.hardware': 4,\n",
" 'comp.windows.x': 5,\n",
" 'misc.forsale': 6,\n",
" 'rec.autos': 7,\n",
" 'rec.motorcycles': 8,\n",
" 'rec.sport.baseball': 9,\n",
" 'rec.sport.hockey': 10,\n",
" 'sci.crypt': 11,\n",
" 'sci.electronics': 12,\n",
" 'sci.med': 13,\n",
" 'sci.space': 14,\n",
" 'soc.religion.christian': 15,\n",
" 'talk.politics.guns': 16,\n",
" 'talk.politics.mideast': 17,\n",
" 'talk.politics.misc': 18,\n",
" 'talk.religion.misc': 19\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mixed = [\n",
" 'comp.sys.ibm.pc.hardware',\n",
" 'talk.politics.mideast',\n",
" 'sci.electronics',\n",
" 'rec.sport.hockey',\n",
"\n",
" 'sci.med',\n",
" 'rec.motorcycles',\n",
" 'comp.graphics',\n",
" 'rec.sport.hockey',\n",
"\n",
" 'talk.politics.mideast',\n",
" 'talk.religion.misc',\n",
" 'rec.autos',\n",
" 'comp.graphics',\n",
"\n",
" 'sci.space',\n",
" 'soc.religion.christian',\n",
" 'comp.os.ms-windows.misc',\n",
" 'sci.crypt',\n",
"\n",
" 'comp.windows.x',\n",
" 'misc.forsale',\n",
" 'sci.space',\n",
" 'sci.crypt',\n",
"\n",
" 'talk.religion.misc',\n",
" 'alt.atheism',\n",
" 'comp.os.ms-windows.misc',\n",
" 'alt.atheism',\n",
" \n",
" 'sci.med',\n",
" 'comp.os.ms-windows.misc',\n",
" 'soc.religion.christian',\n",
" 'talk.politics.guns',\n",
"\n",
" 'rec.autos',\n",
" 'rec.autos',\n",
" 'talk.politics.mideast',\n",
" 'rec.sport.baseball',\n",
"\n",
" 'talk.religion.misc',\n",
" 'talk.politics.misc',\n",
" 'rec.sport.hockey',\n",
" 'comp.sys.mac.hardware',\n",
"\n",
" 'misc.forsale',\n",
" 'sci.space',\n",
" 'talk.politics.guns',\n",
" 'rec.autos',\n",
" \n",
" '-'\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Построим небольшой отчет о правильности нашего разбиения:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"theta_train = model_artm.get_theta()\n",
"model_labels = []\n",
"keys = np.sort([int(i) for i in theta_train.keys()])\n",
"for i in keys:\n",
" max_val = 0\n",
" max_idx = 0\n",
" for j in theta_train[i].keys():\n",
" if j == str(T-1):\n",
" continue\n",
" if theta_train[i][j] > max_val:\n",
" max_val = theta_train[i][j]\n",
" max_idx = j\n",
" topic = mixed[int(max_idx)]\n",
" if topic == '-':\n",
" print(i, '-')\n",
" label = target_dict[topic]\n",
" model_labels.append(label)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(classification_report(train_labels, model_labels))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(classification_report(train_labels, model_labels))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mat = confusion_matrix(train_labels, model_labels)\n",
"sns.heatmap(mat.T, annot=True, fmt='d', cbar=False)\n",
"plt.xlabel('True label')\n",
"plt.ylabel('Predicted label');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"accuracy_score(train_labels, model_labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Нам удалось добиться 80% *accuracy*. По матрице ответов мы видим, что для модели темы *comp.sys.ibm.pc.hardware* и *comp.sys.mac.hardware* практически не различимы (честно говоря, для меня тоже), в остальном все более или менее прилично.\n",
"\n",
"Проверим модель на тестовой выборке:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_vectorizer_test = artm.BatchVectorizer(data_path=\"../../data/news_data/20news_test_mult.vw\",\n",
" data_format=\"vowpal_wabbit\",\n",
" target_folder=\"news_batches_test\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"theta_test = model_artm.transform(batch_vectorizer_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_score = []\n",
"for i in range(len(theta_test.keys())):\n",
" max_val = 0\n",
" max_idx = 0\n",
" for j in theta_test[i].keys():\n",
" if j == str(T-1):\n",
" continue\n",
" if theta_test[i][j] > max_val:\n",
" max_val = theta_test[i][j]\n",
" max_idx = j\n",
" topic = mixed[int(max_idx)]\n",
" label = target_dict[topic]\n",
" test_score.append(label)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(classification_report(test_labels, test_score))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mat = confusion_matrix(test_labels, test_score)\n",
"sns.heatmap(mat.T, annot=True, fmt='d', cbar=False)\n",
"plt.xlabel('True label')\n",
"plt.ylabel('Predicted label');"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"accuracy_score(test_labels, test_score)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Итого почти 77%, незначительно хуже, чем на обучающей."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Вывод:** безумно много времени пришлось потратить на подгонку категорий к реальным темам, но в итоге я осталась довольна результатом. Такие смежные темы, как *alt.atheism*/*soc.religion.christian*/*talk.religion.misc* или *talk.politics.guns*/*talk.politics.mideast*/*talk.politics.misc* разделились вполне неплохо. Думаю, что я все-таки попробую использовать BigARTM в будущем для своих корыстных целей."
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 1
}