{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "ia8cA7knqvOV" }, "source": [ "# Hugging Face transformers を使って日本語 BERT モデルをファインチューニングして感情分析 (with google colab) part01\n", "\n", "本記事では、日本語 BERT モデルをファインチューニングして感情分析する方法を解説します。\n", "\n", "BERT の詳細な解説は、この記事のスコープ外とします。\n", "\n", "この記事は、part01 です。\n", "\n", "[part02](https://jupyterbook.hnishi.com/language-models/fine_tune_jp_bert_part02.html) では、まとまったデータセットを使って実際に学習と評価を行っています。\n", "\n", "## 補足\n", "\n", "**Hugging Face Transformers とは**\n", "\n", "[Hugging Face 社によって開発されている、言語モデルを容易に扱うための OSS ライブラリ](https://huggingface.co/docs/transformers/index) です。\n", "\n", "また同社によって、[学習済みモデル](https://huggingface.co/models)や[データセット](https://huggingface.co/datasets)を公開するリポジトリが提供されています。\n", "\n", "**BERT**\n", "\n", "[BERT](https://arxiv.org/abs/1810.04805) とは、Bidirectional Encoder Representations from Transformersの略称で、Googleが開発した自然言語処理のモデルです。\n", "\n", "ラベルのないテキストから文章の中の単語やフレーズの意味や関係性を事前学習し、出力層を1つ追加してファインチューニングを行うことで、幅広いタスクに対して性能の良いモデルを作成できます。\n", "\n", "## 参考\n", "\n", "- [Hugging Face Transformers ドキュメント](https://huggingface.co/transformers/)\n", "- [BERT 論文](https://arxiv.org/abs/1810.04805)\n", "- [Fine-tuning a BERT model with transformers](https://towardsdatascience.com/fine-tuning-a-bert-model-with-transformers-c8e49c4e008b)" ] }, { "cell_type": "markdown", "metadata": { "id": "r5Z_zhsNqvOb" }, "source": [ "## 必要なライブラリのインストール" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-QtQP6vfqvOb" }, "outputs": [], "source": [ "!pip install -q transformers fugashi[unidic-lite]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C9UCrH2nWSoE" }, "outputs": [], "source": [ "import pandas as pd\n", "import torch\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline\n", "from torch.optim import AdamW" ] }, { "cell_type": "markdown", "metadata": { "id": "mv3oF4c_bRyf" }, "source": [ "## 日本語 BERT の簡単なチュートリアル\n", "\n", "最初に、huggingface transformers を使った日本語 BERT pre-trained model の使い方や fine tuning の方法を、見ていきます。\n", "\n", "今回試す事前学習済みモデルとして、東北大学のグループによって公開されているものを利用します。\n", "\n", "**参考**\n", "\n", "- https://huggingface.co/cl-tohoku\n", "- https://github.com/cl-tohoku/bert-japanese" ] }, { "cell_type": "markdown", "metadata": { "id": "CiontpcdjG6G" }, "source": [ "### Pre-trained Model を使って推論\n", "\n", "BERT モデルは、mask された token (`[MASK]`) を予測するように学習されています。\n", "\n", "したがって、pre-trained model を使って、文章中の穴埋め (文章中の欠損箇所の予測) を\n", "行えます。\n", "\n", "以下の2種類のモデルを使って推論を試して、結果を比較してみましょう。\n", "\n", "- [cl-tohoku/bert-large-japanese](https://huggingface.co/cl-tohoku/bert-large-japanese)\n", "- [bert-base-multilingual-uncased](https://huggingface.co/bert-base-multilingual-uncased) (BERT の多言語モデル)\n", "\n", "[pipeline](https://huggingface.co/docs/transformers/main_classes/pipelines) に `fill-mask` タスクを指定することで、簡単に試すことができます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BAGc2FbAiawP", "outputId": "850b07e4-d258-4f9b-8b98-39e5252b6a20", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Some weights of the model checkpoint at cl-tohoku/bert-large-japanese were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[{'score': 0.04221104457974434,\n", " 'token': 32474,\n", " 'token_str': 'サラダ',\n", " 'sequence': '今日 の 昼食 は サラダ でし た 。'},\n", " {'score': 0.036806173622608185,\n", " 'token': 18526,\n", " 'token_str': 'カレー',\n", " 'sequence': '今日 の 昼食 は カレー でし た 。'},\n", " {'score': 0.0313434936106205,\n", " 'token': 31893,\n", " 'token_str': 'ご飯',\n", " 'sequence': '今日 の 昼食 は ご飯 でし た 。'},\n", " {'score': 0.021632177755236626,\n", " 'token': 17540,\n", " 'token_str': '元気',\n", " 'sequence': '今日 の 昼食 は 元気 でし た 。'},\n", " {'score': 0.020115602761507034,\n", " 'token': 23869,\n", " 'token_str': 'うどん',\n", " 'sequence': '今日 の 昼食 は うどん でし た 。'}]" ] }, "metadata": {}, "execution_count": 3 } ], "source": [ "model_name = \"cl-tohoku/bert-large-japanese\"\n", "\n", "unmasker = pipeline('fill-mask', model=model_name)\n", "unmasker(\"今日の昼食は[MASK]でした。\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WpEtuqseh5H_", "outputId": "dc9788fa-b063-4d8e-c717-83c572681b4a", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[{'score': 0.17987696826457977,\n", " 'token': 7753,\n", " 'token_str': '見',\n", " 'sequence': '今 日 の 昼 食 は 見 てした 。'},\n", " {'score': 0.06706605106592178,\n", " 'token': 4080,\n", " 'token_str': '捨',\n", " 'sequence': '今 日 の 昼 食 は 捨 てした 。'},\n", " {'score': 0.06436670571565628,\n", " 'token': 2073,\n", " 'token_str': '全',\n", " 'sequence': '今 日 の 昼 食 は 全 てした 。'},\n", " {'score': 0.060412339866161346,\n", " 'token': 5216,\n", " 'token_str': '満',\n", " 'sequence': '今 日 の 昼 食 は 満 てした 。'},\n", " {'score': 0.02542056515812874,\n", " 'token': 4518,\n", " 'token_str': '果',\n", " 'sequence': '今 日 の 昼 食 は 果 てした 。'}]" ] }, "metadata": {}, "execution_count": 4 } ], "source": [ "model_name = \"bert-base-multilingual-uncased\"\n", "\n", "unmasker = pipeline('fill-mask', model=model_name)\n", "unmasker(\"今日の昼食は[MASK]でした。\")" ] }, { "cell_type": "markdown", "source": [ "**課題**\n", "\n", "`unmasker` に別の文章を渡して推論させてみましょう。" ], "metadata": { "id": "J2snK4rM7b6N" } }, { "cell_type": "markdown", "source": [ "## Fine Tuning 前の感情分析\n", "\n", "感情分析タスクは [\"sentiment-analysis\"](https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.TextClassificationPipeline) を `pipeline` のタスクに指定することで行えます。\n", "\n", "pipeline で利用可能なタスク一覧は、[Hugging Face のドキュメント](https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/pipelines) で確認できます。\n", "\n", "ただし、モデルが指定されたタスクに対応している必要があります。" ], "metadata": { "id": "dKsrZJem9LST" } }, { "cell_type": "markdown", "metadata": { "id": "8wZQLyH7l97y" }, "source": [ "## テキスト分類のための Fine Tuning の手順\n", "\n", "本来であれば、もっと大規模な学習データセットを用意するべきですが、ここでは説明の簡素化のために、単純なサンプルを手作業で作成して手順を確認します。\n", "\n", "以下のように、3 種類のラベル (positive: 2, neutral: 1, negative: 0) のデータを用意します。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gzvZ0w_MbRyg", "outputId": "fc6e784d-0d14-4bb2-9880-9cbbf2134ffc", "colab": { "base_uri": "https://localhost:8080/", "height": 143 } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " text label\n", "0 私はこの映画をみることができて、とても嬉しい。 POSITIVE\n", "1 今日の晩御飯は何だろう。 NEUTRAL\n", "2 猫に足を噛まれて痛い。 NEGATIVE" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textlabel
0私はこの映画をみることができて、とても嬉しい。POSITIVE
1今日の晩御飯は何だろう。NEUTRAL
2猫に足を噛まれて痛い。NEGATIVE
\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 5 } ], "source": [ "# 確認用のデータセット\n", "df = pd.DataFrame(\n", " [\n", " {\"text\": \"私はこの映画をみることができて、とても嬉しい。\", \"label\": \"POSITIVE\"},\n", " {\"text\": \"今日の晩御飯は何だろう。\", \"label\": \"NEUTRAL\"},\n", " {\"text\": \"猫に足を噛まれて痛い。\", \"label\": \"NEGATIVE\"}\n", " ]\n", ")\n", "df" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "A1a3YPCtbRyi" }, "outputs": [], "source": [ "train_docs = df[\"text\"].tolist()\n", "train_labels = df[\"label\"].tolist()" ] }, { "cell_type": "markdown", "metadata": { "id": "0Ur522Vee2VN" }, "source": [ "## 学習\n", "\n", "同時にダウンロードされるトークナイザーを利用して、データセットの text の encoding を行います。\n", "\n", "**参考**\n", "\n", "- https://huggingface.co/transformers/training.html#pytorch\n", "- https://huggingface.co/docs/transformers/tasks/sequence_classification\n", "- https://huggingface.co/transformers/v4.4.2/custom_datasets.html" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bflrko3NbRyj", "outputId": "36f65536-a91d-497d-b563-a59307d1318d", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Some weights of the model checkpoint at cl-tohoku/bert-large-japanese were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-large-japanese and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model_name = \"cl-tohoku/bert-large-japanese\"\n", "\n", "id2label = {0: \"NEGATIVE\", 1: \"NEUTRAL\", 2: \"POSITIVE\"}\n", "label2id = {\"NEGATIVE\": 0, \"NEUTRAL\": 1, \"POSITIVE\": 2}\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3, id2label=id2label, label2id=label2id)\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B_EAHs-bmSJ3" }, "outputs": [], "source": [ "encodings = tokenizer(train_docs, return_tensors='pt', padding=True, truncation=True, max_length=128)\n", "input_ids = encodings['input_ids']\n", "attention_mask = encodings['attention_mask']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0QqtCEn1bRyk" }, "outputs": [], "source": [ "# Fine-tuning in native PyTorch\n", "\n", "# the AdamW() optimizer which implements gradient bias correction as well as weight decay.\n", "optimizer = AdamW(model.parameters(), lr=1e-5)\n", "\n", "labels = [label2id[label] for label in train_labels]\n", "labels = torch.tensor(labels).unsqueeze(0)\n", "outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n", "loss = outputs.loss\n", "loss.backward()\n", "optimizer.step()" ] }, { "cell_type": "markdown", "metadata": { "id": "c4RT2SwnbRyk" }, "source": [ "## Fine Tune したモデルで推論\n", "\n", "学習データ量が少ないので性能に期待できませんが、推論の手順を確認します。\n", "\n", "以下のように、推論できます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cW2vLoZGbRyl" }, "outputs": [], "source": [ "sentiment_analyzer = pipeline(\"sentiment-analysis\", model=model, tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "d1q--QxrqPf4", "outputId": "672ee102-df84-4768-8390-ac60b9f229cf", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[{'label': 'POSITIVE', 'score': 0.41579246520996094}]" ] }, "metadata": {}, "execution_count": 11 } ], "source": [ "sentiment_analyzer(\"これは、テストのための文章です\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SDUlagMbbRym", "outputId": "ae6517ca-b733-4a13-8716-14633cc9f854", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "私はこの映画をみることができて、とても嬉しい。: [{'label': 'POSITIVE', 'score': 0.5623015761375427}]\n", "今日の晩御飯は何だろう。: [{'label': 'NEUTRAL', 'score': 0.4257284700870514}]\n", "猫に足を噛まれて痛い。: [{'label': 'NEGATIVE', 'score': 0.586097776889801}]\n" ] } ], "source": [ "# 学習データに対する推論\n", "_ = list(map(lambda x: print(f\"{x}: {sentiment_analyzer(x)}\"), train_docs))" ] }, { "cell_type": "markdown", "source": [ "**課題**\n", "\n", "`sentiment_analyzer` に任意の文章を渡して推論してみましょう。\n", "\n" ], "metadata": { "id": "XF_EgNMgbiEj" } }, { "cell_type": "markdown", "metadata": { "id": "Zj4HfgSIXYHu" }, "source": [ "## まとめ\n", "\n", "簡単な文章とラベルを用意して fine tuning する方法を記載しました。\n", "\n", "[次の記事](https://jupyterbook.hnishi.com/language-models/fine_tune_jp_bert_part02.html) では、より大きなデータセットを使って、より時間のかかる学習を試してみたいと思います。" ] } ], "metadata": { "colab": { "name": "fine-tune-jp-bert-part01.ipynb", "provenance": [], "toc_visible": true, "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.0" } }, "nbformat": 4, "nbformat_minor": 0 }