{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# (MultiFiT) Portuguese Text Classifier on TCU jurisprudência dataset\n", "### MultiFiT configuration\n", "- **Architecture 4 QRNN with 1550 hidden parameters by layer, SentencePiece tokenizer (15 000 tokens)**\n", "- **Hyperparameters and training method from the MultiFiT paper**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Author: [Pierre Guillou](https://www.linkedin.com/in/pierreguillou)\n", "- Date: **edition of October 15, 2019** (initial publication on September 2019)\n", "- Post in medium: [link](https://medium.com/@pierre_guillou/nlp-fastai-portuguese-language-model-980c8ec75362)\n", "- Ref: [Fastai v1](https://docs.fast.ai/) (Deep Learning library on PyTorch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Warning (15/10/2019)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**This notebook is a modified version of the v1 published in September 2019.** Indeed (thanks to [David Vieira](https://medium.com/@davidhsv/ol%C3%A1-pierre-tudo-bom-2bc8ae36dc14)), we noticed that the fine-tuning of the LM and classifier did not use the SentencePiece model and vocab trained for the General Portuguese Language Model ([lm3-portuguese.ipynb](https://github.com/piegu/language-models/blob/master/lm3-portuguese.ipynb)).\n", "\n", "For example, the code used to create the fine-tuned Portuguese forward LM was : \n", "\n", "```data_lm = (TextList.from_df(df_trn_val, path, cols=reviews, \n", " processor=[OpenFileProcessor(), SPProcessor(max_vocab_sz=15000)])\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_for_lm() \n", " .databunch(bs=bs, num_workers=1))```\n", " \n", "It has been corrected by using the [SPProcessor.load()](https://github.com/fastai/fastai/blob/master/fastai/text/data.py#L481) function:\n", "\n", "```data_lm = (TextList.from_df(df_trn_val, path, cols=reviews, processor=SPProcessor.load(dest))\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_for_lm() \n", " .databunch(bs=bs, num_workers=1))```\n", " \n", "Therefore, we retrained the fine-tuned Portuguese forward LM and the classifier on TCU jurisprudência dataset and **got better results! :-)** (see the Results paragraph to get all results)\n", "\n", "- **(fine-tuned) Language Model** \n", " - forward : (accuracy) **51.56%** instead of 44.66% | (perplexity) 11.38 instead of 15.97\n", " - backward: (accuracy) **52.15%** instead of 44.97% | (perplexity) 12.54 instead of 18.73\n", "\n", "- **(fine-tuned) Text Classifier**\n", " - **Accuracy** (ensemble) **97.95%** instead of 97.39%\n", " - **f1 score** (ensemble): **0.9795** instead of 0.9737" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Information" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Overview\n", "\n", "According to this new article \"[MultiFiT: Efficient Multi-lingual Language Model Fine-tuning](https://arxiv.org/abs/1909.04761)\" (September 10, 2019), the QRNN architecture and the SentencePiece tokenizer give better results than AWD-LSTM and the spaCy tokenizer respectively. \n", "\n", "Therefore, they have been used in this notebook to **fine-tune a Portuguese Bidirectional Language Model** by Transfer Learning of a Portuguese Bidirectional Language Model (with the QRNN architecture and the SentencePiece tokenizer, too) trained on a Wikipedia corpus of 100 millions tokens ([lm3-portuguese.ipynb](https://github.com/piegu/language-models/blob/master/lm3-portuguese.ipynb)). \n", "\n", "This Portuguese Bidirectional Language Model has been **fine-tuned on the [tcu_jurisp_reduzido.csv dataset about TCU jurisprudência](https://github.com/fastai-bsb/nlp-tcu-enunciados/blob/master/tcu_jurisp_reduzido.csv?raw=true)\"** and **its encoder part has been transfered to a text classifier which has been finally trained on this corpus**.\n", "\n", "This process **LM General --> LM fine-tuned --> Classifier fine-tuned** is called [ULMFiT](http://nlp.fast.ai/category/classification.html) but we trained our 3 models with the hyperparameters values and method of the [MultiFiT](https://arxiv.org/abs/1909.04761) paper that are given at the end of the MultiFiT paper." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Hyperparameters values\n", "\n", "- Language Model\n", " - (batch size) bs = 50\n", " - (QRNN) 4 QRNN (default: 3) with 1550 hidden parameters each one (default: 1152)\n", " - (SentencePiece) vocab of 15000 tokens\n", " - (dropout) mult_drop = 1.0\n", " - (weight decay) wd = 0.1\n", " - (number of training epochs) 20 epochs\n", " - (learning rate) modified version of 1-cycle learning rate schedule (Smith, 2018) that uses cosine instead of linear annealing, cyclical momentum and discriminative finetuning\n", " - (loss) FlattenedLoss of weighted LabelSmoothingCrossEntropy\n", " \n", "\n", "- Sentiment Classifier\n", " - (batch size) bs = 18\n", " - (SentencePiece) vocab of 15000 tokens\n", " - (dropout) mult_drop = 0.3\n", " - (weight decay) wd = 0.1\n", " - (number of training epochs) 14 epochs (forward) and 19 epochs (backward)\n", " - (learning rate) modified version of 1-cycle learning rate schedule (Smith, 2018) that uses cosine instead of linear annealing, cyclical momentum and discriminative finetuning\n", " - (loss) FlattenedLoss of weighted LabelSmoothingCrossEntropy " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**We can conclude that this Bidirectional Portuguese LM model using the MultiFiT configuration is a good model to perform text classification but with about 46 millions of parameters, it is far from being a LM that can gan compete with [GPT-2](https://openai.com/blog/better-language-models/) or [BERT](https://arxiv.org/abs/1810.04805) in NLP tasks like text generation.**\n", " \n", " \n", "- **About the data**: the dataset [tcu_jurisp_reduzido.csv](https://github.com/fastai-bsb/nlp-tcu-enunciados/blob/master/tcu_jurisp_reduzido.csv?raw=true) about \"TCU jurisprudência\" is unbalanced. Therefore, we used a weighted loss function (FlattenedLoss of weighted LabelSmoothingCrossEntropy).\n", " - number of texts: 10263\n", " - class 0: 3468 (33.79%)\n", " - class 1: 2723 (26.53%)\n", " - class 2: 2297 (22.38%)\n", " - class 3: 1775 (17.3%)\n", "\n", "\n", "- **(fine-tuned) Language Model** \n", " - forward : (accuracy) 51.56% | (perplexity) 11.38\n", " - backward: (accuracy) 52.15% | (perplexity) 12.54\n", " \n", "\n", "- **(fine-tuned) Text Classifier**\n", "\n", " - **Accuracy**\n", " - forward : (global) 97.08% | (class 0) 98.49% | (class 1) 98.24% | (class 2) 96.71% | (class 3) 93.40%\n", " - backward: (global) 97.07% | (class 0) 99.10% | (class 1) 97.89% | (class 2) 96.71% | (class 3) 92.89%\n", " - ensemble: (global) **97.95%** | (class 0) **99.40%** | (class 1) **99.30%** | (class 2) **97.18%** | (class 3) **94.42%**\n", "\n", " - **f1 score**\n", " - forward: 0.9707\n", " - backward: 0.9708\n", " - ensemble: **0.9795**\n", "\n", "(neg = negative reviews | pos = positive reviews)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialisation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "\n", "from fastai import *\n", "from fastai.text import *\n", "from fastai.callbacks import *\n", "\n", "import matplotlib.cm as cm" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r\n", "\r\n", "```text\r\n", "=== Software === \r\n", "python : 3.7.4\r\n", "fastai : 1.0.57\r\n", "fastprogress : 0.1.21\r\n", "torch : 1.2.0\r\n", "nvidia driver : 410.104\r\n", "torch cuda : 10.0.130 / is available\r\n", "torch cudnn : 7602 / is enabled\r\n", "\r\n", "=== Hardware === \r\n", "nvidia gpus : 1\r\n", "torch devices : 1\r\n", " - gpu0 : 16130MB | Tesla V100-SXM2-16GB\r\n", "\r\n", "=== Environment === \r\n", "platform : Linux-4.9.0-9-amd64-x86_64-with-debian-9.9\r\n", "distro : #1 SMP Debian 4.9.168-1+deb9u5 (2019-08-11)\r\n", "conda env : base\r\n", "python : /opt/anaconda3/bin/python\r\n", "sys.path : /home/jupyter/tutorials/fastai/course-nlp\r\n", "/opt/anaconda3/lib/python37.zip\r\n", "/opt/anaconda3/lib/python3.7\r\n", "/opt/anaconda3/lib/python3.7/lib-dynload\r\n", "/opt/anaconda3/lib/python3.7/site-packages\r\n", "/opt/anaconda3/lib/python3.7/site-packages/IPython/extensions\r\n", "```\r\n", "\r\n", "Please make sure to include opening/closing ``` when you paste into forums/github to make the reports appear formatted as code sections.\r\n", "\r\n", "Optional package(s) to enhance the diagnostics can be installed with:\r\n", "pip install distro\r\n", "Once installed, re-run this utility to get the additional information\r\n" ] } ], "source": [ "!python -m fastai.utils.show_install" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# bs=48\n", "# bs=24\n", "bs=50" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "torch.cuda.set_device(0)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "data_path = Config.data_path()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This will create a `{lang}wiki` folder, containing a `{lang}wiki` text file with the wikipedia contents. (For other languages, replace `{lang}` with the appropriate code from the [list of wikipedias](https://meta.wikimedia.org/wiki/List_of_Wikipedias).)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "lang = 'pt'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "name = f'{lang}wiki'\n", "path = data_path/name\n", "path.mkdir(exist_ok=True, parents=True)\n", "\n", "lm_fns3 = [f'{lang}_wt_sp15_multifit', f'{lang}_wt_vocab_sp15_multifit']\n", "lm_fns3_bwd = [f'{lang}_wt_sp15_multifit_bwd', f'{lang}_wt_vocab_sp15_multifit_bwd']" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import f1_score\n", "\n", "@np_func\n", "def f1(inp,targ): return f1_score(targ, np.argmax(inp, axis=-1), average='weighted')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# source: https://github.com/fastai/fastai/blob/master//fastai/layers.py#L300:7\n", "# blog: https://bfarzin.github.io/Label-Smoothing/\n", "class WeightedLabelSmoothingCrossEntropy(nn.Module):\n", " def __init__(self, weight, eps:float=0.1, reduction='mean'):\n", " super().__init__()\n", " self.weight,self.eps,self.reduction = weight,eps,reduction\n", " \n", " def forward(self, output, target):\n", " c = output.size()[-1]\n", " log_preds = F.log_softmax(output, dim=-1)\n", " if self.reduction=='sum': loss = -log_preds.sum()\n", " else:\n", " loss = -log_preds.sum(dim=-1)\n", " if self.reduction=='mean': loss = loss.mean()\n", " return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, weight=self.weight, reduction=self.reduction)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore') # \"error\", \"ignore\", \"always\", \"default\", \"module\" or \"on" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TCU jurisprudência:\n", "- reduzido: https://github.com/fastai-bsb/nlp-tcu-enunciados/blob/master/tcu_jurisp_reduzido.csv\n", "- completo: https://github.com/fastai-bsb/nlp-tcu-enunciados/blob/master/tcu_jurisp.csv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import urllib.request\n", "from converter import *" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# create TCU folder\n", "name_data = 'TCU'\n", "path_data = data_path/name_data\n", "path_data.mkdir(exist_ok=True, parents=True)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 48 ms, sys: 20 ms, total: 68 ms\n", "Wall time: 1.41 s\n" ] }, { "data": { "text/plain": [ "(PosixPath('/home/jupyter/.fastai/data/TCU/tcu_jurisp.csv'),\n", " )" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "# Download each file from url and save it locally under file_name\n", "\n", "url = 'https://github.com/fastai-bsb/nlp-tcu-enunciados/blob/master/tcu_jurisp_reduzido.csv?raw=true'\n", "file_name = 'tcu_jurisp_reduzido.csv'\n", "url_file = path_data/file_name\n", "urllib.request.urlretrieve(url, url_file)\n", "\n", "url = 'https://raw.githubusercontent.com/fastai-bsb/nlp-tcu-enunciados/master/tcu_jurisp.csv'\n", "file_name = 'tcu_jurisp.csv'\n", "url_file = path_data/file_name\n", "urllib.request.urlretrieve(url, url_file)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/home/jupyter/.fastai/data/TCU/tcu_jurisp_reduzido_preprocessed.csv'),\n", " PosixPath('/home/jupyter/.fastai/data/TCU/tcu_jurisp_reduzido.csv'),\n", " PosixPath('/home/jupyter/.fastai/data/TCU/tcu_jurisp.csv')]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path_data.ls()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "text,labels\r\n", "\"A medida cautelar do TCU que determina a suspensão de licitação por falhas no edital não impede o órgão ou a entidade de rever seu ato convocatório, valendo-se do poder de autotutela (art. 49 da Lei 8.666/1993 c/c o art. 9º da Lei 10.520/2002) , com o objetivo de, antecipando-se a eventual deliberação do Tribunal, promover de modo próprio a anulação da licitação e o refazimento do edital, livre dos vícios apontados.\",3\r\n", "\"A retenção de recursos pela Administração com vistas ao ressarcimento do prejuízo ou a existência de ação judicial para o reconhecimento do dano ao erário não constituem óbices ao prosseguimento da tomada de contas especial no TCU. Ocorrendo ressarcimento em uma instância, basta que o responsável apresente essa comprovação perante o juízo de execução para evitar o duplo pagamento.

Enunciado \",3\r\n", "\"Para fins de admissibilidade de recurso de revisão, considera-se documento novo todo aquele ainda não examinado no processo.\",3\r\n" ] } ], "source": [ "!head -n4 {path_data.ls()[0]}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Overview" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# to solve display error of pandas dataframe\n", "get_ipython().config.get('IPKernelApp', {})['parent_appname'] = \"\"" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10263\n", "Counter({0: 3468, 1: 2723, 2: 2297, 3: 1775})\n" ] }, { "data": { "text/html": [ "

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0labelstext
003A medida cautelar do TCU que determina a suspe...
113A retenção de recursos pela Administração com ...
223Para fins de admissibilidade de recurso de rev...
333Fotografias não têm pleno valor probatório, so...
443Não cabe instauração de tomada de contas espec...
\n", "
" ], "text/plain": [ " Unnamed: 0 labels text\n", "0 0 3 A medida cautelar do TCU que determina a suspe...\n", "1 1 3 A retenção de recursos pela Administração com ...\n", "2 2 3 Para fins de admissibilidade de recurso de rev...\n", "3 3 3 Fotografias não têm pleno valor probatório, so...\n", "4 4 3 Não cabe instauração de tomada de contas espec..." ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(path_data/'tcu_jurisp_reduzido.csv', encoding='utf-8')\n", "print(len(df))\n", "print(Counter(df.labels))\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "13258\n", "Counter({0: 3468, 1: 2723, 2: 2297, 3: 1775, 4: 932, 5: 673, 6: 572, 7: 343, 8: 337, 9: 138})\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0labelstext
006O TCU tem competência para fiscalizar a aplica...
116Não compete ao TCU apreciar, para fins de regi...
226Compete ao TCU a apreciação da constitucionali...
336É possível a expedição de determinação pelo TC...
446O TCU não tem competência, no âmbito do Progra...
\n", "
" ], "text/plain": [ " Unnamed: 0 labels text\n", "0 0 6 O TCU tem competência para fiscalizar a aplica...\n", "1 1 6 Não compete ao TCU apreciar, para fins de regi...\n", "2 2 6 Compete ao TCU a apreciação da constitucionali...\n", "3 3 6 É possível a expedição de determinação pelo TC...\n", "4 4 6 O TCU não tem competência, no âmbito do Progra..." ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(path_data/'tcu_jurisp.csv', encoding='utf-8')\n", "print(len(df))\n", "print(Counter(df.labels))\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Analysis (reduzido file)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10263\n", "Counter({0: 3468, 1: 2723, 2: 2297, 3: 1775})\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0labelstext
003A medida cautelar do TCU que determina a suspe...
113A retenção de recursos pela Administração com ...
223Para fins de admissibilidade de recurso de rev...
333Fotografias não têm pleno valor probatório, so...
443Não cabe instauração de tomada de contas espec...
\n", "
" ], "text/plain": [ " Unnamed: 0 labels text\n", "0 0 3 A medida cautelar do TCU que determina a suspe...\n", "1 1 3 A retenção de recursos pela Administração com ...\n", "2 2 3 Para fins de admissibilidade de recurso de rev...\n", "3 3 3 Fotografias não têm pleno valor probatório, so...\n", "4 4 3 Não cabe instauração de tomada de contas espec..." ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(path_data/'tcu_jurisp_reduzido.csv', encoding='utf-8')\n", "print(len(df))\n", "print(Counter(df.labels))\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# columns names\n", "reviews = \"text\"\n", "label = \"labels\"\n", "\n", "# keep columns\n", "df2 = df[[reviews,label]].copy()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(orginal csv) number of all reviews: 10263\n", "there is no empty review.\n", "0 reviews with nan label were deleted\n", "\n", "number of text of class 0: 3468 (33.79%)\n", "number of text of class 1: 2723 (26.53%)\n", "number of text of class 2: 2297 (22.38%)\n", "number of text of class 3: 1775 (17.3%)\n", "\n", "(final) number of all texts: 10263\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabels
0A medida cautelar do TCU que determina a suspe...3
1A retenção de recursos pela Administração com ...3
2Para fins de admissibilidade de recurso de rev...3
3Fotografias não têm pleno valor probatório, so...3
4Não cabe instauração de tomada de contas espec...3
\n", "
" ], "text/plain": [ " text labels\n", "0 A medida cautelar do TCU que determina a suspe... 3\n", "1 A retenção de recursos pela Administração com ... 3\n", "2 Para fins de admissibilidade de recurso de rev... 3\n", "3 Fotografias não têm pleno valor probatório, so... 3\n", "4 Não cabe instauração de tomada de contas espec... 3" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# number of reviews\n", "print(f'(orginal csv) number of all reviews: {len(df2)}')\n", "\n", "# keep not null reviews\n", "## delete nan reviews\n", "empty_nan = (df2[reviews].isnull()).sum()\n", "df2 = df2[df2[reviews].notnull()]\n", "## delete empty reviews\n", "list_idx_none = []\n", "for idxs, row in df2.iterrows():\n", " if row[reviews].strip() == \"\":\n", " df2.drop(idxs, axis=0, inplace=True)\n", " list_idx_none.append(idxs)\n", "empty_none = len(list_idx_none)\n", "## print results\n", "empty = empty_nan+empty_none\n", "if empty != 0:\n", " print(f'{empty} empty reviews were deleted')\n", "else:\n", " print('there is no empty review.')\n", "\n", "# # check that there is no twice the same review\n", "# # keep the first of unique review_id reviews\n", "# same = len(df2) - len(df2[idx].unique())\n", "# if same != 0:\n", "# df2.drop_duplicates(subset=[idx], inplace=True)\n", "# print(f'from the {same} identical reviews ids, only the first one has been kept.')\n", "# else:\n", "# print('there is no identical review id.')\n", "\n", "## delete nan labels\n", "empty_label_nan = (df2[label].isnull()).sum()\n", "df2 = df2[df2[label].notnull()]\n", "print(f'{empty_label_nan} reviews with nan label were deleted')\n", "\n", "# number of reviews by class\n", "counter = Counter(df2[label])\n", "clas_0, clas_1, clas_2, clas_3 = counter[0], counter[1], counter[2], counter[3]\n", "num = len(df2)\n", "pc_clas_0, pc_clas_1 = round((clas_0/num)*100,2), round((clas_1/num)*100,2)\n", "pc_clas_2, pc_clas_3 = round((clas_2/num)*100,2), round((clas_3/num)*100,2)\n", "print(f'\\nnumber of text of class 0: {clas_0} ({pc_clas_0}%)')\n", "print(f'number of text of class 1: {clas_1} ({pc_clas_1}%)')\n", "print(f'number of text of class 2: {clas_2} ({pc_clas_2}%)')\n", "print(f'number of text of class 3: {clas_3} ({pc_clas_3}%)')\n", "print(f'\\n(final) number of all texts: {num}') \n", "\n", "# convert HTML caracters to normal letters\n", "df2[reviews] = df2[reviews].apply(convert)\n", "\n", "df2.head(5)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "number of text of class 0: 3468 (33.79%)\n", "number of text of class 1: 2723 (26.53%)\n", "number of text of class 2: 2297 (22.38%)\n", "number of text of class 3: 1775 (17.3%)\n", "\n", "(final) number of all texts: 10263\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAARO0lEQVR4nO3dbaic533n8e+vsvNAEtby+iSokliZVn1wCpXDWdkQKKmT2rLzQi40YL9IRHBRCzYk0C1V+sZ5WIMXkhgCqUHF2sgljVc0KRGJWld1HUKgtnWcVRTLitdnHW98KmGdrhwnJtTF3v++OJfoRD4Pcx4006Pr+4Fh7vt/X/fM/x7Eb+5zzT2jVBWSpD78wrgbkCSNjqEvSR0x9CWpI4a+JHXE0Jekjlw27gYWc9VVV9W2bdvG3YYkrStPPvnkP1fVxHzb/l2H/rZt25iamhp3G5K0riT5Pwttc3pHkjpi6EtSRwx9SeqIoS9JHVky9JO8JckTSb6X5GSST7X6l5L8MMnxdtvR6knyhSTTSU4kec/AY+1J8my77bl4hyVJms8wV++8CtxQVa8kuRz4TpK/adv+uKr+6oLxNwPb2+064H7guiRXAncDk0ABTyY5XFUvrcWBSJKWtuSZfs15pa1e3m6L/TTnbuDBtt9jwBVJNgE3AUer6lwL+qPArtW1L0lajqHm9JNsSHIcOMtccD/eNt3TpnDuS/LmVtsMvDCw+0yrLVS/8Ln2JplKMjU7O7vMw5EkLWao0K+q16tqB7AF2JnkN4BPAL8G/GfgSuBP2vDM9xCL1C98rv1VNVlVkxMT836hTJK0Qsv6Rm5V/TjJt4BdVfXZVn41yX8H/ktbnwG2Duy2BTjd6u+7oP6t5bcsqQfb9n1z3C2M1fP3fvCiPO4wV+9MJLmiLb8V+ADwgzZPT5IAtwJPtV0OAx9pV/FcD7xcVWeAh4Ebk2xMshG4sdUkSSMyzJn+JuBgkg3MvUkcqqpvJPmHJBPMTdscB/6wjT8C3AJMAz8DPgpQVeeSfAY41sZ9uqrOrd2hSJKWsmToV9UJ4Np56jcsML6AOxfYdgA4sMweJUlrxG/kSlJHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHVkydBP8pYkTyT5XpKTST7V6lcneTzJs0n+R5I3tfqb2/p0275t4LE+0erPJLnpYh2UJGl+w5zpvwrcUFW/CewAdiW5HvhvwH1VtR14Cbijjb8DeKmqfhm4r40jyTXAbcC7gV3AnyXZsJYHI0la3JKhX3NeaauXt1sBNwB/1eoHgVvb8u62Ttv+/iRp9Yeq6tWq+iEwDexck6OQJA1lqDn9JBuSHAfOAkeB/w38uKpea0NmgM1teTPwAkDb/jLwHwfr8+wz+Fx7k0wlmZqdnV3+EUmSFjRU6FfV61W1A9jC3Nn5r883rN1ngW0L1S98rv1VNVlVkxMTE8O0J0ka0rKu3qmqHwPfAq4HrkhyWdu0BTjdlmeArQBt+38Azg3W59lHkjQCw1y9M5Hkirb8VuADwCngUeD32rA9wNfb8uG2Ttv+D1VVrX5bu7rnamA78MRaHYgkaWmXLT2ETcDBdqXNLwCHquobSZ4GHkryX4H/CTzQxj8A/EWSaebO8G8DqKqTSQ4BTwOvAXdW1etreziSpMUsGfpVdQK4dp76c8xz9U1V/QvwoQUe6x7gnuW3KUlaC34jV5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHhvntHUkrsG3fN8fdwlg9f+8Hx92C5uGZviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdWTJ0E+yNcmjSU4lOZnkY63+yST/lOR4u90ysM8nkkwneSbJTQP1Xa02nWTfxTkkSdJChvkZhteAP6qq7yZ5B/BkkqNt231V9dnBwUmuAW4D3g38IvD3SX6lbf4i8DvADHAsyeGqenotDkSStLQlQ7+qzgBn2vJPk5wCNi+yy27goap6FfhhkmlgZ9s2XVXPASR5qI019CVpRJY1p59kG3At8Hgr3ZXkRJIDSTa22mbghYHdZlptobokaUSGDv0kbwe+Cny8qn4C3A/8ErCDub8EPnd+6Dy71yL1C59nb5KpJFOzs7PDtidJGsJQoZ/kcuYC/8tV9TWAqnqxql6vqv8H/Dn/NoUzA2wd2H0LcHqR+s+pqv1VNVlVkxMTE8s9HknSIoa5eifAA8Cpqvr8QH3TwLDfBZ5qy4eB25K8OcnVwHbgCeAYsD3J1UnexNyHvYfX5jAkScMY5uqd9wIfBr6f5Hir/Slwe5IdzE3RPA/8AUBVnUxyiLkPaF8D7qyq1wGS3AU8DGwADlTVyTU8FknSEoa5euc7zD8ff2SRfe4B7pmnfmSx/SRJF5ffyJWkjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSPDfDlLndq275vjbmGsnr/3g+NuQVpznulLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1ZMnQT7I1yaNJTiU5meRjrX5lkqNJnm33G1s9Sb6QZDrJiSTvGXisPW38s0n2XLzDkiTNZ5gz/deAP6qqXweuB+5Mcg2wD3ikqrYDj7R1gJuB7e22F7gf5t4kgLuB64CdwN3n3ygkSaOxZOhX1Zmq+m5b/ilwCtgM7AYOtmEHgVvb8m7gwZrzGHBFkk3ATcDRqjpXVS8BR4Fda3o0kqRFLWtOP8k24FrgceBdVXUG5t4YgHe2YZuBFwZ2m2m1heoXPsfeJFNJpmZnZ5fTniRpCUOHfpK3A18FPl5VP1ls6Dy1WqT+84Wq/VU1WVWTExMTw7YnSRrCUKGf5HLmAv/LVfW1Vn6xTdvQ7s+2+gywdWD3LcDpReqSpBEZ5uqdAA8Ap6rq8wObDgPnr8DZA3x9oP6RdhXP9cDLbfrnYeDGJBvbB7g3tpokaUSG+Y/R3wt8GPh+kuOt9qfAvcChJHcAPwI+1LYdAW4BpoGfAR8FqKpzST4DHGvjPl1V59bkKCRJQ1ky9KvqO8w/Hw/w/nnGF3DnAo91ADiwnAYlSWvHb+RKUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOjLMN3LXrW37vjnuFsbq+Xs/OO4WJP0745m+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjqyZOgnOZDkbJKnBmqfTPJPSY632y0D2z6RZDrJM0luGqjvarXpJPvW/lAkSUsZ5kz/S8Cueer3VdWOdjsCkOQa4Dbg3W2fP0uyIckG4IvAzcA1wO1trCRphJb8aeWq+naSbUM+3m7goap6FfhhkmlgZ9s2XVXPASR5qI19etkdS5JWbDVz+nclOdGmfza22mbghYExM622UP0NkuxNMpVkanZ2dhXtSZIutNLQvx/4JWAHcAb4XKtnnrG1SP2Nxar9VTVZVZMTExMrbE+SNJ8V/c9ZVfXi+eUkfw58o63OAFsHhm4BTrflheqSpBFZ0Zl+kk0Dq78LnL+y5zBwW5I3J7ka2A48ARwDtie5OsmbmPuw9/DK25YkrcSSZ/pJvgK8D7gqyQxwN/C+JDuYm6J5HvgDgKo6meQQcx/QvgbcWVWvt8e5C3gY2AAcqKqTa340kqRFDXP1zu3zlB9YZPw9wD3z1I8AR5bVnSRpTfmNXEnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOLBn6SQ4kOZvkqYHalUmOJnm23W9s9ST5QpLpJCeSvGdgnz1t/LNJ9lycw5EkLWaYM/0vAbsuqO0DHqmq7cAjbR3gZmB7u+0F7oe5NwngbuA6YCdw9/k3CknS6CwZ+lX1beDcBeXdwMG2fBC4daD+YM15DLgiySbgJuBoVZ2rqpeAo7zxjUSSdJGtdE7/XVV1BqDdv7PVNwMvDIybabWF6m+QZG+SqSRTs7OzK2xPkjSftf4gN/PUapH6G4tV+6tqsqomJyYm1rQ5SerdSkP/xTZtQ7s/2+ozwNaBcVuA04vUJUkjtNLQPwycvwJnD/D1gfpH2lU81wMvt+mfh4Ebk2xsH+De2GqSpBG6bKkBSb4CvA+4KskMc1fh3AscSnIH8CPgQ234EeAWYBr4GfBRgKo6l+QzwLE27tNVdeGHw5Kki2zJ0K+q2xfY9P55xhZw5wKPcwA4sKzuJElrym/kSlJHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHVkVaGf5Pkk309yPMlUq12Z5GiSZ9v9xlZPki8kmU5yIsl71uIAJEnDW4sz/d+uqh1VNdnW9wGPVNV24JG2DnAzsL3d9gL3r8FzS5KW4WJM7+wGDrblg8CtA/UHa85jwBVJNl2E55ckLWC1oV/A3yV5MsneVntXVZ0BaPfvbPXNwAsD+8602s9JsjfJVJKp2dnZVbYnSRp02Sr3f29VnU7yTuBokh8sMjbz1OoNhar9wH6AycnJN2yXJK3cqs70q+p0uz8L/DWwE3jx/LRNuz/bhs8AWwd23wKcXs3zS5KWZ8Whn+RtSd5xfhm4EXgKOAzsacP2AF9vy4eBj7SreK4HXj4/DSRJGo3VTO+8C/jrJOcf5y+r6m+THAMOJbkD+BHwoTb+CHALMA38DPjoKp5bkrQCKw79qnoO+M156v8XeP889QLuXOnzSZJWz2/kSlJHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHVk5KGfZFeSZ5JMJ9k36ueXpJ6NNPSTbAC+CNwMXAPcnuSaUfYgST0b9Zn+TmC6qp6rqn8FHgJ2j7gHSepWqmp0T5b8HrCrqn6/rX8YuK6q7hoYsxfY21Z/FXhmZA2uvauAfx53E+uYr9/q+Pqtznp+/f5TVU3Mt+GyETeSeWo/965TVfuB/aNp5+JKMlVVk+PuY73y9VsdX7/VuVRfv1FP78wAWwfWtwCnR9yDJHVr1KF/DNie5OokbwJuAw6PuAdJ6tZIp3eq6rUkdwEPAxuAA1V1cpQ9jNglMU01Rr5+q+PrtzqX5Os30g9yJUnj5TdyJakjhr4kdcTQX2NJ3pLkiSTfS3IyyafG3dN6kmRrkkeTnGqv38fG3dN6k+RAkrNJnhp3L+vRpf5TMc7pr7EkAd5WVa8kuRz4DvCxqnpszK2tC0k2AZuq6rtJ3gE8CdxaVU+PubV1I8lvAa8AD1bVb4y7n/Wk/VTM/wJ+h7lLzI8Bt19K//48019jNeeVtnp5u/nOOqSqOlNV323LPwVOAZvH29X6UlXfBs6Nu4916pL/qRhD/yJIsiHJceAscLSqHh93T+tRkm3AtYCvn0ZlM/DCwPoMl9hJh6F/EVTV61W1g7lvHO9M4p/Yy5Tk7cBXgY9X1U/G3Y+6seRPxax3hv5FVFU/Br4F7BpzK+tK+yzkq8CXq+pr4+5HXbnkfyrG0F9jSSaSXNGW3wp8APjBeLtaP9oH4Q8Ap6rq8+PuR9255H8qxtBfe5uAR5OcYO4f0NGq+saYe1pP3gt8GLghyfF2u2XcTa0nSb4C/CPwq0lmktwx7p7Wi6p6DTj/UzGngEOX2k/FeMmmJHXEM31J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjry/wFgB763f8TqiQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "df_trn_val = df2.copy()\n", "\n", "# number of reviews by class\n", "counter = Counter(df_trn_val[label])\n", "clas_0, clas_1, clas_2, clas_3 = counter[0], counter[1], counter[2], counter[3]\n", "num = len(df_trn_val)\n", "pc_clas_0, pc_clas_1 = round((clas_0/num)*100,2), round((clas_1/num)*100,2)\n", "pc_clas_2, pc_clas_3 = round((clas_2/num)*100,2), round((clas_3/num)*100,2)\n", "print(f'\\nnumber of text of class 0: {clas_0} ({pc_clas_0}%)')\n", "print(f'number of text of class 1: {clas_1} ({pc_clas_1}%)')\n", "print(f'number of text of class 2: {clas_2} ({pc_clas_2}%)')\n", "print(f'number of text of class 3: {clas_3} ({pc_clas_3}%)')\n", "print(f'\\n(final) number of all texts: {num}') \n", "\n", "# plot histogram\n", "keys = list(df_trn_val[label].value_counts().keys())\n", "values = list(df_trn_val[label].value_counts().array)\n", "plt.bar(keys, values[::-1]) \n", "plt.xticks(keys, keys[::-1])\n", "# print(df_trn_val['label'].value_counts())\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabels
0A medida cautelar do TCU que determina a suspe...3
1A retenção de recursos pela Administração com ...3
2Para fins de admissibilidade de recurso de rev...3
3Fotografias não têm pleno valor probatório, so...3
4Não cabe instauração de tomada de contas espec...3
\n", "
" ], "text/plain": [ " text labels\n", "0 A medida cautelar do TCU que determina a suspe... 3\n", "1 A retenção de recursos pela Administração com ... 3\n", "2 Para fins de admissibilidade de recurso de rev... 3\n", "3 Fotografias não têm pleno valor probatório, so... 3\n", "4 Não cabe instauração de tomada de contas espec... 3" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_trn_val.head()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "df_trn_val.to_csv(path_data/'tcu_jurisp_reduzido_preprocessed.csv', index = None, header=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fine-tuning \"forward LM\"" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "name_data = 'TCU'\n", "path_data = data_path/name_data\n", "\n", "# Load csv\n", "df_trn_val = pd.read_csv(path_data/'tcu_jurisp_reduzido_preprocessed.csv')\n", "\n", "# columns names\n", "reviews = \"text\"\n", "label = \"labels\"" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/home/jupyter/.fastai/data/ptwiki/corpus2_100/tmp/spm.model'),\n", " PosixPath('/home/jupyter/.fastai/data/ptwiki/corpus2_100/tmp/spm.vocab')]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dest = path/'corpus2_100'\n", "(dest/'tmp').ls()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Databunch" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1 s, sys: 156 ms, total: 1.16 s\n", "Wall time: 1.67 s\n" ] } ], "source": [ "%%time\n", "data_lm = (TextList.from_df(df_trn_val, path, cols=reviews, processor=SPProcessor.load(dest))\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_for_lm() \n", " .databunch(bs=bs, num_workers=1))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "data_lm.save(f'{path}/{lang}_databunch_lm_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "data_lm = load_data(path, f'{lang}_databunch_lm_tcu_jurisp_reduzido_sp15_multifit_v2', bs=bs)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "config = awd_lstm_lm_config.copy()\n", "config['qrnn'] = True\n", "config['n_hid'] = 1550 #default 1152\n", "config['n_layers'] = 4 #default 3" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 4.14 s, sys: 1.22 s, total: 5.35 s\n", "Wall time: 4.6 s\n" ] } ], "source": [ "%%time\n", "perplexity = Perplexity()\n", "learn_lm = language_model_learner(data_lm, AWD_LSTM, config=config, pretrained_fnames=lm_fns3, drop_mult=1., \n", " metrics=[error_rate, accuracy, perplexity]).to_fp16()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "46020150" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# number of model parameters\n", "sum([p.numel() for p in learn_lm.model.parameters()])" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SequentialRNN(\n", " (0): AWD_LSTM(\n", " (encoder): Embedding(15000, 400, padding_idx=1)\n", " (encoder_dp): EmbeddingDropout(\n", " (emb): Embedding(15000, 400, padding_idx=1)\n", " )\n", " (rnns): ModuleList(\n", " (0): QRNN(\n", " (layers): ModuleList(\n", " (0): QRNNLayer(\n", " (linear): WeightDropout(\n", " (module): Linear(in_features=800, out_features=4650, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (1): QRNN(\n", " (layers): ModuleList(\n", " (0): QRNNLayer(\n", " (linear): WeightDropout(\n", " (module): Linear(in_features=1550, out_features=4650, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (2): QRNN(\n", " (layers): ModuleList(\n", " (0): QRNNLayer(\n", " (linear): WeightDropout(\n", " (module): Linear(in_features=1550, out_features=4650, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (3): QRNN(\n", " (layers): ModuleList(\n", " (0): QRNNLayer(\n", " (linear): WeightDropout(\n", " (module): Linear(in_features=1550, out_features=1200, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " )\n", " (input_dp): RNNDropout()\n", " (hidden_dps): ModuleList(\n", " (0): RNNDropout()\n", " (1): RNNDropout()\n", " (2): RNNDropout()\n", " (3): RNNDropout()\n", " )\n", " )\n", " (1): LinearDecoder(\n", " (decoder): Linear(in_features=400, out_features=15000, bias=True)\n", " (output_dp): RNNDropout()\n", " )\n", ")" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn_lm.model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Change loss function" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FlattenedLoss of CrossEntropyLoss()" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn_lm.loss_func" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "learn_lm.loss_func = FlattenedLoss(LabelSmoothingCrossEntropy)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FlattenedLoss of LabelSmoothingCrossEntropy()" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn_lm.loss_func" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Training" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn_lm.lr_find()" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn_lm.recorder.plot()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "lr = 2e-2\n", "lr *= bs/48\n", "\n", "wd = 0.1" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losserror_rateaccuracyperplexitytime
023.82576613.3853240.9691970.030803628923.31250000:07
18.1618446.8767480.9361910.063810571.47393800:07
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_lm.fit_one_cycle(2, lr*10, wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "learn_lm.save(f'{lang}fine_tuned1_tcu_jurisp_reduzido_sp15_multifit_v2')\n", "learn_lm.save_encoder(f'{lang}fine_tuned1_enc_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losserror_rateaccuracyperplexitytime
06.5003026.1859750.8545710.145429262.54892000:10
15.4906705.0039700.7226800.27732068.92647600:10
24.7553014.4575590.6537550.34624536.10226100:10
34.4160794.2097790.6215510.37844927.89644100:10
44.2767284.1182840.6087210.39127924.50475900:10
54.1994814.0364890.5975650.40243522.28459400:10
64.1175953.9557030.5845170.41548320.70964200:10
74.0235163.9067530.5742450.42575519.21320200:10
83.9424923.8277080.5626530.43734717.80708100:10
93.8496383.7687710.5541770.44582316.52224700:10
103.7526833.6979430.5400540.45994615.37866100:10
113.6595263.6314380.5286800.47132014.28287400:10
123.5502643.5685180.5157960.48420413.36016800:10
133.4402103.5040190.5031560.49684312.48923200:10
143.3269723.4621490.4937690.50623111.92442200:10
153.2333203.4382490.4882580.51174111.56948900:10
163.1726603.4264810.4847480.51525211.41036200:10
173.1465283.4247440.4844490.51555111.38473700:10
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_lm.unfreeze()\n", "learn_lm.fit_one_cycle(18, lr, wd=wd, moms=(0.8,0.7), callbacks=[ShowGraph(learn_lm)])" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "learn_lm.save(f'{lang}fine_tuned2_lenerbr_sp15_multifit_v2')\n", "learn_lm.save_encoder(f'{lang}fine_tuned2_enc_lenerbr_sp15_multifit_v2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save best LM learner and its encoder" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "learn_lm.save(f'{lang}fine_tuned_tcu_jurisp_reduzido_sp15_multifit_v2')\n", "learn_lm.save_encoder(f'{lang}fine_tuned_enc_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fine-tuning \"backward LM\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Databunch" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.03 s, sys: 560 ms, total: 1.59 s\n", "Wall time: 2.12 s\n" ] } ], "source": [ "%%time\n", "data_lm = (TextList.from_df(df_trn_val, path, cols=reviews, processor=SPProcessor.load(dest))\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_for_lm() \n", " .databunch(bs=bs, num_workers=1, backwards=True))" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "data_lm.save(f'{path}/{lang}_databunch_lm_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 84 ms, sys: 12 ms, total: 96 ms\n", "Wall time: 94.2 ms\n" ] } ], "source": [ "%%time\n", "data_lm = load_data(path, f'{lang}_databunch_lm_tcu_jurisp_reduzido_sp15_multifit_bwd_v2', bs=bs, backwards=True)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "config = awd_lstm_lm_config.copy()\n", "config['qrnn'] = True\n", "config['n_hid'] = 1550 #default 1152\n", "config['n_layers'] = 4 #default 3" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.32 s, sys: 220 ms, total: 1.54 s\n", "Wall time: 2.26 s\n" ] } ], "source": [ "%%time\n", "perplexity = Perplexity()\n", "learn_lm = language_model_learner(data_lm, AWD_LSTM, config=config, pretrained_fnames=lm_fns3_bwd, drop_mult=1., \n", " metrics=[error_rate, accuracy, perplexity]).to_fp16()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Change loss function" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FlattenedLoss of CrossEntropyLoss()" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn_lm.loss_func" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "learn_lm.loss_func = FlattenedLoss(LabelSmoothingCrossEntropy)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FlattenedLoss of LabelSmoothingCrossEntropy()" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn_lm.loss_func" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Training" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn_lm.lr_find()" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEGCAYAAABiq/5QAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXgcd53n8fe3L7VuybYOHyF27NhO4pBM7GRyzJCEcISbABngYXYDYcjDznA/sMssO+wMPAwwWXZ2GJ4ZNgsJ7EIYSEIYwplwhBwkJM5pO4ljO5ct2ZJsnd1q9fnbP7oky45kK7Kqu1r1eT1PP+quru76dkv61K9+VfUrc84hIiLhEal2ASIiUlkKfhGRkFHwi4iEjIJfRCRkFPwiIiETq3YBc7Fs2TK3evXqapchIlJTHnrooYPOuY6jp9dE8K9evZqtW7dWuwwRkZpiZs/PNF1dPSIiIaPgFxEJGQW/iEjIKPhFREJGwS8iEjIKfhGRkFHwi4iEjIJfRCSADoxM8JXbd/LMQGrB31vBLyISQC8MjvPPv9lNz3Bmwd9bwS8iEkDpbAGAxrqFH2BBwS8iEkApL/ibFPwiIuGQVvCLiIRLSl09IiLhMhX8ieiCv7eCX0QkgNLZAsl4hFh04WNawS8iEkCpbNGX/n1Q8IuIBFI6W1Dwi4iESSpb8GXHLij4RUQCScEvIhIy6uoREQmZtFr8IiLhoqN6RERCJpXN01S38CdvgYJfRCRwCsUSE/mSunpERMIinSsC/gzQBgp+EZHA8XNkTvAx+M3sejPrN7Pt06Zda2ZPmdnjZnarmbX5tXwRkVrl50VYwN8W/7eAy4+adgewyTn3cuBp4K99XL6ISE0aq9UWv3PuLmDwqGm3O+cK3sP7gVV+LV9EpFbVcov/eK4Gfj7bk2Z2jZltNbOtAwMDFSxLRKS6Dgf/Ijqc08w+AxSA7842j3PuOufcFufclo6OjsoVJyJSZals+aie5rq4L+/vz3bEMZjZVcAbgcucc67SyxcRCTq/W/wVDX4zuxz4L8DFzrnxSi5bRKRW+Hm9XfD3cM7vAfcBG8xsn5m9H/ga0AzcYWaPmtnX/Vq+iEitSmULxCJGXcyfiPatxe+ce/cMk7/p1/JERBaLyZE5zcyX99eZuyIiAZPycSx+UPCLiASOnxdhAQW/iEjglC+76M8RPaDgFxEJnFS26NsRPaDgFxEJnHS2QHNSwS8iEhrpbIHGhIJfRCQ0Uj5eaB0U/CIigeKc01E9IiJhkskXKTn/hmsABb+ISKBMjtPTpJ27IiLhkM5OXmhdx/GLiITC1JDMOqpHRCQcxib8vd4uKPhFRALF7+vtgoJfRCRQ0jnt3BURCZWpo3rU4hcRCQd19YiIhEzK27nbENfhnCIioZDKFmmqixGJ+HPZRVDwi4gEStrni7CAgl9EJFBSOX9H5gQFv4hIoPg9Mico+EVEAiU14e9FWEDBLyISKKlswdeTt0DBLyISKOmcunpEREIlnS3qqB4RkTDx+3q7oOAXEQmMXKFErlCiSTt3RUTCIV2Byy6Cgl9EJDBSFRigDRT8IiKBMTUWv4JfRCQcKjEkMyj4RUQCoxLX2wUfg9/MrjezfjPbPm3alWa2w8xKZrbFr2WLiNSidLYI1HDwA98CLj9q2nbgbcBdPi5XRKQmHe7q8fcELt9WK865u8xs9VHTngQw8+8CAyIitaoS19uFAPfxm9k1ZrbVzLYODAxUuxwREd+F/nBO59x1zrktzrktHR0d1S5HRMR36WyBuliEeNTfaA5s8IuIhE2qAhdhAQW/iEhgpCswQBv4ezjn94D7gA1mts/M3m9mV5jZPuAC4Kdm9ku/li8iUmtS2WJFgt/Po3rePctTt/q1TBGRWpbK5mny+VBOUFePiEhgpLNF9fGLiIRJzffxi4jISzM6UaDZ57H4QcEvIhIIzjmGx3O0NSR8X5aCX0QkAFLZAoWSo70h7vuyFPwiIgEwPJ4HUItfRCQshsZzALQr+EVEwmHIa/Grq0dEJCSGJ1v8jWrxi4iEwlBaXT0iIqEyOJ7HDFrr1dUjIhIKw+M5WpJxohH/r1Co4BcRCYCh8XxFduyCgl9EJBAqddYuKPhFRAJhaDynFr+ISJgMpfMVOaIHFPwiIoGgrh4RkRDJFUqkc0V19YiIhMXkWbttFThrFxT8IiJVN+gF/xJ19YiIhMNQunIDtIGCX0Sk6qa6eoLU4jeztWZW592/xMw+YmZt/pYmIhIOU0MyNwarxX8LUDSzdcA3gTXAjb5VJSISIpW8CAvMPfhLzrkCcAXwv5xzHweW+1eWiEh4DI/nSMYjJOPRiixvrsGfN7N3A1cBP/GmVWabRERkkSsP0FaZ1j7MPfjfB1wAfME596yZrQG+419ZIiLhUcmzdgFic5nJOfcE8BEAM2sHmp1zX/KzMBGRsKjkkMww96N67jSzFjNbAjwG3GBm/9Pf0kREwmEonQtkV0+rc24UeBtwg3NuM/Aq/8oSEQmPofEcbUFr8QMxM1sO/BmHd+6KiMgJKpUcI5k8Syo0Tg/MPfg/B/wS2OOce9DMTgF2+VeWiEg4jE7kKbnKnbULc9+5exNw07THzwBv96soEZGwmDprN2hdPWa2ysxuNbN+M+szs1vMbNVxXnO9N//2adOWmNkdZrbL+9l+oh9ARKSWVfqsXZh7V88NwI+BFcBK4DZv2rF8C7j8qGmfBn7tnDsV+LX3WEQktA4P0BawFj/Q4Zy7wTlX8G7fAjqO9QLn3F3A4FGT3wJ827v/beCtL6VYEZHF5vCQzMFr8R80sz83s6h3+3Pg0DyW1+Wc2w/g/eycbUYzu8bMtprZ1oGBgXksSkQk+ILc1XM15UM5DwD7gXdQHsbBN86565xzW5xzWzo6jrlxISJSs4bGc0QMmpNzOtZmQcwp+J1zLzjn3uyc63DOdTrn3kr5ZK6Xqs87HwDvZ/883kNEZNEYGs/T1pAgErGKLfNErsD1iXm85seUR/jE+/nvJ7B8EZGaN1zhs3bhxIL/mKsnM/secB+wwcz2mdn7gS8BrzazXcCrvcciIqE1lK7skMwwxxO4ZuGO+aRz757lqctOYJkiIovK0HiOVe0NFV3mMYPfzMaYOeANqPelIhGREBkez3Pmysp29Rwz+J1zzZUqREQkjIbGc7RXcIA2OLE+fhEROQGZXJFsoVRTO3dFROQEVOPkLVDwi4hUzWB6MvjV4hcRCYVhb0jmSo7FDwp+EZGqUVePiEjIDI+rq0dEJFSG1NUjIhIug+kcTXUxErHKRrGCX0SkSnqHM6xoS1Z8uQp+EZEq6R3JsLy18qPfKPhFRKqkd3iCFW0KfhGRUJjIFxlM51iprh4RkXDoHc4AqMUvIhIWvcMTgIJfRCQ0Jlv8KxX8IiLh0DOcwQy6WtTHLyISCr3DGTqb6yp+8hYo+EVEqqJax/CDgl9EpCp6hyeq0r8PCn4RkYpzzlVtuAZQ8IuIVNxgOke2UKrKoZyg4BcRqbhqHsMPCn4RkYrrqeIx/KDgFxGpuGoO1wAKfhGRiusdzpCMRyp+ycVJCn4RkQrrHcmwoq0eM6vK8hX8IiIV1jM8wYoqnbwFCn4RkYrbX8Vj+EHBLyJSUdlCkf6xbNV27IKCX0SkovpGskD1jugBBb+ISEVV+xh+qFLwm9lHzWy7me0ws49VowYRkWqo9jH8UIXgN7NNwAeA84CzgDea2amVrkNEpBomg395a7h27p4G3O+cG3fOFYDfAVdUoQ4RkYrrHcmwrClBMh6tWg3VCP7twCvMbKmZNQCvB046eiYzu8bMtprZ1oGBgYoXKSLih57hiap280AVgt859yTwZeAO4BfAY0Bhhvmuc85tcc5t6ejoqHCVIiL+2D+cqWo3D1Rp565z7pvOuXOcc68ABoFd1ahDRKSSDl+Apbot/lg1Fmpmnc65fjN7GfA24IJq1CEiUkmjmQLpXLGqh3JClYIfuMXMlgJ54K+cc0NVqkNEpGL2DY8DVO0i65OqEvzOuT+txnJFRKrp6b4xANZ1NlW1Dp25KyJSIdt7RqmLRVjb0VjVOhT8IiIVsqN3hNOWtxCLVjd6FfwiIhVQKjl29IxyxoqWapei4BcRqYS9Q+OMZQtsWtla7VIU/CIilbCjdxSATSsU/CIiobC9Z4RYxFjfXd0jekDBLyJSEdt7Rzm1q5m6WPUGZ5uk4BcR8Zlzjh09I2wKwI5dUPCLiPiubzTLoXQuEEf0gIJfRMR323tGAAJxRA8o+EVEfLe9dwQzOG25WvwiIqGwo3eUNcsaaayr1riYR1Lwi4j4rLxjNxjdPKDgFxHx1WA6R+/IBJtWBqObBxT8IiK+2tFb3rF7hlr8IiLhsL2nPFRDUA7lBAW/iIivtveOsKq9nraGRLVLmaLgFxHx0bZ9I4Fq7YOCX0TEN88fSvPC4Djnn7K02qUcQcEvIuKTO3cOAHDphs4qV3IkBb+IiE9+u7OfNcsaWb2sutfYPZqCX0TEB5lckfv2HOKSDR3VLuVFgnH+cAjsH8nwk8f28/Pt+4lFI2xa0coZK1rY0N1MLGoUio5CyQFQF4uUb/EopZIjWygykS8xkS8yli0wNlFgNJMHYGVbPSva6lnRlqQ5Ga/mR/RdseTIF0tkCyUOpbIMjGUZSGUZzxZJJqLUx8u35mSM9oYErQ1xmutiRCI29R7OOVLZAodSOQ6ls6SyRSby5Vu2UCIRjZCMR6iLRYlEjEKxRL7oKJRKGEYsasSjRjwaIektL+ktc0ljgvi0i2g75xjLFugfzdI3OsGBkQkOjE4wkS+ypDHB0qY6ljUmaG9M0NYQp70hQTJe/bHaZWHc/8whsoVS4Lp5YJEHv3OOfLEcFrlCiXyxRCZfZDxXZDxXYCSTZ1dfip19Y+w8MMbYRIGTlzZw8tIGVi9tZGlTgoZEjIZElIZElLpYdCoUMvkie/pT7BlIsWcgzXiucMSy49EIiWiEeDTCs4fSPPjcIM7BppUtOODGB55nIl9a0M+biEVorY/TkozRUh+fCqX6eJS6WIRErFxPPBqhUCqRzhbJ5Atk8yWWNCbobk3S2ZKkq7mOrpYkXS1JljUliEVf2obh2ESepw6MsasvxcFUlkOp8pC0mVyxvPxYhHjUyBcd6WyBdLbARL5IXSxKQ12UxkQMh2NgLMvBVI6BsSzpXAHn5ve9xKNGxIxYxMiXHLnCwn7v07XWx1nalCBXKDEwliU7w7IiBqVZPksiFqElGaOxLkZjIkaz97tsScZprS/f2hrKt9b6OEsaE+WVSGMd9QmtNILktzv7qY9HOW/NkmqX8iKLOvj/24+2890/vHDc+bpbkqzvbuaUjiZeOJTmx4/2MjpROO7rJq1oTdJSf7i17RzkS+UVTb7gaGuI87HL1vOms5ZzSkf5smvFkuPZgyl29aUAiEbKrUmAXKHERL5EtlAkYkZdPErS2wJoTsZoScZoTsZxDnqGM/QOZ+gZzjA0nmM0k2c0U2B0Ik8mV5z6mfVWfPliOfjiUZtaqcWjEXb0jjKQylI8KpEiBi31cZqTMZrr4jQlY+WViLcCiUQot4i9935+MM3ewcwR79GSjLG0qY6GRJRC8XCrPRGL0FgXpSERo60hQbZQZDCdY+/gOAAdzXVsWtnKsqYEzXUxYtEIsaiRiEZY1lRHR3P51pCITm0RjeeKjE3kGR7PT30fhZKj6N2iUWOpF5RLmhK0JGPUxaLUJ6IkopGp2rKFEsVSiVgk4rXyI+Xfa7FEYXLLI19uSGTyRUYzeQbTOW9FlyMRi5Tr8+rsaknS3ZqkuyVJXSzCcCZf3mpJZRkZzzM0nmc4k2NkPM+YtzJMZwuMThTYOzjOaCbPSCZPOlec9e8wEYvQXOetNOpiLG1M0NlSR3dLkuVt9ZzW3czpK1poSCzqf/tAcM7xm6f6uWjd0kBuxS3qv4BXnd7Firb6qU3zWDRCQzxKY12U+kSMproYazsaX3RihXOO4fE8w5k847mCt4VQJJsvMlEokc0XScQirO1omveIe9GIsa6zmXWdzSf0Gbtbk2w+uf2E3mNSseQ4lMpyYHSCPq97om90guHxPKlsgbGJPGMTBVLZwtRKregcscjk92u8fGUb79xyEqctL3djdTYnScS0K+loky31U7te2u8/Xywxkimv2EYyOQbTeQbT5S2qEe/3lM6Wf0cHUzme2ZOifyw71Y0YMTi1s5kzVrawsbuZDd0tnNbdTEdzHWZ2nKXLXO0ZSLNvKMMHL15b7VJmtKiD/9INnfPqXzMz2r2+1zCJRozOlnJ3jwRT3NvaWdZUN+fXlEqOvrEJdvSM8njPCNv2DXP3roP88OGeqXmWNdXx8lWtbFrZypkrW1nf1cSq9gaiEa0M5uPOnf0AgdyxC4s8+EUEIhFjeWs9y1vredXpXVPTB9M5njowylP7x9jRO8q2nmHu3Nk/tf8hEYtwyrJGNnY3c+HaZVy4bimr2huq9Clqy2939k+tPINIwS8SUksaE+VAX7tsatp4rsCT+0fZ059m90CKPf0p7tl9iB892gvAyUsbuGjdMv5k3TIuOGVp6LaK5yKVLfDAs4NcfdGaapcyKwW/iExpSMTYfPISNp98+EgU5xxP96W4d/dB7t19kB8/2suNf3gBM9i0onVqRbBldXsgd2RW2r27D5IvOi4J4GGckxT8InJMZsaG7mY2dDdz9Z+sIV8s8fi+Ye7ZdYh7dx/kG3c/w9d/t4e6WITz1izh4vUdXLy+g3WdTaHcYfyDB/fSkowt2EEXfjA334OjK2jLli1u69at1S5DRGaQ9ro27t51kLt2DbC7v3yI8sq2ei7d2MGrTuvigrVLqYst/q2Bu3cN8B+++QCfft3GQBzRY2YPOee2vGi6gl9EFtK+oXF+9/QAd+4c4J5dB8nkizQmoly8oYPLNy3nlRs7aQrIRccXUqFY4g1fvYfxfIFffeLiQKzoZgv+qnz7ZvZx4C8AB2wD3uecm6hGLSKysFa1N/CePz6Z9/zxyUzky+PV3PFkH3c80cfPth0gEYtw8foO3nDmcl5zRteiOaHs+1v3srNvjH99zzmBCP1jqXiL38xWAvcApzvnMmb2A+BnzrlvzfYatfhFal+p5HjohSF+tm0/P992gAOjEzQmorzuzOW87ZyVnL9m6RHjKtWS0Yk8l1x7J+s6m/j+NecHZt9GoFr83nLrzSwPNAC9VapDRCokEjHOXb2Ec1cv4W/ecDoPPjfIDx/u4afb9nPzQ/tY1V7PO7ecxJVbTqK7tbZOIvzab3YzNJ7js288PTChfyxV6eM3s48CXwAywO3OuffMMM81wDUAL3vZyzY///zzlS1SRCoikyty+xMH+P6De/n9nkNErHzW/TvPPYlLN3YeMeJpEG3vGeGKf7mXt569kmuvPKva5RwhMDt3zawduAV4JzAM3ATc7Jz7zmyvUVePSDg8fyjND7bu5aat++gfy9LRXMc7Nq/iz7acxJqAXcwE4LmDad7x9d+TiEb40YcuorM5WFsqQQr+K4HLnXPv9x7/R+B859xfzvYaBb9IuBSKJX67c4DvP/gCv3mqPIzEOS9r44pzVvGmly9/0cCK1dA/OsE7vn4fYxN5bvrghazrbKp2SS8SpD7+F4DzzayBclfPZYBSXUSmxKIRXn16F68+vYsDIxP86NEefvjwPv7mR9v53G07uGxjF2/fvIpLNnRUpStoJJPnqhse5GAqy40fOD+QoX8s1erj/zvKXT0F4BHgL5xz2dnmV4tfRJxz7Ogd5dZHevj3R3s4mMqxtDHBm89ewZvPWsHZJ7VVZMfq9p4R/uut23hy/yjXv/dc/vTUYI7ACQHq6pkPBb+ITJcvlrjr6QFufmgfv36yn1yxxIrWJK87czmvP7Obs09qX/AhpQfGsnzl9p18f+te2hsSfPFtZ/LaM7oXdBkLTcEvIovSSCbPr57o42fb9nP3roPkiuVLiV6yvoNLN3Zy3poldM7zQjPOOR7fN8JPHu/l3x7YSyZf5L0XrubDl51Ka33wr3Gt4BeRRW8kk+d3Tw/wmyf7uPPpAYbH8wC0NcRZ39XM+q4m1nc1s66ziVM7m1nWlDhihVAqOV4YHOepA6M8sneYn287wAuD48SjxmUbu/jU5RtY21E7/fkKfhEJlWLJ8ejeIbb3jPLUgTGe7hvj6QNjjGUPX087GY9QH4+SjEepi0XoG82SyZevaxyNGBeuXcqbXr6C157RTWtD8Fv4RwvSUT0iIr6LRmzGawv0jWbZ1T/Grr4UvcMZsoUSE971tC9tSrCxu5mN3S2s72qmPhHsMXfmS8EvIqFhZnS3JuluTQb6aBy/BftcaBERWXAKfhGRkFHwi4iEjIJfRCRkFPwiIiGj4BcRCRkFv4hIyCj4RURCpiaGbDCzEWDXDE+1AiPHmHa8+5M/lwEH51HaTMufy/NHTz/WY9V9/LqO9/x86p5pWiXrnss0P+s+Xs1zrXG2Ome7P32aH3XP9W9kLrVOvx/Uv+2TnXMvPlPNORf4G3DdXKdPn3a8+9N+bl3Iul5q3cd6rLqrU/cs0ypW91ym+Vn38Wqeb91z/Rvxq+6wZclst1rp6rntJUy/7SXcn+195+p4r59r3cd6rLpnX95cn59P3bN9lvmYT91zmeZn3XN57XzqrpW/kaOn1Urdc6qjJrp6/GZmW90MI9gFnequLNVdWbVYd63UXCstfr9dV+0C5kl1V5bqrqxarLsmalaLX0QkZNTiFxEJGQW/iEjILLrgN7PrzazfzLbP47WbzWybme02s6/atItxmtmHzWynme0ws39Y2Kr9qdvM/tbMeszsUe/2+lqoe9rznzQzZ2bLFq7iqff24/v+vJk97n3Xt5vZihqo+Voze8qr+1Yza1vImn2s+0rvf7FkZgu6M/VE6p3l/a4ys13e7app04/59++r+RxzGuQb8ArgHGD7PF77AHABYMDPgdd50y8FfgXUeY87a6TuvwU+WWvft/fcScAvgeeBZbVQN9AybZ6PAF+vgZpfA8S8+18Gvlwj3/VpwAbgTmBLEOr1all91LQlwDPez3bvfvuxPlslbouuxe+cuwsYnD7NzNaa2S/M7CEzu9vMNh79OjNbTvkf9z5X/q38X+Ct3tP/CfiScy7rLaO/Rur2nY91/yPwnwFfjj7wo27n3Oi0WRsXunafar7dOTd59fH7gVULWbOPdT/pnNu50LWeSL2zeC1wh3Nu0Dk3BNwBXF7t/9tFF/yzuA74sHNuM/BJ4F9mmGclsG/a433eNID1wJ+a2R/M7Hdmdq6v1R52onUDfMjbjL/ezNr9K/UIJ1S3mb0Z6HHOPeZ3oUc54e/bzL5gZnuB9wCf9bHWSQvxNzLpasotz0pYyLorYS71zmQlsHfa48nPUNXPtugvtm5mTcCFwE3TutDqZpp1hmmTLbYY5c2084FzgR+Y2SnemtoXC1T3vwKf9x5/HvgK5X9u35xo3WbWAHyGchdExSzQ941z7jPAZ8zsr4EPAf99gUs9XMgC1ey912eAAvDdhaxxJgtZdyUcq14zex/wUW/aOuBnZpYDnnXOXcHsn6Gqn23RBz/lrZph59zZ0yeaWRR4yHv4Y8ohOX0zdxXQ693fB/zQC/oHzKxEeTCmgSDX7Zzrm/a6/wP8xMd6J51o3WuBNcBj3j/ZKuBhMzvPOXcgwHUf7Ubgp/gY/CxQzd4OxzcCl/nZmJlmob9rv81YL4Bz7gbgBgAzuxN4r3PuuWmz7AMumfZ4FeV9Afuo5mer1M6ESt6A1UzbMQP8HrjSu2/AWbO87kHKrfrJnS2v96Z/EPicd3895U03q4G6l0+b5+PAv9XC933UPM/hw85dn77vU6fN82Hg5hqo+XLgCaDDj+/Y778RfNi5O996mX3n7rOUewzavftL5vLZfP19VGpBFftA8D1gP5CnvFZ9P+UW5C+Ax7w/8s/O8totwHZgD/A1Dp/ZnAC+4z33MPDKGqn7/wHbgMcpt6CW10LdR83zHP4c1ePH932LN/1xyoNlrayBmndTbsg86t0W9EgkH+u+wnuvLNAH/LLa9TJD8HvTr/a+593A+17K379fNw3ZICISMmE5qkdERDwKfhGRkFHwi4iEjIJfRCRkFPwiIiGj4JeaZGapCi/vG2Z2+gK9V9HKI3huN7Pbjjcippm1mdlfLsSyRUBX4JIaZWYp51zTAr5fzB0erMxX02s3s28DTzvnvnCM+VcDP3HObapEfbL4qcUvi4aZdZjZLWb2oHe7yJt+npn93swe8X5u8Ka/18xuMrPbgNvN7BIzu9PMbrbyGPXfnRwj3Zu+xbuf8gZje8zM7jezLm/6Wu/xg2b2uTluldzH4cHpmszs12b2sJXHaX+LN8+XgLXeVsK13ryf8pbzuJn93QJ+jRICCn5ZTP4J+Efn3LnA24FveNOfAl7hnPsjyiNm/v2011wAXOWce6X3+I+AjwGnA6cAF82wnEbgfufcWcBdwAemLf+fvOUfd9wVb2yayyifVQ0wAVzhnDuH8jUgvuKteD4N7HHOne2c+5SZvQY4FTgPOBvYbGavON7yRCaFYZA2CY9XAadPG0GxxcyagVbg22Z2KuUREOPTXnOHc2762OsPOOf2AZjZo5THbLnnqOXkODzg3UPAq737F3B4TPUbgf8xS5310977IcpjtEN5zJa/90K8RHlLoGuG17/Guz3iPW6ivCK4a5bliRxBwS+LSQS4wDmXmT7RzP4Z+K1z7gqvv/zOaU+nj3qP7LT7RWb+H8m7wzvHZpvnWDLOubPNrJXyCuSvgK9SHsO/A9jsnMub2XNAcobXG/BF59z/fonLFQHU1SOLy+2Ux8AHwMwmh9FtBXq8++/1cfn3U+5iAnjX8WZ2zo1QvkTjJ80sTrnOfi/0LwVO9mYdA5qnvfSXwNXeOPGY2Uoz61ygzyAhoOCXWtVgZvum3T5BOUS3eDs8n6A8nDbAPwBfNLN7gaiPNX0M+ISZPQAsB0aO9wLn3COUR3x8F+WLoGwxs62UW/9PefMcAu71Dv+81jl3O+WupPvMbBtwM0euGESOSYdziiwQ7+phGeecM7N3Ae92zr3leK8TqTT18YssnM3A15QytlgAAAAySURBVLwjcYbx+TKXIvOlFr+ISMioj19EJGQU/CIiIaPgFxEJGQW/iEjIKPhFRELm/wOgUtix3Q9b9gAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn_lm.recorder.plot()" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "lr = 2e-2\n", "lr *= bs/48\n", "\n", "wd = 0.1" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losserror_rateaccuracyperplexitytime
025.74298312.5997950.9738910.026109274962.81250000:07
18.6963607.0042830.9472240.052776679.85412600:07
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_lm.fit_one_cycle(2, lr*10, wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "learn_lm.save(f'{lang}fine_tuned1_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')\n", "learn_lm.save_encoder(f'{lang}fine_tuned1_enc_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losserror_rateaccuracyperplexitytime
06.7808136.6219640.9068710.093129438.37872300:10
16.0056375.5471020.7542040.245796127.22343400:10
25.0413894.6886190.6574970.34250348.79218300:10
34.5816654.3501060.6129800.38702032.84767500:10
44.3742494.1864800.5942990.40570127.22125200:10
54.2512934.0906510.5786390.42136124.08844900:10
64.1705604.0338180.5703400.42966022.51815200:10
74.0955213.9557740.5599050.44009520.81150600:10
84.0036193.8986480.5495370.45046319.15136300:10
93.9203323.8308950.5419050.45809518.01042000:10
103.8185333.7586360.5280820.47191816.68132600:10
113.7215273.6967970.5189930.48100715.53545700:10
123.6276983.6362950.5080950.49190514.62882500:10
133.5222733.5821960.4978100.50219013.69802200:10
143.4425103.5401970.4883540.51164613.10172500:10
153.3666053.5120250.4824220.51757812.71911900:10
163.3119713.5016060.4793740.52062612.56303300:10
173.2835843.5003830.4784760.52152412.53784700:10
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_lm.unfreeze()\n", "learn_lm.fit_one_cycle(18, lr, wd=wd, moms=(0.8,0.7), callbacks=[ShowGraph(learn_lm)])" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "learn_lm.save(f'{lang}fine_tuned2_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')\n", "learn_lm.save_encoder(f'{lang}fine_tuned2_enc_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save best LM learner and its encoder" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "learn_lm.save(f'{lang}fine_tuned_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')\n", "learn_lm.save_encoder(f'{lang}fine_tuned_enc_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fine-tuning \"forward Classifier\"" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "bs = 18" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Databunch" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 80 ms, sys: 32 ms, total: 112 ms\n", "Wall time: 678 ms\n" ] } ], "source": [ "%%time\n", "data_lm = load_data(path, f'{lang}_databunch_lm_tcu_jurisp_reduzido_sp15_multifit_v2', bs=bs)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.27 s, sys: 140 ms, total: 1.41 s\n", "Wall time: 2.51 s\n" ] } ], "source": [ "%%time\n", "data_clas = (TextList.from_df(df_trn_val, path, vocab=data_lm.vocab, cols=reviews, processor=SPProcessor.load(dest))\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_from_df(cols=label)\n", " .databunch(bs=bs, num_workers=1))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 188 ms, sys: 28 ms, total: 216 ms\n", "Wall time: 240 ms\n" ] } ], "source": [ "%%time\n", "data_clas.save(f'{lang}_textlist_class_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get weights to penalize loss function of the majority class" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 384 ms, sys: 0 ns, total: 384 ms\n", "Wall time: 384 ms\n" ] } ], "source": [ "%%time\n", "data_clas = load_data(path, f'{lang}_textlist_class_tcu_jurisp_reduzido_sp15_multifit_v2', bs=bs, num_workers=1)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(9237, 1026, 10263)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_trn = len(data_clas.train_ds.x)\n", "num_val = len(data_clas.valid_ds.x)\n", "num_trn, num_val, num_trn+num_val" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([3136, 2439, 2084, 1578]), array([332, 284, 213, 197]))" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trn_LabelCounts = np.unique(data_clas.train_ds.y.items, return_counts=True)[1]\n", "val_LabelCounts = np.unique(data_clas.valid_ds.y.items, return_counts=True)[1]\n", "trn_LabelCounts, val_LabelCounts" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([0.6604958319800801,\n", " 0.7359532315686912,\n", " 0.7743856230377828,\n", " 0.8291653134134459],\n", " [0.6764132553606238,\n", " 0.723196881091618,\n", " 0.7923976608187134,\n", " 0.8079922027290448])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trn_weights = [1 - count/num_trn for count in trn_LabelCounts]\n", "val_weights = [1 - count/num_val for count in val_LabelCounts]\n", "trn_weights, val_weights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training (Loss = FlattenedLoss of weighted LabelSmoothingCrossEntropy)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 344 ms, sys: 24 ms, total: 368 ms\n", "Wall time: 364 ms\n" ] } ], "source": [ "%%time\n", "data_clas = load_data(path, f'{lang}_textlist_class_tcu_jurisp_reduzido_sp15_multifit_v2', bs=bs, num_workers=1)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "config = awd_lstm_clas_config.copy()\n", "config['qrnn'] = True\n", "config['n_hid'] = 1550 #default 1152\n", "config['n_layers'] = 4 #default 3" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "learn_c = text_classifier_learner(data_clas, AWD_LSTM, config=config, pretrained=False, drop_mult=0.3, \n", " metrics=[accuracy,f1]).to_fp16()\n", "learn_c.load_encoder(f'{lang}fine_tuned_enc_tcu_jurisp_reduzido_sp15_multifit_v2');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Change loss function" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FlattenedLoss of CrossEntropyLoss()" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn_c.loss_func" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "loss_weights = torch.FloatTensor(trn_weights).cuda()\n", "learn_c.loss_func = FlattenedLoss(WeightedLabelSmoothingCrossEntropy, weight=loss_weights)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FlattenedLoss of WeightedLabelSmoothingCrossEntropy()" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn_c.loss_func" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Training" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "learn_c.freeze()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn_c.lr_find()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn_c.recorder.plot()" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "lr = 2e-1\n", "lr *= bs/48\n", "\n", "wd = 0.1" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.7493930.6318150.8684210.86985600:08
10.6280980.5428740.9161790.91537400:09
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.fit_one_cycle(2, lr, wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.7168130.9275930.7202730.68378500:09
10.6434620.5515660.9054580.90742300:08
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.fit_one_cycle(2, lr, wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.6723510.6560760.8635480.85765100:09
10.5346140.4619400.9512670.95060800:09
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.freeze_to(-2)\n", "learn_c.fit_one_cycle(2, slice(lr/(2.6**4),lr), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.6235530.5080490.9249510.92404600:12
10.4896350.4474990.9580900.95869000:12
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.freeze_to(-3)\n", "learn_c.fit_one_cycle(2, slice(lr/2/(2.6**4),lr/2), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.4670030.4482500.9629630.96358600:23
10.4710060.4509910.9639380.96233300:21
20.4374630.4311280.9658870.96493400:20
30.4092700.4293230.9697860.96849200:21
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.unfreeze()\n", "learn_c.fit_one_cycle(4, slice(lr/10/(2.6**4),lr/10), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.4128640.4296980.9649120.96398800:20
10.4027810.4273830.9668620.96587100:22
20.4018380.4281620.9688110.96761400:22
30.3992310.4260540.9688110.96787400:20
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.load(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2')\n", "learn_c.fit_one_cycle(4, slice(lr/100/(2.6**4),lr/100), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.4053590.4300130.9717350.97065200:20
10.3899100.4253900.9707600.96953400:20
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.load(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2')\n", "learn_c.fit_one_cycle(2, slice(lr/1000/(2.6**4),lr/1000), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Confusion matrix" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.26 s, sys: 236 ms, total: 1.5 s\n", "Wall time: 2.9 s\n" ] } ], "source": [ "%%time\n", "data_clas = load_data(path, f'{lang}_textlist_class_tcu_jurisp_reduzido_sp15_multifit_v2', bs=bs, num_workers=1);\n", "\n", "config = awd_lstm_clas_config.copy()\n", "config['qrnn'] = True\n", "config['n_hid'] = 1550 #default 1152\n", "config['n_layers'] = 4 #default 3\n", "\n", "learn_c = text_classifier_learner(data_clas, AWD_LSTM, config=config)" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [], "source": [ "learn_c.load(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2', purge=False);" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "preds,y,losses = learn_c.get_preds(with_loss=True)\n", "predictions = np.argmax(preds, axis = 1)\n", "\n", "interp = ClassificationInterpretation(learn_c, preds, y, losses)\n", "interp.plot_confusion_matrix()" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[327 3 0 2]\n", " [ 4 279 0 1]\n", " [ 0 3 206 4]\n", " [ 0 5 8 184]]\n", "accuracy global: 0.9707602339181286\n", "accuracy on class 0: 98.49397590361446\n", "accuracy on class 1: 98.23943661971832\n", "accuracy on class 2: 96.71361502347418\n", "accuracy on class 3: 93.4010152284264\n" ] } ], "source": [ "from sklearn.metrics import confusion_matrix\n", "cm = confusion_matrix(np.array(y), np.array(predictions))\n", "print(cm)\n", "\n", "## acc\n", "print(f'accuracy global: {(cm[0,0]+cm[1,1]+cm[2,2]+cm[3,3])/(cm.sum())}')\n", "\n", "# acc neg, acc pos\n", "print(f'accuracy on class 0: {cm[0,0]/(cm.sum(1)[0])*100}') \n", "print(f'accuracy on class 1: {cm[1,1]/(cm.sum(1)[1])*100}')\n", "print(f'accuracy on class 2: {cm[2,2]/(cm.sum(1)[2])*100}')\n", "print(f'accuracy on class 3: {cm[3,3]/(cm.sum(1)[3])*100}')" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
texttargetprediction
▁xxbos ▁a ▁assistência ▁judicial ▁por ▁parte ▁do ▁xxmaj ▁poder ▁xxmaj ▁público ▁aos ▁membros ▁do ▁xxmaj ▁ministério ▁xxmaj ▁público ▁da ▁xxmaj ▁união , ▁em ▁ações ▁propostas ▁por ▁pessoas ▁física s ▁ou ▁jurídica s ▁por ▁eles ▁investiga das , ▁deverá ▁pa u tar - se ▁pelos ▁seguintes ▁critérios : < ▁/ ▁p >< p > a ) ▁nas ▁hipótese s ▁em ▁que ▁as ▁demanda s ▁judiciais ▁mo vidas ▁contra ▁os ▁membros ▁do11
▁xxbos ▁xxmaj ▁nos ▁processos ▁li cita tório s ▁sob ▁a ▁modalidade ▁pre gão ▁que ▁se ▁de stin em ▁ao ▁fornecimento ▁de ▁bens ▁e ▁serviços ▁comuns ▁de ▁informática ▁e ▁auto ma ção , ▁verifica do ▁empate ▁entre ▁propostas ▁comerciais , ▁a ▁xxmaj ▁administração ▁xxmaj ▁pública ▁xxmaj ▁federal ▁deverá ▁adotar ▁os ▁seguintes ▁procedimento s : ▁i . ▁analisar , ▁primeiro , ▁se ▁algum ▁dos ▁li cita ntes ▁está ▁oferta ndo ▁bem ▁ou11
▁xxbos ▁xxmaj ▁os ▁requisitos ▁a ▁serem ▁pre en chi dos , ▁no ▁momento ▁do ▁ ó bit o ▁do ▁institui dor , ▁para ▁a ▁habilita ção ▁e ▁manutenção ▁da ▁qualidade ▁de ▁filha ▁maior ▁sol t eira , ▁como ▁depende nte ▁de ▁pen são ▁são : ▁a ) ▁ser ▁sol t eira , ▁viúva ▁ou ▁de s qui tada , ▁independente mente ▁da ▁idade ▁( pode ▁ser ▁maior ▁ou ▁menor ▁de00
▁xxbos ▁é ▁permitida ▁a ▁utilização ▁do ▁chama mento ▁público ▁para ▁per m uta ▁de ▁imóveis ▁da ▁xxmaj ▁união ▁como ▁mecanismo ▁de ▁pro spec ção ▁de ▁mercado , ▁para ▁fim ▁de ▁identificar ▁os ▁imóveis ▁ele g íveis ▁que ▁a ten da m ▁às ▁necessidades ▁da ▁xxmaj ▁união , ▁com ▁atendimento ▁aos ▁princípios ▁da ▁im pessoa lidade , ▁moral idade ▁e ▁publicidade , ▁deve ndo , ▁posteriormente , ▁ser ▁utilizadas ▁várias ▁fontes11
▁xxbos ▁xxmaj ▁como ▁alternativa ▁à ▁re ten ção ▁ cau te lar ▁de ▁pagamento s , ▁a ▁xxmaj ▁administração ▁pode ▁pro pi cia r ▁ao ▁contratado ▁a ▁oportunidade ▁de ▁oferecer ▁nova ▁ fia nça ▁bancária ▁ou ▁outra ▁garantia ▁de ▁alta ▁ liquid ez ▁dentre ▁aquela s ▁prevista s ▁no ▁art . ▁ 56 , ▁ § ▁1 o , ▁da ▁xxmaj ▁lei ▁8 . 6 66 ▁/ ▁1993, ▁de ▁a33
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.show_results()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predictions some random sentences" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Category 1, tensor(1), tensor([0.0193, 0.7045, 0.0242, 0.2520]))\n" ] } ], "source": [ "# Get the prediction\n", "test_text = \"A medida cautelar do TCU que determina a suspensão de licitação por falhas no edital não impede o órgão ou a entidade de rever seu ato convocatório, valendo-se do poder de autotutela (art. 49 da Lei 8.666/1993 c/c o art. 9º da Lei 10.520/2002) , com o objetivo de, antecipando-se a eventual deliberação do Tribunal, promover de modo próprio a anulação da licitação e o refazimento do edital, livre dos vícios apontados.\"\n", "pred = learn_c.predict(test_text)\n", "print(pred)" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "data": { "text/html": [ "▁xxbos ▁a ▁medida cau te lar ▁do ▁xxup t cu ▁que ▁determina ▁a ▁suspensão ▁de ▁li cita ção ▁por ▁falha s ▁no ▁e di tal ▁não ▁impede ▁o ▁órgão ▁ou ▁a ▁entidade ▁de ▁rever ▁seu ▁ato ▁convoca tório , ▁vale ndo - se ▁do ▁poder ▁de ▁auto tu te la ▁( art . 49 ▁da ▁xxmaj ▁lei ▁8 . 6 66 ▁/ ▁1993 ▁c ▁/ ▁c ▁o ▁art . ▁9 o ▁da ▁xxmaj ▁lei ▁10 . 5 20 ▁/ ▁2002 ) , ▁com ▁o ▁objetivo ▁de , ▁antecipa ndo - se ▁a ▁eventual ▁de libera ção ▁do ▁xxmaj ▁tribunal , ▁promover ▁de ▁modo ▁próprio ▁a ▁an ulação ▁da ▁li cita ção ▁e ▁o ▁re fa zi mento ▁do ▁e di tal , ▁livre ▁dos ▁v ício s ▁apontado s ." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# The darker the word-shading in the below example, the more it contributes to the classification. \n", "txt_ci = TextClassificationInterpretation.from_learner(learn_c)\n", "txt_ci.show_intrinsic_attention(test_text,cmap=plt.cm.Purples)" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.2395, 0.6145, 0.8318, 0.5050, 0.9821, 1.0000, 0.5648, 0.1259, 0.0754,\n", " 0.1899, 0.2945, 0.1241, 0.0776, 0.1660, 0.1436, 0.2270, 0.1089, 0.1496,\n", " 0.0836, 0.0363, 0.0872, 0.2356, 0.0342, 0.0443, 0.0896, 0.1303, 0.1193,\n", " 0.0666, 0.1301, 0.0322, 0.0787, 0.1167, 0.1554, 0.2903, 0.0753, 0.1617,\n", " 0.1949, 0.4512, 0.6584, 0.0610, 0.0571, 0.1333, 0.0393, 0.0493, 0.0936,\n", " 0.1043, 0.2322, 0.5054, 0.4923, 0.1743, 0.1532, 0.0799, 0.0238, 0.0178,\n", " 0.0152, 0.0184, 0.0459, 0.0170, 0.0167, 0.0335, 0.0389, 0.0949, 0.2626,\n", " 0.2967, 0.0161, 0.0526, 0.0600, 0.0407, 0.0855, 0.0151, 0.0091, 0.0165,\n", " 0.0194, 0.0110, 0.0180, 0.0304, 0.0254, 0.0369, 0.2418, 0.2850, 0.3064,\n", " 0.0873, 0.1301, 0.0248, 0.0222, 0.0110, 0.0131, 0.0415, 0.0622, 0.0540,\n", " 0.0430, 0.1060, 0.0663, 0.0510, 0.0401, 0.0531, 0.0810, 0.1038, 0.2476,\n", " 0.0991, 0.0421, 0.1390, 0.2452, 0.0683, 0.1305, 0.0378, 0.0639, 0.0499,\n", " 0.0519, 0.1165, 0.1615, 0.0988, 0.1658, 0.0335, 0.0099, 0.0145, 0.0193,\n", " 0.0740, 0.2143, 0.2131, 0.1213, 0.0701, 0.0779, 0.0260, 0.0257, 0.0394,\n", " 0.0801, 0.0763, 0.1167, 0.0558, 0.0679, 0.1402, 0.2245, 0.2207],\n", " device='cuda:0')" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "txt_ci.intrinsic_attention(test_text)[1]" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TextPredictionActualLossProbability
▁xxbos ▁a ▁adoção , ▁por ▁conselho ▁de ▁fiscalização ▁profissional , ▁da ▁modalidade ▁de ▁li cita ção ▁convite ▁para ▁a ▁contratação ▁de ▁serviços ▁ad voca t ício s ▁que ▁possam ▁ser ▁considerados ▁como ▁objeto ▁comum ▁ inf ring e ▁o ▁disposto ▁no ▁art . ▁4 o ▁do ▁xxmaj ▁decreto ▁5 . 4 50 ▁/ ▁2005, ▁que ▁determina ▁a ▁utilização ▁do ▁pre gão , ▁prefere ncial mente ▁na ▁forma ▁eletrônica .214.310.03
▁xxbos ▁a ▁concessão ▁de ▁pen são ▁da ▁xxmaj ▁lei ▁3. 37 3 ▁/ ▁1958 ▁a ▁filho ▁maior ▁in vá lido ▁requer ▁ lau do ▁peri cial ▁emitido ▁por ▁junta ▁médica ▁oficial ▁que ▁a tes te ▁a ▁in vali dez ▁e ▁sua ▁pre ex ist ência ▁no ▁momento ▁do ▁ ó bit o ▁do ▁institui dor .103.580.02
▁xxbos ▁xxmaj ▁falecido ▁o ▁responsável , ▁a ▁obriga ção ▁de ▁re para r ▁o ▁da no ▁re ca i ▁sobre ▁o ▁seu ▁ esp ólio ▁ou , ▁caso ▁concluída ▁a ▁partilha , ▁aos ▁sucesso res ▁até ▁o ▁limite ▁do ▁valor ▁do ▁patrimônio ▁transferido . ▁xxmaj ▁ ante ▁o ▁seu ▁caráter ▁ persona l íssimo , ▁a ▁multa ▁não ▁se ▁trans fer e ▁aos ▁sucesso res .223.040.91
▁xxbos ▁a ▁xxmaj ▁administração ▁deve ▁a ten tar ▁para ▁os ▁per cent uais ▁aplicado s ▁de ▁xxup ▁b di ▁sobre ▁serviços , ▁materiais ▁e ▁equipamentos , ▁de ▁forma ▁a ▁ corri gir ▁eventual ▁disto r ção ▁com para tivamente ▁aos ▁preços ▁de ▁mercado , ▁avalia ndo ▁quanto ▁a ▁estes ▁dois ▁últimos ▁( mate ria is ▁e ▁equipamentos ) ▁a ▁possibilidade ▁de ▁ela ▁própria ▁rea vali ar ▁as ▁compra s ▁de213.030.01
▁xxbos ▁é ▁possível ▁a ▁ respons a bil ização ▁de ▁agentes ▁políticos ▁nas ▁hipótese s ▁de ▁( i ) ▁prática ▁de ▁ato ▁administrativo ▁de ▁gestão ▁ou ▁outro ▁ato , ▁o miss ivo ▁ou ▁com is sivo , ▁que ▁esta bel eça ▁corre lação ▁com ▁as ▁irregular idades ▁a pura das ; ▁( ii ) ▁conduta ▁rei ter ada ▁de ▁da no ▁ao ▁e r ário ▁em ▁decorrência ▁da ▁execução ▁deficiente222.970.90
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# tabulation showing the first k texts in top_losses along with their prediction, actual,loss, and probability of actual class.\n", "# max_len is the maximum number of tokens displayed. If max_len=None, it will display all tokens.\n", "txt_ci.show_top_losses(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fine-tuning \"backward Classifier\"" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore') # \"error\", \"ignore\", \"always\", \"default\", \"module\" or \"on" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "bs = 18" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Databunch" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 100 ms, sys: 20 ms, total: 120 ms\n", "Wall time: 501 ms\n" ] } ], "source": [ "%%time\n", "data_lm = load_data(path, f'{lang}_databunch_lm_tcu_jurisp_reduzido_sp15_multifit_bwd_v2', bs=bs, backwards=True)" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.25 s, sys: 420 ms, total: 1.67 s\n", "Wall time: 2.46 s\n" ] } ], "source": [ "%%time\n", "data_clas = (TextList.from_df(df_trn_val, path, cols=reviews, processor=SPProcessor.load(dest), vocab=data_lm.vocab)\n", " .split_by_rand_pct(0.1, seed=42)\n", " .label_from_df(cols=label)\n", " .databunch(bs=bs, num_workers=1, backwards=True))" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 176 ms, sys: 16 ms, total: 192 ms\n", "Wall time: 193 ms\n" ] } ], "source": [ "%%time\n", "data_clas.save(f'{lang}_textlist_class_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get weights to penalize loss function of the majority class" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 420 ms, sys: 4 ms, total: 424 ms\n", "Wall time: 398 ms\n" ] } ], "source": [ "%%time\n", "data_clas = load_data(path, f'{lang}_textlist_class_tcu_jurisp_reduzido_sp15_multifit_bwd_v2', bs=bs, num_workers=1, backwards=True)" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(9237, 1026, 10263)" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_trn = len(data_clas.train_ds.x)\n", "num_val = len(data_clas.valid_ds.x)\n", "num_trn, num_val, num_trn+num_val" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([3136, 2439, 2084, 1578]), array([332, 284, 213, 197]))" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trn_LabelCounts = np.unique(data_clas.train_ds.y.items, return_counts=True)[1]\n", "val_LabelCounts = np.unique(data_clas.valid_ds.y.items, return_counts=True)[1]\n", "trn_LabelCounts, val_LabelCounts" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([0.6604958319800801,\n", " 0.7359532315686912,\n", " 0.7743856230377828,\n", " 0.8291653134134459],\n", " [0.6764132553606238,\n", " 0.723196881091618,\n", " 0.7923976608187134,\n", " 0.8079922027290448])" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trn_weights = [1 - count/num_trn for count in trn_LabelCounts]\n", "val_weights = [1 - count/num_val for count in val_LabelCounts]\n", "trn_weights, val_weights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training (Loss = FlattenedLoss of weighted LabelSmoothingCrossEntropy)" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 404 ms, sys: 0 ns, total: 404 ms\n", "Wall time: 378 ms\n" ] } ], "source": [ "%%time\n", "data_clas = load_data(path, f'{lang}_textlist_class_tcu_jurisp_reduzido_sp15_multifit_bwd_v2', bs=bs, num_workers=1, backwards=True)" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [], "source": [ "config = awd_lstm_clas_config.copy()\n", "config['qrnn'] = True\n", "config['n_hid'] = 1550 #default 1152\n", "config['n_layers'] = 4 #default 3" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [], "source": [ "learn_c = text_classifier_learner(data_clas, AWD_LSTM, config=config, drop_mult=0.3, metrics=[accuracy,f1]).to_fp16()\n", "learn_c.load_encoder(f'{lang}fine_tuned_enc_tcu_jurisp_reduzido_sp15_multifit_bwd_v2');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Change loss function" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FlattenedLoss of CrossEntropyLoss()" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn_c.loss_func" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "loss_weights = torch.FloatTensor(trn_weights).cuda()\n", "learn_c.loss_func = FlattenedLoss(WeightedLabelSmoothingCrossEntropy, weight=loss_weights)" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FlattenedLoss of WeightedLabelSmoothingCrossEntropy()" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn_c.loss_func" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Training" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [], "source": [ "learn_c.freeze()" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn_c.lr_find()" ] }, { "cell_type": "code", "execution_count": 97, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn_c.recorder.plot()" ] }, { "cell_type": "code", "execution_count": 98, "metadata": {}, "outputs": [], "source": [ "lr = 2e-1\n", "lr *= bs/48\n", "\n", "wd = 0.1" ] }, { "cell_type": "code", "execution_count": 99, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.7661120.7900790.7816760.78084700:09
10.6339760.5672820.8986360.89865700:08
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.fit_one_cycle(2, lr, wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 100, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "code", "execution_count": 101, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.7572860.7065000.8421050.83355600:08
10.6569970.5585890.8966860.89744500:09
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.fit_one_cycle(2, lr, wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "code", "execution_count": 103, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.6728920.7616020.8099420.81764400:09
10.5291950.4705510.9493180.95064000:10
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.freeze_to(-2)\n", "learn_c.fit_one_cycle(2, slice(lr/(2.6**4),lr), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 104, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "code", "execution_count": 105, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.6355780.5552200.9064330.90687200:12
10.5471000.4638240.9463940.94752400:12
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.freeze_to(-3)\n", "learn_c.fit_one_cycle(2, slice(lr/2/(2.6**4),lr/2), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.5128730.4476830.9551660.95531600:22
10.5033540.4616610.9541910.95333800:20
20.4649760.4306270.9658870.96526700:21
30.4309500.4344050.9658870.96566300:21
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.unfreeze()\n", "learn_c.fit_one_cycle(4, slice(lr/10/(2.6**4),lr/10), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 108, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.4416120.4325740.9658870.96528000:21
10.4220670.4290310.9678360.96726500:20
20.4308730.4323980.9707600.97049300:22
30.4277650.4336430.9668620.96660800:22
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.fit_one_cycle(4, slice(lr/100/(2.6**4),lr/100), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 110, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.4245080.4297260.9688110.96870300:21
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.load(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')\n", "learn_c.fit_one_cycle(1, slice(lr/1000/(2.6**4),lr/1000), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyf1time
00.4280170.4311080.9707600.97049300:20
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.fit_one_cycle(1, slice(lr/1000/(2.6**4),lr/1000), wd=wd, moms=(0.8,0.7))" ] }, { "cell_type": "code", "execution_count": 117, "metadata": {}, "outputs": [], "source": [ "learn_c.save(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "code", "execution_count": 118, "metadata": {}, "outputs": [], "source": [ "learn_c.load(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2');\n", "learn_c.to_fp32().export(f'{lang}_classifier_tcu_jurisp_reduzido_sp15_multifit_bwd_v2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Confusion matrix" ] }, { "cell_type": "code", "execution_count": 119, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 1.83 s, sys: 176 ms, total: 2.01 s\n", "Wall time: 907 ms\n" ] } ], "source": [ "%%time\n", "data_clas = load_data(path, f'{lang}_textlist_class_tcu_jurisp_reduzido_sp15_multifit_bwd_v2', bs=bs, num_workers=1, backwards=True)\n", "\n", "config = awd_lstm_clas_config.copy()\n", "config['qrnn'] = True\n", "config['n_hid'] = 1550 #default 1152\n", "config['n_layers'] = 4 #default 3\n", "\n", "learn_c = text_classifier_learner(data_clas, AWD_LSTM, config=config)" ] }, { "cell_type": "code", "execution_count": 120, "metadata": {}, "outputs": [], "source": [ "learn_c.load(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2', purge=False);" ] }, { "cell_type": "code", "execution_count": 121, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "preds,y,losses = learn_c.get_preds(with_loss=True)\n", "predictions = np.argmax(preds, axis = 1)\n", "\n", "interp = ClassificationInterpretation(learn_c, preds, y, losses)\n", "interp.plot_confusion_matrix()" ] }, { "cell_type": "code", "execution_count": 122, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[329 2 0 1]\n", " [ 1 278 5 0]\n", " [ 1 2 206 4]\n", " [ 1 1 12 183]]\n", "accuracy global: 0.9707602339181286\n", "accuracy on class 0: 99.09638554216868\n", "accuracy on class 1: 97.88732394366197\n", "accuracy on class 2: 96.71361502347418\n", "accuracy on class 3: 92.89340101522842\n" ] } ], "source": [ "from sklearn.metrics import confusion_matrix\n", "cm = confusion_matrix(np.array(y), np.array(predictions))\n", "print(cm)\n", "\n", "## acc\n", "print(f'accuracy global: {(cm[0,0]+cm[1,1]+cm[2,2]+cm[3,3])/(cm.sum())}')\n", "\n", "# acc neg, acc pos\n", "print(f'accuracy on class 0: {cm[0,0]/(cm.sum(1)[0])*100}') \n", "print(f'accuracy on class 1: {cm[1,1]/(cm.sum(1)[1])*100}')\n", "print(f'accuracy on class 2: {cm[2,2]/(cm.sum(1)[2])*100}')\n", "print(f'accuracy on class 3: {cm[3,3]/(cm.sum(1)[3])*100}')" ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
texttargetprediction
. ▁ ) ▁1993 ▁/ 66 6 . ▁8 ▁lei ▁xxmaj ▁da o ▁3 . art ▁( ” ▁administração ▁xxmaj ▁a ▁para a jos ta ▁van ▁mais ▁proposta ▁a r ciona le se ▁“ ▁a ▁modo ▁de , s ▁interessado uais vent ▁e ▁os ▁entre ” mia no ▁iso ▁da ▁constitucional ▁princípio ▁do ância v er ▁obs ▁a r anti gar ▁“ ▁necessariamente ▁deverá ▁administração ▁xxmaj ▁a , ▁questão ▁em11
▁2002. ▁/ 20 5 . ▁10 ▁lei ▁xxmaj ▁da o ▁9 . ▁art ▁do ▁força ▁por gão ▁pre ▁xxmaj ▁ao mente aria idi s ▁sub ável c pli ▁a ▁1993, ▁/ 66 6 . ▁8 ▁lei ▁xxmaj ▁da , o ▁2 § ▁ , ▁45 . ▁art ▁no ▁disposto ▁o ▁observado , ▁público ▁interesse ▁ao á ▁atender ▁que ▁oferta ▁da ▁sorteio ▁ao r ▁procede ▁1991, ▁/ 48 2 . ▁811
. são ▁pen ▁de ▁benefício ▁do ▁econômica ▁dependência ▁a ize ter ac scar ▁de ▁que ▁renda r feri u ▁a ) ▁c ▁ou ; ▁ocupação ▁dessa corrente ▁de ▁aposentadoria ▁receber ▁ou , ▁indireta ▁xxmaj ▁ou ▁direta ▁xxmaj ▁pública ▁xxmaj ▁administração ▁xxmaj ▁na ▁efetivo ▁cargo ▁ocupar ) ▁b ; ▁estável ▁união ▁de ▁situação ▁na ▁encontrar ▁se ▁ou ▁casamento ído ▁contra ▁ter ) ▁a : ▁benefício ▁do ▁percepção ▁à ▁direito ▁do tivas00
. ▁escolhida ▁opção ▁a ▁para ção ▁motiva ▁adequada ▁a ▁observar ndo ▁deve ▁1998, ▁/ 36 6 . ▁9 ▁lei ▁xxmaj ▁da , o ▁2 § ▁ , ▁30 . ▁art ▁do ▁e ▁1993 ▁/ 66 6 . ▁8 ▁lei ▁xxmaj ▁da , ▁i ▁xxup so inci ▁ , ▁17 . ▁art ▁do ▁termos ▁nos , tório cita ▁li ▁procedimento ▁o ▁realizar ▁ou ▁1993, ▁/ 66 6 . ▁8 ▁lei ▁xxmaj11
. ▁vigor ▁em mente ida ▁val ▁estiver ▁alternativa ▁medida ▁a ▁enquanto ▁eficácia ▁sua e ▁suspend ▁mas , s ▁pagamento ▁de ção ten ▁re ▁de lar te cau ▁ ▁a ▁revoga ▁não ▁referência ▁em ▁alternativa ▁medida ▁da ▁adoção ▁a , ▁casos s ▁nesse ▁xxmaj . ▁valores ir titu res ▁ ▁a ▁contratada ▁empresa ▁a ne ▁conde ▁que ▁tribunal ▁xxmaj ▁deste dão r có ▁a ▁eventual ▁de ▁julgado ▁em ▁trânsito ▁o ▁após33
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn_c.show_results()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predictions some random sentences" ] }, { "cell_type": "code", "execution_count": 124, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Category 3, tensor(3), tensor([0.0095, 0.1228, 0.0891, 0.7785]))\n" ] } ], "source": [ "# Get the prediction\n", "test_text = \"A medida cautelar do TCU que determina a suspensão de licitação por falhas no edital não impede o órgão ou a entidade de rever seu ato convocatório, valendo-se do poder de autotutela (art. 49 da Lei 8.666/1993 c/c o art. 9º da Lei 10.520/2002) , com o objetivo de, antecipando-se a eventual deliberação do Tribunal, promover de modo próprio a anulação da licitação e o refazimento do edital, livre dos vícios apontados.\"\n", "pred = learn_c.predict(test_text)\n", "print(pred)" ] }, { "cell_type": "code", "execution_count": 125, "metadata": {}, "outputs": [ { "data": { "text/html": [ ". s ▁apontado s ício ▁v ▁dos ▁livre , tal di ▁e ▁do mento zi fa ▁re ▁o ▁e ção cita ▁li ▁da ulação ▁an ▁a ▁próprio ▁modo ▁de ▁promover , ▁tribunal ▁xxmaj ▁do ção libera ▁de ▁eventual ▁a se - ndo ▁antecipa , ▁de ▁objetivo ▁o ▁com , ) ▁2002 ▁/ 20 5 . ▁10 ▁lei ▁xxmaj ▁da o ▁9 . ▁art ▁o ▁c ▁/ ▁c ▁1993 ▁/ 66 6 . ▁8 ▁lei ▁xxmaj ▁da 49 . art ▁( la te tu ▁auto ▁de ▁poder ▁do se - ndo ▁vale , tório ▁convoca ▁ato ▁seu ▁rever ▁de ▁entidade ▁a ▁ou ▁órgão ▁o ▁impede ▁não tal di ▁e ▁no s ▁falha ▁por ção cita ▁li ▁de ▁suspensão ▁a ▁determina ▁que cu t ▁xxup ▁do lar te cau ▁medida ▁a ▁xxbos" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# The darker the word-shading in the below example, the more it contributes to the classification. \n", "txt_ci = TextClassificationInterpretation.from_learner(learn_c)\n", "txt_ci.show_intrinsic_attention(test_text,cmap=plt.cm.Purples)" ] }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.0440, 0.0597, 0.1282, 0.1160, 0.2387, 0.2270, 0.0274, 0.0446, 0.0321,\n", " 0.0675, 0.0699, 0.0328, 0.0472, 0.1784, 0.3785, 0.1595, 0.0891, 0.0143,\n", " 0.0246, 0.0554, 0.0706, 0.0741, 0.0281, 0.0738, 0.0480, 0.0486, 0.1095,\n", " 0.0978, 0.0815, 0.1162, 0.1015, 0.2267, 0.1001, 0.0349, 0.1285, 0.2844,\n", " 0.0148, 0.0181, 0.0062, 0.0109, 0.0138, 0.0493, 0.0709, 0.0110, 0.0093,\n", " 0.0148, 0.0079, 0.0101, 0.0200, 0.0354, 0.0605, 0.0379, 0.0420, 0.0856,\n", " 0.0557, 0.0167, 0.0172, 0.0105, 0.0098, 0.0284, 0.0583, 0.0142, 0.0201,\n", " 0.0193, 0.0110, 0.0186, 0.0167, 0.0271, 0.0528, 0.0899, 0.1545, 0.0456,\n", " 0.0219, 0.0315, 0.0358, 0.0052, 0.0151, 0.0303, 0.0187, 0.0293, 0.0269,\n", " 0.0546, 0.1531, 0.1788, 0.1044, 0.1432, 0.0371, 0.0347, 0.0338, 0.0615,\n", " 0.0187, 0.0492, 0.0584, 0.0502, 0.1698, 0.3144, 0.1714, 0.0701, 0.1377,\n", " 0.0382, 0.0842, 0.0190, 0.0597, 0.1201, 0.0680, 0.1715, 0.0999, 0.2089,\n", " 0.1084, 0.0650, 0.0587, 0.1827, 0.4336, 0.0451, 0.0576, 0.0937, 0.0808,\n", " 0.1604, 0.3141, 0.0955, 0.1888, 0.0840, 0.1050, 0.0788, 0.0690, 0.1024,\n", " 0.2974, 1.0000, 0.9221, 0.4872, 0.3569, 0.2507, 0.0954, 0.1546],\n", " device='cuda:0')" ] }, "execution_count": 126, "metadata": {}, "output_type": "execute_result" } ], "source": [ "txt_ci.intrinsic_attention(test_text)[1]" ] }, { "cell_type": "code", "execution_count": 127, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TextPredictionActualLossProbability
▁xxbos ▁xxmaj ▁consideram - se ▁i liquid áveis ▁as ▁conta s , ▁ordena ndo - se ▁o ▁seu ▁tran ca mento , ▁em ▁razão ▁da ▁imp os s ibilidade ▁do ▁exercício ▁de ▁ampla ▁defesa , ▁pelo ▁longo ▁de curso ▁de ▁tempo ▁entre ▁a ▁prática ▁do ▁ato ▁e ▁a ▁cita ção ▁do ▁responsável .333.100.91
▁xxbos ▁xxmaj ▁como ▁regra ▁geral , ▁sujeita ▁a ▁po nder ação ▁no ▁caso ▁concreto , ▁o ▁parcela mento ▁do ▁objeto ▁deve ▁ser ▁adotado ▁na ▁contratação ▁de ▁serviços ▁de ▁maior ▁especialização ▁técnica , ▁sendo ▁de s ne ces s ário ▁nos ▁serviços ▁de ▁menor ▁especialização .112.760.81
▁xxbos ▁xxmaj ▁falecido ▁o ▁responsável , ▁a ▁obriga ção ▁de ▁re para r ▁o ▁da no ▁re ca i ▁sobre ▁o ▁seu ▁ esp ólio ▁ou , ▁caso ▁concluída ▁a ▁partilha , ▁aos ▁sucesso res ▁até ▁o ▁limite ▁do ▁valor ▁do ▁patrimônio ▁transferido . ▁xxmaj ▁ ante ▁o ▁seu ▁caráter ▁ persona l íssimo , ▁a ▁multa ▁não ▁se ▁trans fer e ▁aos ▁sucesso res .222.640.86
▁xxbos ▁xxmaj ▁ao ▁firma r ▁termo ▁de ▁parceria ▁com ▁xxmaj ▁os cip ▁que ▁em ▁a ve nça ▁anterior ▁deixou ▁de ▁obedece r ▁normas ▁técnicas ▁na ▁execução ▁de ▁projeto ▁semelhante ▁e ▁de ▁mesma ▁natureza , ▁apresentando ▁erros ▁graves ▁na ▁presta ção ▁dos ▁serviços , ▁o ▁gesto r ▁assume ▁o ▁risco ▁de ▁in sucesso ▁e ▁de ▁prejuízo ▁ao ▁e r ário , ▁responde ndo ▁sol ida ria mente ▁pelo ▁da no .122.610.27
▁xxbos ▁a ▁apresentação ▁de ▁presta ção ▁conta s ▁fora ▁de ▁prazo ▁ ajusta do , ▁sem ▁justifica tiva ▁para ▁a ▁falta , ▁não ▁e li de ▁a ▁respectiva ▁o missão .322.610.07
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# tabulation showing the first k texts in top_losses along with their prediction, actual,loss, and probability of actual class.\n", "# max_len is the maximum number of tokens displayed. If max_len=None, it will display all tokens.\n", "txt_ci.show_top_losses(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Ensemble" ] }, { "cell_type": "code", "execution_count": 128, "metadata": {}, "outputs": [], "source": [ "bs = 18" ] }, { "cell_type": "code", "execution_count": 129, "metadata": {}, "outputs": [], "source": [ "config = awd_lstm_clas_config.copy()\n", "config['qrnn'] = True\n", "config['n_hid'] = 1550 #default 1152\n", "config['n_layers'] = 4 #default 3" ] }, { "cell_type": "code", "execution_count": 130, "metadata": {}, "outputs": [], "source": [ "data_clas = load_data(path, f'{lang}_textlist_class_tcu_jurisp_reduzido_sp15_multifit_v2', bs=bs, num_workers=1)\n", "learn_c = text_classifier_learner(data_clas, AWD_LSTM, config=config, drop_mult=0.3, metrics=[accuracy,f1]).to_fp16()\n", "learn_c.load(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_v2', purge=False);" ] }, { "cell_type": "code", "execution_count": 131, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.9708), tensor(0.9707))" ] }, "execution_count": 131, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds,targs = learn_c.get_preds(ordered=True)\n", "accuracy(preds,targs),f1(preds,targs)" ] }, { "cell_type": "code", "execution_count": 132, "metadata": {}, "outputs": [], "source": [ "data_clas_bwd = load_data(path, f'{lang}_textlist_class_tcu_jurisp_reduzido_sp15_multifit_bwd_v2', bs=bs, num_workers=1, backwards=True)\n", "learn_c_bwd = text_classifier_learner(data_clas_bwd, AWD_LSTM, config=config, drop_mult=0.3, metrics=[accuracy,f1]).to_fp16()\n", "learn_c_bwd.load(f'{lang}clas_tcu_jurisp_reduzido_sp15_multifit_bwd_v2', purge=False);" ] }, { "cell_type": "code", "execution_count": 133, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.9708), tensor(0.9708))" ] }, "execution_count": 133, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds_b,targs_b = learn_c_bwd.get_preds(ordered=True)\n", "accuracy(preds_b,targs_b),f1(preds_b,targs_b)" ] }, { "cell_type": "code", "execution_count": 134, "metadata": {}, "outputs": [], "source": [ "preds_avg = (preds+preds_b)/2" ] }, { "cell_type": "code", "execution_count": 135, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.9795), tensor(0.9795))" ] }, "execution_count": 135, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy(preds_avg,targs_b),f1(preds_avg,targs_b)" ] }, { "cell_type": "code", "execution_count": 136, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[330 2 0 0]\n", " [ 0 282 2 0]\n", " [ 0 1 207 5]\n", " [ 0 2 9 186]]\n", "accuracy global: 0.97953216374269\n", "accuracy on class 0: 99.3975903614458\n", "accuracy on class 1: 99.29577464788733\n", "accuracy on class 2: 97.1830985915493\n", "accuracy on class 3: 94.41624365482234\n" ] } ], "source": [ "from sklearn.metrics import confusion_matrix\n", "\n", "predictions = np.argmax(preds_avg, axis = 1)\n", "cm = confusion_matrix(np.array(targs_b), np.array(predictions))\n", "print(cm)\n", "\n", "## acc\n", "print(f'accuracy global: {(cm[0,0]+cm[1,1]+cm[2,2]+cm[3,3])/(cm.sum())}')\n", "\n", "# acc neg, acc pos\n", "print(f'accuracy on class 0: {cm[0,0]/(cm.sum(1)[0])*100}') \n", "print(f'accuracy on class 1: {cm[1,1]/(cm.sum(1)[1])*100}')\n", "print(f'accuracy on class 2: {cm[2,2]/(cm.sum(1)[2])*100}')\n", "print(f'accuracy on class 3: {cm[3,3]/(cm.sum(1)[3])*100}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }