{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "kernelspec": { "display_name": "TensorFlow 2.3 on Python 3.6 (CUDA 10.1)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "colab": { "name": "11-9.conv_lstm_stack_sentiment_classifier.ipynb", "provenance": [] }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "kupWRkk_uJgd" }, "source": [ "# 합성곱-LSTM 적층 감성 분류기" ] }, { "cell_type": "markdown", "metadata": { "id": "wD-wiuO1uJgh" }, "source": [ "이 노트북에서 합성곱 층 위에 LSTM을 쌓아 감성에 따라 IMDB 영화 리뷰를 분류합니다." ] }, { "cell_type": "markdown", "metadata": { "id": "Mov3sSiUuJgi" }, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rickiepark/dl-illustrated/blob/master/notebooks/11-9.conv_lstm_stack_sentiment_classifier.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "id": "1TEuVdEluJgi" }, "source": [ "#### 라이브러리 적재" ] }, { "cell_type": "code", "metadata": { "id": "uNpYmfwkuJgi" }, "source": [ "from tensorflow import keras\n", "from tensorflow.keras.datasets import imdb\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Dropout, Embedding, SpatialDropout1D, LSTM\n", "from tensorflow.keras.layers import Bidirectional \n", "from tensorflow.keras.layers import Conv1D, MaxPooling1D \n", "from tensorflow.keras.callbacks import ModelCheckpoint\n", "import os\n", "from sklearn.metrics import roc_auc_score \n", "import matplotlib.pyplot as plt \n", "%matplotlib inline" ], "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "S1X9IoReuJgj" }, "source": [ "#### 하이퍼파라미터 설정" ] }, { "cell_type": "code", "metadata": { "id": "rypG-_FhuJgj" }, "source": [ "# 출력 디렉토리\n", "output_dir = 'model_output/cnnLSTM'\n", "\n", "# 훈련\n", "epochs = 4\n", "batch_size = 128\n", "\n", "# 벡터 공간 임베딩\n", "n_dim = 64 \n", "n_unique_words = 10000 \n", "max_review_length = 200 \n", "pad_type = trunc_type = 'pre'\n", "drop_embed = 0.2 \n", "\n", "# 합성곱 층 구조\n", "n_conv = 64 \n", "k_conv = 3 \n", "mp_size = 4\n", "\n", "# LSTM 층 구조\n", "n_lstm = 64 \n", "drop_lstm = 0.2" ], "execution_count": 2, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "b3Ws_sYyuJgj" }, "source": [ "#### 데이터 적재" ] }, { "cell_type": "code", "metadata": { "id": "0MuxR-QxuJgj", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "86a5feb4-716e-45a9-9f68-5124857bfeb7" }, "source": [ "(x_train, y_train), (x_valid, y_valid) = imdb.load_data(num_words=n_unique_words)" ], "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz\n", "17464789/17464789 [==============================] - 1s 0us/step\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "e9BEEnBXuJgj" }, "source": [ "#### 데이터 전처리" ] }, { "cell_type": "code", "metadata": { "id": "YCJ2ShfiuJgk" }, "source": [ "x_train = pad_sequences(x_train, maxlen=max_review_length, padding=pad_type, truncating=trunc_type, value=0)\n", "x_valid = pad_sequences(x_valid, maxlen=max_review_length, padding=pad_type, truncating=trunc_type, value=0)" ], "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "id": "hCe_L-cmuJgk" }, "source": [ "#### 신경망 만들기" ] }, { "cell_type": "code", "metadata": { "id": "zZ3_xHRDuJgk" }, "source": [ "model = Sequential()\n", "model.add(Embedding(n_unique_words, n_dim, input_length=max_review_length)) \n", "model.add(SpatialDropout1D(drop_embed))\n", "model.add(Conv1D(n_conv, k_conv, activation='relu'))\n", "model.add(MaxPooling1D(mp_size))\n", "model.add(Bidirectional(LSTM(n_lstm, dropout=drop_lstm)))\n", "model.add(Dense(1, activation='sigmoid'))" ], "execution_count": 5, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "c9RIyfrnuJgk", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "7548d743-572e-49b0-c791-6ad0b24971b1" }, "source": [ "# 양 방향으로 역전파되기 때문에 LSTM 층의 파라미터가 두 배가 됩니다.\n", "model.summary() " ], "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " embedding (Embedding) (None, 200, 64) 640000 \n", " \n", " spatial_dropout1d (SpatialD (None, 200, 64) 0 \n", " ropout1D) \n", " \n", " conv1d (Conv1D) (None, 198, 64) 12352 \n", " \n", " max_pooling1d (MaxPooling1D (None, 49, 64) 0 \n", " ) \n", " \n", " bidirectional (Bidirectiona (None, 128) 66048 \n", " l) \n", " \n", " dense (Dense) (None, 1) 129 \n", " \n", "=================================================================\n", "Total params: 718,529\n", "Trainable params: 718,529\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "dlUQj4NDuJgl" }, "source": [ "#### 모델 설정" ] }, { "cell_type": "code", "metadata": { "id": "8yO5lWjNuJgl" }, "source": [ "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])" ], "execution_count": 7, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "VVyaZluxuJgl" }, "source": [ "modelcheckpoint = ModelCheckpoint(filepath=output_dir+\"/weights.{epoch:02d}.hdf5\")\n", "if not os.path.exists(output_dir):\n", " os.makedirs(output_dir)" ], "execution_count": 8, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "pBpKgUGvuJgl" }, "source": [ "#### 훈련!" ] }, { "cell_type": "code", "metadata": { "id": "adlDynPquJgl", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "7952d287-9a66-43fe-8160-e407ab9845ec" }, "source": [ "model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_valid, y_valid), callbacks=[modelcheckpoint])" ], "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/4\n", "196/196 [==============================] - 17s 24ms/step - loss: 0.4815 - accuracy: 0.7380 - val_loss: 0.3091 - val_accuracy: 0.8699\n", "Epoch 2/4\n", "196/196 [==============================] - 3s 18ms/step - loss: 0.2402 - accuracy: 0.9070 - val_loss: 0.3150 - val_accuracy: 0.8688\n", "Epoch 3/4\n", "196/196 [==============================] - 3s 18ms/step - loss: 0.1689 - accuracy: 0.9381 - val_loss: 0.3434 - val_accuracy: 0.8631\n", "Epoch 4/4\n", "196/196 [==============================] - 3s 18ms/step - loss: 0.1177 - accuracy: 0.9602 - val_loss: 0.4166 - val_accuracy: 0.8555\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 9 } ] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "id": "BS0oX5k-uJgm" }, "source": [ "#### 평가" ] }, { "cell_type": "code", "metadata": { "id": "EjvI5a3vuJgm" }, "source": [ "model.load_weights(output_dir+\"/weights.02.hdf5\") " ], "execution_count": 10, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "XtY209PwuJgm", "outputId": "4de56391-431e-411a-fb1a-65b3c77a090b", "colab": { "base_uri": "https://localhost:8080/" } }, "source": [ "y_hat = model.predict(x_valid)" ], "execution_count": 11, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "782/782 [==============================] - 4s 4ms/step\n" ] } ] }, { "cell_type": "code", "metadata": { "id": "baye2UN4uJgm", "colab": { "base_uri": "https://localhost:8080/", "height": 267 }, "outputId": "9f248055-3fa4-48b1-9a21-7f976bd9b339" }, "source": [ "plt.hist(y_hat)\n", "_ = plt.axvline(x=0.5, color='orange')" ], "execution_count": 12, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD6CAYAAABDPiuvAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAQPUlEQVR4nO3df6xfdX3H8edLKv4Wir0jrO12MVa3yrKIN1Bj4pw1UMFQkimpmaOSZk2UqXNmE7c/uoAksB8ySdStk85inMCYGc3AkY4fIVvWykUc8mOMO362A7nSUueIP6rv/fH9wL7ovdwf33u/3957n4+k+Z7zOZ9zzvvTW/rq+ZzzPaSqkCQtbS8adAGSpMEzDCRJhoEkyTCQJGEYSJIwDCRJTCMMkuxI8mSSu7vajkuyO8kD7XN5a0+Sy5OMJbkrycld+2xu/R9Isrmr/c1JvtX2uTxJ5nqQkqQXlqm+Z5DkbcD3gCur6qTW9ifAgaq6JMkFwPKq+kSSM4APA2cApwKfqapTkxwHjAIjQAF3AG+uqoNJvg58BNgL3ABcXlVfm6rwFStW1PDw8KwGLc2b797f+Xz1GwZbhzSBO+644ztVNTTRtmVT7VxVtyUZ/qnmjcDb2/JO4FbgE639yuokzJ4kxyY5ofXdXVUHAJLsBjYkuRV4dVXtae1XAmcDU4bB8PAwo6OjU3WT+uuf3975fOetg6xCmlCSRybbNtt7BsdX1eNt+Qng+La8Enisq9++1vZC7fsmaJck9VHPN5DbVUBf3mmRZGuS0SSj4+Pj/TilJC0Jsw2Db7fpH9rnk619P7C6q9+q1vZC7asmaJ9QVW2vqpGqGhkamnDaS5I0C7MNg13As08EbQau62o/tz1VtA441KaTbgROS7K8PXl0GnBj2/bdJOvaU0Tndh1LktQnU95ATvIVOjeAVyTZB2wDLgGuSbIFeAQ4p3W/gc6TRGPAM8B5AFV1IMlFwO2t34XP3kwGPgR8EXgZnRvHU948liTNrek8TfS+STatn6BvAedPcpwdwI4J2keBk6aqQ5I0f/wGsiTJMJAkGQaSJKZxz2AxGr7g+oGc9+FLzhzIeSVpKl4ZSJIMA0mSYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJ9BgGST6W5J4kdyf5SpKXJjkxyd4kY0muTnJ06/uStj7Wtg93HeeTrf3+JKf3NiRJ0kzNOgySrAQ+AoxU1UnAUcAm4FLgsqp6HXAQ2NJ22QIcbO2XtX4kWdv2eyOwAfhckqNmW5ckaeZ6nSZaBrwsyTLg5cDjwDuAa9v2ncDZbXljW6dtX58krf2qqvpBVT0EjAGn9FiXJGkGZh0GVbUf+DPgUTohcAi4A3i6qg63bvuAlW15JfBY2/dw6/+a7vYJ9pEk9UEv00TL6fyr/kTg54FX0JnmmTdJtiYZTTI6Pj4+n6eSpCWll2midwIPVdV4Vf0I+CrwVuDYNm0EsArY35b3A6sB2vZjgKe62yfY53mqantVjVTVyNDQUA+lS5K69RIGjwLrkry8zf2vB+4FbgHe0/psBq5ry7vaOm37zVVVrX1Te9roRGAN8PUe6pIkzdCyqbtMrKr2JrkW+AZwGLgT2A5cD1yV5FOt7Yq2yxXAl5KMAQfoPEFEVd2T5Bo6QXIYOL+qfjzbuiRJMzfrMACoqm3Atp9qfpAJngaqqu8D753kOBcDF/dSiyRp9vwGsiTJMJAkGQaSJAwDSRKGgSSJHp8mkqSlaviC6wdy3ocvOXNejuuVgSTJMJAkGQaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRI9hkGSY5Ncm+Q/ktyX5C1JjkuyO8kD7XN565sklycZS3JXkpO7jrO59X8gyeZeByVJmplerww+A/xTVf0S8KvAfcAFwE1VtQa4qa0DvAtY035tBT4PkOQ4YBtwKnAKsO3ZAJEk9ceswyDJMcDbgCsAquqHVfU0sBHY2brtBM5uyxuBK6tjD3BskhOA04HdVXWgqg4Cu4ENs61LkjRzvVwZnAiMA3+T5M4kX0jyCuD4qnq89XkCOL4trwQe69p/X2ubrF2S1Ce9hMEy4GTg81X1JuB/+f8pIQCqqoDq4RzPk2RrktEko+Pj43N1WEla8noJg33Avqra29avpRMO327TP7TPJ9v2/cDqrv1XtbbJ2n9GVW2vqpGqGhkaGuqhdElSt1mHQVU9ATyW5A2taT1wL7ALePaJoM3AdW15F3Bue6poHXCoTSfdCJyWZHm7cXxaa5Mk9cmyHvf/MPDlJEcDDwLn0QmYa5JsAR4Bzml9bwDOAMaAZ1pfqupAkouA21u/C6vqQI91SZJmoKcwqKpvAiMTbFo/Qd8Czp/kODuAHb3UIkmaPb+BLEkyDCRJhoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEnMQBkmOSnJnkn9s6ycm2ZtkLMnVSY5u7S9p62Nt+3DXMT7Z2u9PcnqvNUmSZmYurgw+CtzXtX4pcFlVvQ44CGxp7VuAg639staPJGuBTcAbgQ3A55IcNQd1SZKmqacwSLIKOBP4QlsP8A7g2tZlJ3B2W97Y1mnb17f+G4GrquoHVfUQMAac0ktdkqSZ6fXK4C+APwB+0tZfAzxdVYfb+j5gZVteCTwG0LYfav2fa59gH0lSH8w6DJK8G3iyqu6Yw3qmOufWJKNJRsfHx/t1Wkla9Hq5MngrcFaSh4Gr6EwPfQY4Nsmy1mcVsL8t7wdWA7TtxwBPdbdPsM/zVNX2qhqpqpGhoaEeSpckdZt1GFTVJ6tqVVUN07kBfHNV/SZwC/Ce1m0zcF1b3tXWadtvrqpq7Zva00YnAmuAr8+2LknSzC2busuMfQK4KsmngDuBK1r7FcCXkowBB+gECFV1T5JrgHuBw8D5VfXjeahLkjSJOQmDqroVuLUtP8gETwNV1feB906y/8XAxXNRiyRp5vwGsiTJMJAkGQaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSfQQBklWJ7klyb1J7kny0dZ+XJLdSR5on8tbe5JcnmQsyV1JTu461ubW/4Ekm3sfliRpJnq5MjgMfLyq1gLrgPOTrAUuAG6qqjXATW0d4F3AmvZrK/B56IQHsA04FTgF2PZsgEiS+mPWYVBVj1fVN9ry/wD3ASuBjcDO1m0ncHZb3ghcWR17gGOTnACcDuyuqgNVdRDYDWyYbV2SpJmbk3sGSYaBNwF7geOr6vG26Qng+La8Enisa7d9rW2ydklSn/QcBkleCfw98LtV9d3ubVVVQPV6jq5zbU0ymmR0fHx8rg4rSUteT2GQ5MV0guDLVfXV1vztNv1D+3yyte8HVnftvqq1Tdb+M6pqe1WNVNXI0NBQL6VLkrr08jRRgCuA+6rq012bdgHPPhG0Gbiuq/3c9lTROuBQm066ETgtyfJ24/i01iZJ6pNlPez7VuC3gG8l+WZr+0PgEuCaJFuAR4Bz2rYbgDOAMeAZ4DyAqjqQ5CLg9tbvwqo60ENdkqQZmnUYVNW/AJlk8/oJ+hdw/iTH2gHsmG0tkqTe9HJlIEkDNXzB9YMuYdHwdRSSJMNAkmQYSJIwDCRJGAaSJAwDSRKGgSQJw0CShF8666tBfkHm4UvOHNi5JR35vDKQJBkGkiTDQJKEYSBJwjCQJGEYSJLw0VJJc8D/r8DC55WBJMkwkCQ5TbRkDOoy3m8+SwuDVwaSJMNAkuQ0keaZL+frH5/oUS8MAy1ag/jL8arXPgXAJv9i1gLjNJEkyTCQJBkGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkcQWGQZEOS+5OMJblg0PVI0lJyRIRBkqOAzwLvAtYC70uydrBVSdLScUSEAXAKMFZVD1bVD4GrgI0DrkmSlowjJQxWAo91re9rbZKkPlhQr7BOshXY2la/l+T+WR5qBfCdualqwXDMffCW55be3c/TdvPnvMjlUmD2Y/7FyTYcKWGwH1jdtb6qtT1PVW0Htvd6siSjVTXS63EWEse8NDjmpWE+xnykTBPdDqxJcmKSo4FNwK4B1yRJS8YRcWVQVYeT/A5wI3AUsKOq7hlwWZK0ZBwRYQBQVTcAN/TpdD1PNS1AjnlpcMxLw5yPOVU118eUJC0wR8o9A0nSAC3aMJjq9RZJXpLk6rZ9b5Lh/lc5t6Yx5t9Lcm+Su5LclGTSx8wWium+xiTJbySpJAv+qZPpjDnJOe1nfU+Sv+13jfNhGn++fyHJLUnubH/GzxhEnXMlyY4kTya5e5LtSXJ5+/24K8nJPZ2wqhbdLzo3of8LeC1wNPDvwNqf6vMh4C/b8ibg6kHX3Ycx/zrw8rb8waUw5tbvVcBtwB5gZNB19+HnvAa4E1je1n9u0HX3adzbgQ+25bXAw4Ouu8cxvw04Gbh7ku1nAF8DAqwD9vZyvsV6ZTCd11tsBHa25WuB9UnSxxrn2pRjrqpbquqZtrqHzvc5FrLpvsbkIuBS4Pv9LG6eTGfMvw18tqoOAlTVk32ucT5MZ9wFvLotHwP8dx/rm3NVdRtw4AW6bASurI49wLFJTpjt+RZrGEzn9RbP9amqw8Ah4DV9qW5+zPSVHlvo/KtiIZtyzO3SeXVVXd/PwubRdH7Orwden+Rfk+xJsqFv1c2f6Yz7j4H3J9lH58nED/entIGZ09f4HDGPlqp/krwfGAF+bdC1zKckLwI+DXxgwKX02zI6U0Vvp3P1d1uSX6mqpwda1fx7H/DFqvrzJG8BvpTkpKr6yaALWwgW65XBdF5v8VyfJMvoXFY+1Zfq5se0XumR5J3AHwFnVdUP+lTbfJlqzK8CTgJuTfIwnXnVXQv8JvJ0fs77gF1V9aOqegj4TzrhsJBNZ9xbgGsAqurfgJfSeYfPYjWt/+ana7GGwXReb7EL2NyW3wPcXO2uzAI15ZiTvAn4KzpBsBjmkV9wzFV1qKpWVNVwVQ3TuU9yVlWNDqbcOTGdP9v/QOeqgCQr6EwbPdjPIufBdMb9KLAeIMkv0wmD8b5W2V+7gHPbU0XrgENV9fhsD7Yop4lqktdbJLkQGK2qXcAVdC4jx+jcpNk0uIp7N80x/ynwSuDv2r3yR6vqrIEV3aNpjnlRmeaYbwROS3Iv8GPg96tqIV/1TnfcHwf+OsnH6NxM/sBC/gdekq/QCfUV7T7INuDFAFX1l3Tui5wBjAHPAOf1dL4F/HslSZoji3WaSJI0A4aBJMkwkCQZBpIkDANJEoaBJAnDQJKEYSBJAv4PoMq9SBcbglMAAAAASUVORK5CYII=\n" }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "code", "metadata": { "id": "KA25ZGspuJgm", "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "outputId": "0b5efa44-56d9-45cb-a56e-c98194bab5d9" }, "source": [ "\"{:0.2f}\".format(roc_auc_score(y_valid, y_hat)*100.0)" ], "execution_count": 13, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'94.49'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 13 } ] } ] }