{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# task5:多类型情感分析" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "在之前的所有学习中,我们的数据集对于情感的分析只有两个分类:正面或负面。当我们只有两个类时,我们的输出可以是单个标量,范围在 0 和 1 之间,表示示例属于哪个类。当我们有 2 个以上的例子时,我们的输出必须是一个 $C$ 维向量,其中 $C$ 是类的数量。\n", "\n", "在本次学习中,我们将对具有 6 个类的数据集执行分类。请注意,该数据集实际上并不是情感分析数据集,而是问题数据集,任务是对问题所属的类别进行分类。但是,本次学习中涵盖的所有内容都适用于任何包含属于 $C$ 类之一的输入序列的示例的数据集。\n", "\n", "下面,我们设置字段并加载数据集,与之前不同的是:\n", "\n", "第一,我们不需要在 `LABEL` 字段中设置 `dtype`。在处理多类问题时,PyTorch 期望标签被数字化为`LongTensor`。\n", "\n", "第二,这次我们使用的是`TREC`数据集而不是`IMDB`数据集。 `fine_grained` 参数允许我们使用细粒度标签(其中有50个类)或不使用(在这种情况下它们将是6个类)。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "D:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py:740: UserWarning: [W094] Model 'en_core_web_sm' (2.2.0) specifies an under-constrained spaCy version requirement: >=2.2.0. This can lead to compatibility problems with older versions, or as new spaCy versions are released, because the model may say it's compatible when it's not. Consider changing the \"spacy_version\" in your meta.json to a version range, with a lower and upper pin. For example: >=3.1.2,<3.2.0\n", " warnings.warn(warn_msg)\n" ] }, { "ename": "OSError", "evalue": "[E053] Could not read config.cfg from D:\\ProgramData\\Anaconda3\\lib\\site-packages\\en_core_web_sm\\en_core_web_sm-2.2.0\\config.cfg", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mOSError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcudnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdeterministic\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m \u001b[0mTEXT\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mField\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtokenize\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'spacy'\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtokenizer_language\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;34m'en_core_web_sm'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 12\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[0mLABEL\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mLabelField\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchtext\\legacy\\data\\field.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, sequential, use_vocab, init_token, eos_token, fix_length, dtype, preprocessing, postprocessing, lower, tokenize, tokenizer_language, include_lengths, batch_first, pad_token, unk_token, pad_first, truncate_first, stop_words, is_target)\u001b[0m\n\u001b[0;32m 159\u001b[0m \u001b[1;31m# in case the tokenizer isn't picklable (e.g. spacy)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 160\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtokenizer_args\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtokenizer_language\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 161\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtokenize\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_tokenizer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtokenize\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtokenizer_language\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 162\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minclude_lengths\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minclude_lengths\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 163\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbatch_first\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbatch_first\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchtext\\data\\utils.py\u001b[0m in \u001b[0;36mget_tokenizer\u001b[1;34m(tokenizer, language)\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mspacy\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 114\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 115\u001b[1;33m \u001b[0mspacy\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mspacy\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlanguage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 116\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mIOError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 117\u001b[0m \u001b[1;31m# Model shortcuts no longer work in spaCy 3.0+, try using fullnames\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\__init__.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(name, vocab, disable, exclude, config)\u001b[0m\n\u001b[0;32m 49\u001b[0m \u001b[0mRETURNS\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mLanguage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mThe\u001b[0m \u001b[0mloaded\u001b[0m \u001b[0mnlp\u001b[0m \u001b[0mobject\u001b[0m\u001b[1;33m.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 50\u001b[0m \"\"\"\n\u001b[1;32m---> 51\u001b[1;33m return util.load_model(\n\u001b[0m\u001b[0;32m 52\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdisable\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 53\u001b[0m )\n", "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py\u001b[0m in \u001b[0;36mload_model\u001b[1;34m(name, vocab, disable, exclude, config)\u001b[0m\n\u001b[0;32m 319\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mget_lang_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreplace\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"blank:\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 320\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mis_package\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# installed as package\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 321\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mload_model_from_package\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 322\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mPath\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# path to model data directory\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 323\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mload_model_from_path\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mPath\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py\u001b[0m in \u001b[0;36mload_model_from_package\u001b[1;34m(name, vocab, disable, exclude, config)\u001b[0m\n\u001b[0;32m 352\u001b[0m \"\"\"\n\u001b[0;32m 353\u001b[0m \u001b[0mcls\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 354\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mcls\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdisable\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 355\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 356\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\en_core_web_sm\\__init__.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(**overrides)\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m**\u001b[0m\u001b[0moverrides\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mload_model_from_init_py\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m__file__\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0moverrides\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py\u001b[0m in \u001b[0;36mload_model_from_init_py\u001b[1;34m(init_file, vocab, disable, exclude, config)\u001b[0m\n\u001b[0;32m 512\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mmodel_path\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 513\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mIOError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mErrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mE052\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdata_path\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 514\u001b[1;33m return load_model_from_path(\n\u001b[0m\u001b[0;32m 515\u001b[0m \u001b[0mdata_path\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 516\u001b[0m \u001b[0mvocab\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py\u001b[0m in \u001b[0;36mload_model_from_path\u001b[1;34m(model_path, meta, vocab, disable, exclude, config)\u001b[0m\n\u001b[0;32m 386\u001b[0m \u001b[0mconfig_path\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel_path\u001b[0m \u001b[1;33m/\u001b[0m \u001b[1;34m\"config.cfg\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 387\u001b[0m \u001b[0moverrides\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdict_to_dot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 388\u001b[1;33m \u001b[0mconfig\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mload_config\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moverrides\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moverrides\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 389\u001b[0m \u001b[0mnlp\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mload_model_from_config\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mvocab\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdisable\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdisable\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 390\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mnlp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfrom_disk\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mexclude\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mexclude\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moverrides\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moverrides\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\spacy\\util.py\u001b[0m in \u001b[0;36mload_config\u001b[1;34m(path, overrides, interpolate)\u001b[0m\n\u001b[0;32m 543\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 544\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mconfig_path\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mconfig_path\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mconfig_path\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_file\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 545\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mIOError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mErrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mE053\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mconfig_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"config.cfg\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 546\u001b[0m return config.from_disk(\n\u001b[0;32m 547\u001b[0m \u001b[0mconfig_path\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moverrides\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0moverrides\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minterpolate\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minterpolate\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mOSError\u001b[0m: [E053] Could not read config.cfg from D:\\ProgramData\\Anaconda3\\lib\\site-packages\\en_core_web_sm\\en_core_web_sm-2.2.0\\config.cfg" ] } ], "source": [ "import torch\n", "from torchtext.legacy import data\n", "from torchtext.legacy import datasets\n", "import random\n", "\n", "SEED = 1234\n", "\n", "torch.manual_seed(SEED)\n", "torch.backends.cudnn.deterministic = True\n", "\n", "TEXT = data.Field(tokenize = 'spacy',tokenizer_language = 'en_core_web_sm')\n", "\n", "LABEL = data.LabelField()\n", "\n", "train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=False)\n", "\n", "train_data, valid_data = train_data.split(random_state = random.seed(SEED))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "下面我们看一个训练集的示例" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'text': ['What', 'is', 'a', 'Cartesian', 'Diver', '?'], 'label': 'DESC'}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vars(train_data[-1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来,我们将构建词汇表。 由于这个数据集很小(只有约 3800 个训练样本),它的词汇量也非常小(约 7500 个不同单词,即one-hot向量为7500维),这意味着我们不需要像以前一样在词汇表上设置“max_size”。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "MAX_VOCAB_SIZE = 25_000\n", "\n", "TEXT.build_vocab(train_data, \n", " max_size = MAX_VOCAB_SIZE, \n", " vectors = \"glove.6B.100d\", \n", " unk_init = torch.Tensor.normal_)\n", "\n", "LABEL.build_vocab(train_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来,我们可以检查标签。\n", "\n", "6 个标签(对于非细粒度情况)对应于数据集中的 6 类问题:\n", "- `HUM`:关于人类的问题\n", "- `ENTY`:关于实体的问题的\n", "- `DESC`:关于要求提供描述的问题\n", "- `NUM`:关于答案为数字的问题\n", "- `LOC`:关于答案是位置的问题\n", "- `ABBR`:关于询问缩写的问题" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "defaultdict(None, {'HUM': 0, 'ENTY': 1, 'DESC': 2, 'NUM': 3, 'LOC': 4, 'ABBR': 5})\n" ] } ], "source": [ "print(LABEL.vocab.stoi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "与往常一样,我们设置了迭代器。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" ] } ], "source": [ "BATCH_SIZE = 64\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n", " (train_data, valid_data, test_data), \n", " batch_size = BATCH_SIZE, \n", " device = device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们将使用上一个notebook中的CNN模型,但是教程中涵盖的任何模型都适用于该数据集。 唯一的区别是现在 `output_dim` 是 $C$维而不是 $2$维。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class CNN(nn.Module):\n", " def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n", " dropout, pad_idx):\n", " \n", " super().__init__()\n", " \n", " self.embedding = nn.Embedding(vocab_size, embedding_dim)\n", " \n", " self.convs = nn.ModuleList([\n", " nn.Conv2d(in_channels = 1, \n", " out_channels = n_filters, \n", " kernel_size = (fs, embedding_dim)) \n", " for fs in filter_sizes\n", " ])\n", " \n", " self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n", " \n", " self.dropout = nn.Dropout(dropout)\n", " \n", " def forward(self, text):\n", " \n", " #text = [sent len, batch size]\n", " \n", " text = text.permute(1, 0)\n", " \n", " #text = [batch size, sent len]\n", " \n", " embedded = self.embedding(text)\n", " \n", " #embedded = [batch size, sent len, emb dim]\n", " \n", " embedded = embedded.unsqueeze(1)\n", " \n", " #embedded = [batch size, 1, sent len, emb dim]\n", " \n", " conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n", " \n", " #conv_n = [batch size, n_filters, sent len - filter_sizes[n]]\n", " \n", " pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n", " \n", " #pooled_n = [batch size, n_filters]\n", " \n", " cat = self.dropout(torch.cat(pooled, dim = 1))\n", "\n", " #cat = [batch size, n_filters * len(filter_sizes)]\n", " \n", " return self.fc(cat)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们定义我们的模型,确保将输出维度: `OUTPUT_DIM` 设置为 $C$。 我们可以通过使用 `LABEL` 词汇的大小轻松获得 $C$,就像我们使用 `TEXT` 词汇的长度来获取输入词汇的大小一样。\n", "\n", "此数据集中的示例比 IMDb 数据集中的示例小很多,因此我们将使用较小的`filter`大小。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "INPUT_DIM = len(TEXT.vocab)\n", "EMBEDDING_DIM = 100\n", "N_FILTERS = 100\n", "FILTER_SIZES = [2,3,4]\n", "OUTPUT_DIM = len(LABEL.vocab)\n", "DROPOUT = 0.5\n", "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n", "\n", "model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "检查参数的数量,我们可以看到较小的`filter`大小意味着我们的参数是 IMDb 数据集上 CNN 模型的三分之一。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The model has 841,806 trainable parameters\n" ] } ], "source": [ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "print(f'The model has {count_parameters(model):,} trainable parameters')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "之后,我们将加载我们的预训练embedding。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n", " [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n", " [ 0.1638, 0.6046, 1.0789, ..., -0.3140, 0.1844, 0.3624],\n", " ...,\n", " [-0.3110, -0.3398, 1.0308, ..., 0.5317, 0.2836, -0.0640],\n", " [ 0.0091, 0.2810, 0.7356, ..., -0.7508, 0.8967, -0.7631],\n", " [ 0.5831, -0.2514, 0.4156, ..., -0.2735, -0.8659, -1.4063]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pretrained_embeddings = TEXT.vocab.vectors\n", "\n", "model.embedding.weight.data.copy_(pretrained_embeddings)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "然后将用0来初始化未知的权重和padding参数。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n", "\n", "model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n", "model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "与之前notebook的另一个不同之处是我们的损失函数。 `BCEWithLogitsLoss` 一般用来做二分类,而 `CrossEntropyLoss`用来做多分类,`CrossEntropyLoss` 对我们的模型输出执行 *softmax* 函数,损失由该函数和标签之间的 *交叉熵 * 给出。\n", "\n", "一般来说:\n", "- 当我们的示例仅属于 $C$ 类之一时,使用 `CrossEntropyLoss`\n", "- 当我们的示例仅属于 2 个类(0 和 1)时使用 `BCEWithLogitsLoss`,并且也用于我们的示例属于 0 和 $C$ 之间的类(也称为多标签分类)的情况。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import torch.optim as optim\n", "\n", "optimizer = optim.Adam(model.parameters())\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "\n", "model = model.to(device)\n", "criterion = criterion.to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "之前,我们有一个函数可以计算二进制标签情况下的准确度,我们说如果值超过 0.5,那么我们会假设它是正的。 在我们有超过 2 个类的情况下,我们的模型输出一个 $C$ 维向量,其中每个元素的值是示例属于该类的置信度。\n", "\n", "例如,在我们的标签中,我们有:'HUM' = 0、'ENTY' = 1、'DESC' = 2、'NUM' = 3、'LOC' = 4 和 'ABBR' = 5。如果我们的输出 模型是这样的:**[5.1, 0.3, 0.1, 2.1, 0.2, 0.6]** 这意味着该模型确信该示例属于第 0 类:这是一个关于人类的问题,并且略微相信该示例属于该第3类:关于数字的问题。\n", "\n", "我们通过执行 `argmax` 来获取批次中每个元素的预测最大值的索引,然后计算它与实际标签相等的次数来计算准确度。 然后我们对整个批次进行平均。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def categorical_accuracy(preds, y):\n", " \"\"\"\n", " Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n", " \"\"\"\n", " top_pred = preds.argmax(1, keepdim = True)\n", " correct = top_pred.eq(y.view_as(top_pred)).sum()\n", " acc = correct.float() / y.shape[0]\n", " return acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "训练循环与之前类似,`CrossEntropyLoss`期望输入数据为 **[batch size, n classes]** ,标签为 **[batch size]** 。\n", "\n", "标签默认需要是一个 `LongTensor`类型的数据,因为我们没有像以前那样将 `dtype` 设置为 `FloatTensor`。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def train(model, iterator, optimizer, criterion):\n", " \n", " epoch_loss = 0\n", " epoch_acc = 0\n", " \n", " model.train()\n", " \n", " for batch in iterator:\n", " \n", " optimizer.zero_grad()\n", " \n", " predictions = model(batch.text)\n", " \n", " loss = criterion(predictions, batch.label)\n", " \n", " acc = categorical_accuracy(predictions, batch.label)\n", " \n", " loss.backward()\n", " \n", " optimizer.step()\n", " \n", " epoch_loss += loss.item()\n", " epoch_acc += acc.item()\n", " \n", " return epoch_loss / len(iterator), epoch_acc / len(iterator)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "像之前一样对循环进行评估" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def evaluate(model, iterator, criterion):\n", " \n", " epoch_loss = 0\n", " epoch_acc = 0\n", " \n", " model.eval()\n", " \n", " with torch.no_grad():\n", " \n", " for batch in iterator:\n", "\n", " predictions = model(batch.text)\n", " \n", " loss = criterion(predictions, batch.label)\n", " \n", " acc = categorical_accuracy(predictions, batch.label)\n", "\n", " epoch_loss += loss.item()\n", " epoch_acc += acc.item()\n", " \n", " return epoch_loss / len(iterator), epoch_acc / len(iterator)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "def epoch_time(start_time, end_time):\n", " elapsed_time = end_time - start_time\n", " elapsed_mins = int(elapsed_time / 60)\n", " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n", " return elapsed_mins, elapsed_secs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "接下来,训练模型" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n", " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 01 | Epoch Time: 0m 0s\n", "\tTrain Loss: 1.312 | Train Acc: 47.11%\n", "\t Val. Loss: 0.947 | Val. Acc: 66.41%\n", "Epoch: 02 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.870 | Train Acc: 69.18%\n", "\t Val. Loss: 0.741 | Val. Acc: 74.14%\n", "Epoch: 03 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.675 | Train Acc: 76.32%\n", "\t Val. Loss: 0.621 | Val. Acc: 78.49%\n", "Epoch: 04 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.506 | Train Acc: 83.97%\n", "\t Val. Loss: 0.547 | Val. Acc: 80.32%\n", "Epoch: 05 | Epoch Time: 0m 0s\n", "\tTrain Loss: 0.373 | Train Acc: 88.23%\n", "\t Val. Loss: 0.487 | Val. Acc: 82.92%\n" ] } ], "source": [ "N_EPOCHS = 5\n", "\n", "best_valid_loss = float('inf')\n", "\n", "for epoch in range(N_EPOCHS):\n", "\n", " start_time = time.time()\n", " \n", " train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n", " valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n", " \n", " end_time = time.time()\n", "\n", " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n", " \n", " if valid_loss < best_valid_loss:\n", " best_valid_loss = valid_loss\n", " torch.save(model.state_dict(), 'tut5-model.pt')\n", " \n", " print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n", " print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n", " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "最后,在测试集上运行我们的模型" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test Loss: 0.415 | Test Acc: 86.07%\n" ] } ], "source": [ "model.load_state_dict(torch.load('tut5-model.pt'))\n", "\n", "test_loss, test_acc = evaluate(model, test_iterator, criterion)\n", "\n", "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "类似于我们创建一个函数来预测任何给定句子的情绪,我们现在可以创建一个函数来预测给定问题的类别。\n", "\n", "这里唯一的区别是,我们没有使用 sigmoid 函数将输入压缩在 0 和 1 之间,而是使用 `argmax` 来获得最高的预测类索引。 然后我们使用这个索引和标签 vocab 来获得可读的标签string。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "import spacy\n", "nlp = spacy.load('en_core_web_sm')\n", "\n", "def predict_class(model, sentence, min_len = 4):\n", " model.eval()\n", " tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n", " if len(tokenized) < min_len:\n", " tokenized += [''] * (min_len - len(tokenized))\n", " indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n", " tensor = torch.LongTensor(indexed).to(device)\n", " tensor = tensor.unsqueeze(1)\n", " preds = model(tensor)\n", " max_preds = preds.argmax(dim = 1)\n", " return max_preds.item()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "现在,让我们在几个不同的问题上尝试一下……" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted class is: 0 = HUM\n" ] } ], "source": [ "pred_class = predict_class(model, \"Who is Keyser Söze?\")\n", "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted class is: 3 = NUM\n" ] } ], "source": [ "pred_class = predict_class(model, \"How many minutes are in six hundred and eighteen hours?\")\n", "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted class is: 4 = LOC\n" ] } ], "source": [ "pred_class = predict_class(model, \"What continent is Bulgaria in?\")\n", "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted class is: 5 = ABBR\n" ] } ], "source": [ "pred_class = predict_class(model, \"What does WYSIWYG stand for?\")\n", "print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 2 }