{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "8d3ef3e3-5325-4ef8-8ccd-c4a5cf8a5a86", "metadata": {}, "outputs": [], "source": [ "# Copyright 2023 Shane Khalid. All Rights Reserved.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "# ==============================================================================License." ] }, { "cell_type": "markdown", "id": "72c1415a-8f22-46ed-81d5-f72b6a28d209", "metadata": {}, "source": [ "# Stock Prediction Deep Learning Model\n", "\n", "##### Stock/Asset prices are time-series data and so I am implementing a LSTM (Long Short-Term Memory) which is a type of RNN (Recurrent Neural Network) that can remember information over a long period of time. " ] }, { "cell_type": "markdown", "id": "1055917c", "metadata": { "papermill": { "duration": 0.01662, "end_time": "2023-07-30T10:00:49.783467", "exception": false, "start_time": "2023-07-30T10:00:49.766847", "status": "completed" }, "tags": [] }, "source": [ "##### Import libraries. I normally use yfinance for this, but I have a clean dataset loading in" ] }, { "cell_type": "code", "execution_count": 1, "id": "6785e45e", "metadata": { "papermill": { "duration": 13.344944, "end_time": "2023-07-30T10:01:03.145307", "exception": false, "start_time": "2023-07-30T10:00:49.800363", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-10-30 17:36:22.078914: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2023-10-30 17:36:22.079005: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2023-10-30 17:36:22.079025: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2023-10-30 17:36:22.175099: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2023-10-30 17:36:27.494027: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:43:00.0/numa_node\n", "Your kernel may have been built without NUMA support.\n", "2023-10-30 17:36:27.650172: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:43:00.0/numa_node\n", "Your kernel may have been built without NUMA support.\n", "2023-10-30 17:36:27.650226: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:43:00.0/numa_node\n", "Your kernel may have been built without NUMA support.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Num GPUs Available: 1\n", "GPU is available and being used.\n" ] } ], "source": [ "import yfinance as yf\n", "\n", "import tensorflow as tf\n", "from tensorflow import keras\n", "\n", "from keras.layers import GRU, Dropout, SimpleRNN, LSTM, Dense, SimpleRNN, GRU\n", "from keras.models import Sequential\n", "from sklearn.preprocessing import MinMaxScaler\n", "print(\"Num GPUs Available: \", len(tf.config.experimental.list_physical_devices('GPU')))\n", "\n", "import torch\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", " print(\"GPU is available and being used.\")\n", "else:\n", " device = torch.device(\"cpu\")\n", " print(\"GPU is not available, using CPU instead.\")\n", "\n", "import pandas as pd\n", "import numpy as np\n", "\n", "import plotly.express as px\n", "import statsmodels.api as sm\n", "import matplotlib.pyplot as plt\n", "import plotly.graph_objects as go\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.dates as dates\n", "import seaborn as sns\n", "import math\n", "import datetime\n", "import keras\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "from datetime import date, timedelta\n", "from keras.models import Sequential\n", "from keras.layers import Dense\n", "from keras.layers import LSTM\n", "from keras.layers import Dropout\n", "from keras.layers import *\n", "from keras.callbacks import EarlyStopping\n", "\n", "from keras.metrics import Accuracy\n", "from keras.metrics import F1Score\n", "from keras.metrics import Precision\n", "\n", "\n", "from sklearn.preprocessing import MinMaxScaler\n", "from sklearn.metrics import mean_squared_error\n", "from sklearn.metrics import mean_absolute_error\n", "#from sklearn.metrics import accuracy_score\n", "#from sklearn.metrics import precision_score\n", "#from sklearn.metrics import f1_score\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "id": "48b3cb15", "metadata": { "papermill": { "duration": 0.016758, "end_time": "2023-07-30T10:01:03.179177", "exception": false, "start_time": "2023-07-30T10:01:03.162419", "status": "completed" }, "tags": [] }, "source": [ "##### Load Data (Google stock)" ] }, { "cell_type": "code", "execution_count": 2, "id": "87a5801f", "metadata": { "papermill": { "duration": 0.065846, "end_time": "2023-07-30T10:01:03.261994", "exception": false, "start_time": "2023-07-30T10:01:03.196148", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", " | Date | \n", "Open | \n", "High | \n", "Low | \n", "Close | \n", "Adj Close | \n", "Volume | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "2010-01-04 | \n", "15.689439 | \n", "15.753504 | \n", "15.621622 | \n", "15.684434 | \n", "15.684434 | \n", "78169752 | \n", "
1 | \n", "2010-01-05 | \n", "15.695195 | \n", "15.711712 | \n", "15.554054 | \n", "15.615365 | \n", "15.615365 | \n", "120067812 | \n", "
2 | \n", "2010-01-06 | \n", "15.662162 | \n", "15.662162 | \n", "15.174174 | \n", "15.221722 | \n", "15.221722 | \n", "158988852 | \n", "
3 | \n", "2010-01-07 | \n", "15.250250 | \n", "15.265265 | \n", "14.831081 | \n", "14.867367 | \n", "14.867367 | \n", "256315428 | \n", "
4 | \n", "2010-01-08 | \n", "14.814815 | \n", "15.096346 | \n", "14.742492 | \n", "15.065566 | \n", "15.065566 | \n", "188783028 | \n", "
5 | \n", "2010-01-11 | \n", "15.126627 | \n", "15.126627 | \n", "14.865866 | \n", "15.042793 | \n", "15.042793 | \n", "288227484 | \n", "
6 | \n", "2010-01-12 | \n", "14.956206 | \n", "14.968969 | \n", "14.714715 | \n", "14.776777 | \n", "14.776777 | \n", "193937868 | \n", "
7 | \n", "2010-01-13 | \n", "14.426677 | \n", "14.724224 | \n", "14.361862 | \n", "14.691942 | \n", "14.691942 | \n", "259604136 | \n", "
8 | \n", "2010-01-14 | \n", "14.612112 | \n", "14.869870 | \n", "14.584835 | \n", "14.761011 | \n", "14.761011 | \n", "169434396 | \n", "
9 | \n", "2010-01-15 | \n", "14.848348 | \n", "14.853854 | \n", "14.465465 | \n", "14.514515 | \n", "14.514515 | \n", "217162620 | \n", "
\n", " | Date | \n", "Open | \n", "High | \n", "Low | \n", "Close | \n", "Adj Close | \n", "Volume | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "2023-01-03 | \n", "89.589996 | \n", "91.050003 | \n", "88.519997 | \n", "89.120003 | \n", "89.120003 | \n", "28131200 | \n", "
1 | \n", "2023-01-04 | \n", "90.349998 | \n", "90.650002 | \n", "87.269997 | \n", "88.080002 | \n", "88.080002 | \n", "34854800 | \n", "
2 | \n", "2023-01-05 | \n", "87.470001 | \n", "87.570000 | \n", "85.900002 | \n", "86.199997 | \n", "86.199997 | \n", "27194400 | \n", "
3 | \n", "2023-01-06 | \n", "86.790001 | \n", "87.690002 | \n", "84.860001 | \n", "87.339996 | \n", "87.339996 | \n", "41381500 | \n", "
4 | \n", "2023-01-09 | \n", "88.360001 | \n", "90.050003 | \n", "87.860001 | \n", "88.019997 | \n", "88.019997 | \n", "29003900 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
138 | \n", "2023-07-24 | \n", "121.660004 | \n", "123.000000 | \n", "120.980003 | \n", "121.529999 | \n", "121.529999 | \n", "29686100 | \n", "
139 | \n", "2023-07-25 | \n", "121.360001 | \n", "123.150002 | \n", "121.019997 | \n", "122.209999 | \n", "122.209999 | \n", "52509600 | \n", "
140 | \n", "2023-07-26 | \n", "130.070007 | \n", "130.979996 | \n", "128.320007 | \n", "129.270004 | \n", "129.270004 | \n", "61682100 | \n", "
141 | \n", "2023-07-27 | \n", "131.669998 | \n", "133.240005 | \n", "128.789993 | \n", "129.399994 | \n", "129.399994 | \n", "44952100 | \n", "
142 | \n", "2023-07-28 | \n", "130.779999 | \n", "133.740005 | \n", "130.570007 | \n", "132.580002 | \n", "132.580002 | \n", "36572900 | \n", "
143 rows × 7 columns
\n", "\n", " | Date | \n", "RNN_Open | \n", "LSTM_Open | \n", "GRU_Open | \n", "
---|---|---|---|---|
0 | \n", "2023-01-03 | \n", "89.761795 | \n", "82.481285 | \n", "87.833824 | \n", "
1 | \n", "2023-01-04 | \n", "89.380569 | \n", "82.145378 | \n", "88.074127 | \n", "
2 | \n", "2023-01-05 | \n", "89.087761 | \n", "81.918633 | \n", "88.152527 | \n", "
3 | \n", "2023-01-06 | \n", "88.819893 | \n", "81.767494 | \n", "87.754929 | \n", "
4 | \n", "2023-01-09 | \n", "88.490471 | \n", "81.667305 | \n", "87.387054 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
138 | \n", "2023-07-24 | \n", "131.086044 | \n", "111.596542 | \n", "120.487892 | \n", "
139 | \n", "2023-07-25 | \n", "131.083679 | \n", "111.854652 | \n", "120.331024 | \n", "
140 | \n", "2023-07-26 | \n", "131.231766 | \n", "111.973335 | \n", "120.570282 | \n", "
141 | \n", "2023-07-27 | \n", "131.598236 | \n", "112.064293 | \n", "121.998161 | \n", "
142 | \n", "2023-07-28 | \n", "132.154419 | \n", "112.244171 | \n", "123.765366 | \n", "
143 rows × 4 columns
\n", "