{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "kernelspec": { "display_name": "TensorFlow 2.3 on Python 3.6 (CUDA 10.1)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "colab": { "name": "11-7.stacked_bi_lstm_sentiment_classifier.ipynb", "provenance": [] }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "gFZPTq-jB5M8" }, "source": [ "# 적층 양방향 LSTM 감성 분류기" ] }, { "cell_type": "markdown", "metadata": { "id": "vVrrV-nwB5NA" }, "source": [ "이 노트북에서 *적층* 양방향 LSTM을 사용해 감성에 따라 IMDB 영화 리뷰를 분류합니다." ] }, { "cell_type": "markdown", "metadata": { "id": "lEBdfCPwB5NA" }, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rickiepark/dl-illustrated/blob/master/notebooks/11-7.stacked_bi_lstm_sentiment_classifier.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "id": "EdwQKCR2B5NA" }, "source": [ "#### 라이브러리 적재" ] }, { "cell_type": "code", "metadata": { "id": "msyZFCQdB5NB" }, "source": [ "from tensorflow import keras\n", "from tensorflow.keras.datasets import imdb\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Dropout, Embedding, SpatialDropout1D, LSTM\n", "from tensorflow.keras.layers import Bidirectional \n", "from tensorflow.keras.callbacks import ModelCheckpoint\n", "import os\n", "from sklearn.metrics import roc_auc_score \n", "import matplotlib.pyplot as plt \n", "%matplotlib inline" ], "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "pVFZm5mUB5NB" }, "source": [ "#### 하이퍼파라미터 설정" ] }, { "cell_type": "code", "metadata": { "id": "b01xF481B5NB" }, "source": [ "# 출력 디렉토리\n", "output_dir = 'model_output/stackedLSTM'\n", "\n", "# 훈련\n", "epochs = 4\n", "batch_size = 128\n", "\n", "# 벡터 공간 임베딩\n", "n_dim = 64 \n", "n_unique_words = 10000 \n", "max_review_length = 200 \n", "pad_type = trunc_type = 'pre'\n", "drop_embed = 0.2 \n", "\n", "# LSTM 층 구조\n", "n_lstm_1 = 64 # 줄임\n", "n_lstm_2 = 64 # new!\n", "drop_lstm = 0.2" ], "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "aknD-Z7xB5NB" }, "source": [ "#### 데이터 적재" ] }, { "cell_type": "code", "metadata": { "id": "IpWjiaryB5NC", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "63f6f5d4-8b96-45dc-a5bc-56091e12a16c" }, "source": [ "(x_train, y_train), (x_valid, y_valid) = imdb.load_data(num_words=n_unique_words) # n_words_to_skip 삭제" ], "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz\n", "17464789/17464789 [==============================] - 0s 0us/step\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "mv63kiv5B5NC" }, "source": [ "#### 데이터 전처리" ] }, { "cell_type": "code", "metadata": { "id": "OgonJNTbB5NC" }, "source": [ "x_train = pad_sequences(x_train, maxlen=max_review_length, padding=pad_type, truncating=trunc_type, value=0)\n", "x_valid = pad_sequences(x_valid, maxlen=max_review_length, padding=pad_type, truncating=trunc_type, value=0)" ], "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "id": "Hcaw0Q3ZB5NC" }, "source": [ "#### 신경망 만들기" ] }, { "cell_type": "code", "metadata": { "id": "6j1skJeTB5NC" }, "source": [ "model = Sequential()\n", "model.add(Embedding(n_unique_words, n_dim, input_length=max_review_length)) \n", "model.add(SpatialDropout1D(drop_embed))\n", "model.add(Bidirectional(LSTM(n_lstm_1, dropout=drop_lstm, \n", " return_sequences=True))) \n", "model.add(Bidirectional(LSTM(n_lstm_2, dropout=drop_lstm)))\n", "model.add(Dense(1, activation='sigmoid'))" ], "execution_count": 5, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "vDHjAKLoB5ND", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "e896bd76-97b5-4c69-8512-0cd598ab6a5b" }, "source": [ "# 양 방향으로 역전파되기 때문에 LSTM 층의 파라미터가 두 배가 됩니다.\n", "model.summary() " ], "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " embedding (Embedding) (None, 200, 64) 640000 \n", " \n", " spatial_dropout1d (SpatialD (None, 200, 64) 0 \n", " ropout1D) \n", " \n", " bidirectional (Bidirectiona (None, 200, 128) 66048 \n", " l) \n", " \n", " bidirectional_1 (Bidirectio (None, 128) 98816 \n", " nal) \n", " \n", " dense (Dense) (None, 1) 129 \n", " \n", "=================================================================\n", "Total params: 804,993\n", "Trainable params: 804,993\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "gyRcUn74B5ND" }, "source": [ "#### 모델 설정" ] }, { "cell_type": "code", "metadata": { "id": "6VNcI0NgB5ND" }, "source": [ "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])" ], "execution_count": 7, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "3L1-XFRyB5NE" }, "source": [ "modelcheckpoint = ModelCheckpoint(filepath=output_dir+\"/weights.{epoch:02d}.hdf5\")\n", "if not os.path.exists(output_dir):\n", " os.makedirs(output_dir)" ], "execution_count": 8, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "dYJhACI-B5NE" }, "source": [ "#### 훈련!" ] }, { "cell_type": "code", "metadata": { "id": "RJAhv7MzB5NE", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "91c53a91-07e8-42f1-ba8e-e3acbfdeb3d2" }, "source": [ "model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_valid, y_valid), callbacks=[modelcheckpoint])" ], "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/4\n", "196/196 [==============================] - 23s 65ms/step - loss: 0.4572 - accuracy: 0.7633 - val_loss: 0.3050 - val_accuracy: 0.8723\n", "Epoch 2/4\n", "196/196 [==============================] - 12s 59ms/step - loss: 0.2476 - accuracy: 0.9032 - val_loss: 0.3475 - val_accuracy: 0.8474\n", "Epoch 3/4\n", "196/196 [==============================] - 12s 60ms/step - loss: 0.1906 - accuracy: 0.9277 - val_loss: 0.3224 - val_accuracy: 0.8672\n", "Epoch 4/4\n", "196/196 [==============================] - 12s 59ms/step - loss: 0.1422 - accuracy: 0.9490 - val_loss: 0.3546 - val_accuracy: 0.8650\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 9 } ] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "id": "eb9PtAmzB5NE" }, "source": [ "#### 평가" ] }, { "cell_type": "code", "metadata": { "id": "3ctJSHylB5NE" }, "source": [ "model.load_weights(output_dir+\"/weights.02.hdf5\") " ], "execution_count": 10, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "kKghjYSJB5NF", "outputId": "4749f694-dbd0-4b33-ae44-8548587884d1", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "y_hat = model.predict(x_valid)" ], "execution_count": 11, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "782/782 [==============================] - 11s 13ms/step\n" ] } ] }, { "cell_type": "code", "metadata": { "id": "4k8wSGdKB5NF", "colab": { "base_uri": "https://localhost:8080/", "height": 265 }, "outputId": "b75ab811-0ebc-45d2-8baa-0b91731733a3" }, "source": [ "plt.hist(y_hat)\n", "_ = plt.axvline(x=0.5, color='orange')" ], "execution_count": 12, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAPQ0lEQVR4nO3df6yeZX3H8fdHKjp/8cN2xLVlp8bqhi6L5AQwJs5ZAwiGkkxNzZzVNGvimHPObNbtDxaVBLJNpok/1lm2apzAmBnNcCMdP2K2jGoRxwTGOAOEdiiVlrqN+KP63R/PRXc0PT3PKc95Hk6v9ys5Ofd93dd939e35/B5rnM/93OTqkKS1IdnTHoAkqTxMfQlqSOGviR1xNCXpI4Y+pLUkWWTHsDRLF++vKampiY9DOnHfefewfcXvGyy45DmcPvtt3+7qlYcadvTOvSnpqbYvXv3pIch/bh/fO3g++tvneQopDkl+cZc27y8I0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHXlafyJXkiZpassNEzv3g5dfuCjHdaYvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SODBX6Sd6b5K4kX0/y+STPTrImya4kM0muSXJi6/ustj7Ttk/NOs4HWvu9Sc5bnJIkSXOZN/STrAR+C5iuqlcAJwAbgCuAK6vqJcABYFPbZRNwoLVf2fqR5Iy238uB84FPJDlhtOVIko5m2Ms7y4CfSrIMeA7wCPA64Lq2fTtwcVte39Zp29clSWu/uqq+V1UPADPAWU+9BEnSsOYN/araC/wx8BCDsD8I3A48XlWHWrc9wMq2vBJ4uO17qPV/4ez2I+wjSRqDYS7vnMJglr4G+BnguQwuzyyKJJuT7E6ye9++fYt1Gknq0jCXd14PPFBV+6rqB8AXgFcDJ7fLPQCrgL1teS+wGqBtPwl4bHb7EfY5rKq2VtV0VU2vWLHiGEqSJM1lmNB/CDgnyXPatfl1wN3ALcCbWp+NwPVteUdbp22/uaqqtW9od/esAdYCXx5NGZKkYSybr0NV7UpyHfBV4BBwB7AVuAG4OsmHW9u2tss24LNJZoD9DO7YoaruSnItgxeMQ8AlVfXDEdcjSTqKeUMfoKouBS79ieb7OcLdN1X1XeDNcxznMuCyBY5RkjQifiJXkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI0OFfpKTk1yX5N+T3JPkVUlOTbIzyX3t+ymtb5J8LMlMkjuTnDnrOBtb//uSbFysoiRJRzbsTP+jwD9U1c8BvwjcA2wBbqqqtcBNbR3gDcDa9rUZ+CRAklOBS4GzgbOAS598oZAkjce8oZ/kJOA1wDaAqvp+VT0OrAe2t27bgYvb8nrgMzVwG3BykhcB5wE7q2p/VR0AdgLnj7QaSdJRDTPTXwPsA/4iyR1JPp3kucBpVfVI6/NN4LS2vBJ4eNb+e1rbXO0/JsnmJLuT7N63b9/CqpEkHdUwob8MOBP4ZFW9Evhf/v9SDgBVVUCNYkBVtbWqpqtqesWKFaM4pCSpGSb09wB7qmpXW7+OwYvAt9plG9r3R9v2vcDqWfuvam1ztUuSxmTe0K+qbwIPJ3lZa1oH3A3sAJ68A2cjcH1b3gG8vd3Fcw5wsF0GuhE4N8kp7Q3cc1ubJGlMlg3Z793A55KcCNwPvJPBC8a1STYB3wDe0vp+EbgAmAGeaH2pqv1JPgR8pfX7YFXtH0kVkqShDBX6VfU1YPoIm9YdoW8Bl8xxnKuAqxYyQEnS6PiJXEnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSR4b9P2ctSVNbbpjIeR+8/MKJnFeS5uNMX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHRk69JOckOSOJH/X1tck2ZVkJsk1SU5s7c9q6zNt+9SsY3ygtd+b5LxRFyNJOrqFzPTfA9wza/0K4MqqeglwANjU2jcBB1r7la0fSc4ANgAvB84HPpHkhKc2fEnSQgwV+klWARcCn27rAV4HXNe6bAcubsvr2zpt+7rWfz1wdVV9r6oeAGaAs0ZRhCRpOMPO9P8U+D3gR239hcDjVXWore8BVrbllcDDAG37wdb/cPsR9jksyeYku5Ps3rdv3wJKkSTNZ97QT/JG4NGqun0M46GqtlbVdFVNr1ixYhynlKRuLBuiz6uBi5JcADwbeAHwUeDkJMvabH4VsLf13wusBvYkWQacBDw2q/1Js/eRJI3BvDP9qvpAVa2qqikGb8TeXFW/CtwCvKl12whc35Z3tHXa9purqlr7hnZ3zxpgLfDlkVUiSZrXMDP9ubwfuDrJh4E7gG2tfRvw2SQzwH4GLxRU1V1JrgXuBg4Bl1TVD5/C+SVJC7Sg0K+qW4Fb2/L9HOHum6r6LvDmOfa/DLhsoYOUJI2Gn8iVpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkfmDf0kq5PckuTuJHcleU9rPzXJziT3te+ntPYk+ViSmSR3Jjlz1rE2tv73Jdm4eGVJko5k2RB9DgHvq6qvJnk+cHuSncA7gJuq6vIkW4AtwPuBNwBr29fZwCeBs5OcClwKTAPVjrOjqg6MuihJx5epLTdMegjHjXln+lX1SFV9tS3/N3APsBJYD2xv3bYDF7fl9cBnauA24OQkLwLOA3ZW1f4W9DuB80dajSTpqIaZ6R+WZAp4JbALOK2qHmmbvgmc1pZXAg/P2m1Pa5ur/SfPsRnYDHD66acvZHhPG5OalTx4+YUTOa+kpWPoN3KTPA/4G+C3q+o7s7dVVTG4ZPOUVdXWqpququkVK1aM4pCSpGao0E/yTAaB/7mq+kJr/la7bEP7/mhr3wusnrX7qtY2V7skaUyGuXsnwDbgnqr6yKxNO4An78DZCFw/q/3t7S6ec4CD7TLQjcC5SU5pd/qc29okSWMyzDX9VwO/Bvxbkq+1tt8HLgeuTbIJ+Abwlrbti8AFwAzwBPBOgKran+RDwFdavw9W1f6RVCFJGsq8oV9V/wRkjs3rjtC/gEvmONZVwFULGaAkaXT8RK4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpIwt6yqakvvlc+6XPmb4kdcSZ/nFkkrMwn+UvLQ3O9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOuLdOxqJnu7fvvrFjwFwzoTO39O/tUbPmb4kdcSZvnSMnHFrKXKmL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI2MP/STnJ7k3yUySLeM+vyT1bKyhn+QE4OPAG4AzgLcmOWOcY5Ckno17pn8WMFNV91fV94GrgfVjHoMkdWvZmM+3Enh41voe4OzZHZJsBja31f9Jcu8Cjr8c+PZTGuHSZN1j9KrDS28c96mf5M+7A7ni8OKx1P2zc20Yd+jPq6q2AluPZd8ku6tqesRDetqz7r5Yd19GXfe4L+/sBVbPWl/V2iRJYzDu0P8KsDbJmiQnAhuAHWMegyR1a6yXd6rqUJLfBG4ETgCuqqq7RniKY7osdByw7r5Yd19GWneqapTHkyQ9jfmJXEnqiKEvSR1ZkqE/36MckjwryTVt+64kU+Mf5egNUffvJLk7yZ1Jbkoy5726S8mwj+5I8itJKslxcVvfMHUneUv7md+V5K/GPcbFMMTv+elJbklyR/tdv2AS4xylJFcleTTJ1+fYniQfa/8mdyY585hPVlVL6ovBG8D/CbwYOBH4V+CMn+jzG8Cn2vIG4JpJj3tMdf8y8Jy2/K5e6m79ng98CbgNmJ70uMf0814L3AGc0tZ/etLjHlPdW4F3teUzgAcnPe4R1P0a4Ezg63NsvwD4eyDAOcCuYz3XUpzpD/Moh/XA9rZ8HbAuScY4xsUwb91VdUtVPdFWb2PwOYilbthHd3wIuAL47jgHt4iGqfvXgY9X1QGAqnp0zGNcDMPUXcAL2vJJwH+NcXyLoqq+BOw/Spf1wGdq4Dbg5CQvOpZzLcXQP9KjHFbO1aeqDgEHgReOZXSLZ5i6Z9vEYGaw1M1bd/tTd3VV3TDOgS2yYX7eLwVemuSfk9yW5PyxjW7xDFP3HwJvS7IH+CLw7vEMbaIW+t//nJ52j2HQU5fkbcA08EuTHstiS/IM4CPAOyY8lElYxuASz2sZ/FX3pSS/UFWPT3RUi++twF9W1Z8keRXw2SSvqKofTXpgS8FSnOkP8yiHw32SLGPwJ+BjYxnd4hnqERZJXg/8AXBRVX1vTGNbTPPV/XzgFcCtSR5kcL1zx3HwZu4wP+89wI6q+kFVPQD8B4MXgaVsmLo3AdcCVNW/AM9m8FCy49nIHmGzFEN/mEc57AA2tuU3ATdXezdkCZu37iSvBP6MQeAfD9d3YZ66q+pgVS2vqqmqmmLwXsZFVbV7MsMdmWF+z/+WwSyfJMsZXO65f5yDXATD1P0QsA4gyc8zCP19Yx3l+O0A3t7u4jkHOFhVjxzLgZbc5Z2a41EOST4I7K6qHcA2Bn/yzTB4c2TD5EY8GkPW/UfA84C/bu9bP1RVF01s0CMwZN3HnSHrvhE4N8ndwA+B362qJf0X7ZB1vw/48yTvZfCm7juW+qQuyecZvIAvb+9VXAo8E6CqPsXgvYsLgBngCeCdx3yuJf5vJUlagKV4eUeSdIwMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktSR/wOGTwZhrtWtXQAAAABJRU5ErkJggg==\n" }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "code", "metadata": { "id": "X3miUGLjB5NF", "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "outputId": "9fa0eaeb-dba9-495c-9e25-ba51dc9c499a" }, "source": [ "\"{:0.2f}\".format(roc_auc_score(y_valid, y_hat)*100.0)" ], "execution_count": 13, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'94.00'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 13 } ] } ] }