{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Disaster Watcher \n", "\n", "## Disaster Identification using Tweeter Data and Deep Learning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The sole purpose of this notebook is to presents and outlines the steps that was taken to train the model.\n", "\n", "Due to the large run time, the model training section was not run in this notebook.The actual production jupyter notebook which was trained using google Colab can be found in the [github repo](https://github.com/khordoo/disaster-watch-classifier)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Introduction\n", "Social media is increasingly being used to broadcast useful information during local crisis situations(e.g. hurricanes, earthquakes, explosions, bombings,etc).Identifying disaster related information in social media is challenging due to the low signal-to-noise ratio.In this work we will use NLP to address this challenge. \n", "\n", "Some of the tweets sent from mobile devices can be geotagged containing the precise\n", "location coordinates. However, only about 1% to 3% of all tweets are geotagged.Identifying the disaster related tweets along with their is highly valuable to for the first responders in the disaster and crisis situations.\n", "In this project we fist. identify the disaster related tweets from a deep learning model and then use Named Entity Recognition library to identify and map the location of the data.\n", "\n", "## 2. Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The natural disaster events generally generate a massive and disperse reaction in social media channels.Users usually express their thoughts and actions taken before, during, and after the storm. We \n", "used the classified crisis related tweets collection from the CrisisLex.org which is a repository of crisis-related social media data. \n", "We used the CrisisLexT6 dataset which includes Tweets from 6 crises, labeled by relatedness.\n", "- Contents: ~60K tweets posted during 6 crisis events in 2012 and 2013.\n", "- Labels: ~60,000 tweets (10,000 in each collection) were labeled by crowdsourcing workers according to relatedness (as \"on-topic\", or \"off-topic\").\n", "\n", "The data from the following crisis events were used in this analysis : \n", " - Flood \n", " - Earthquake\n", " - Hurricane\n", " - Tornado\n", " - Explosion\n", " - Bombing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Preprocessing \n", "The preprocessing of the text data is an essential step in any NLP and text classification analysis and machine learning algorithms.[]The objective of this step is to clean noise those are less relevant to find the sentiment of tweets such as punctuation, special characters, numbers, and terms which don’t carry much weightage in context to the text.[]\n", "Lets first import the required packages." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%matplotlib inline\n", "\n", "import os\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import tensorflow as tf\n", "from sklearn.base import TransformerMixin ,BaseEstimator\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.model_selection import train_test_split\n", "from tensorflow.keras.utils import to_categorical\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "from tensorflow.keras.models import Sequential ,model_from_json\n", "from tensorflow.keras.layers import Embedding,Dense,Dropout ,GlobalMaxPool1D\n", "\n", "from IPython.display import clear_output\n", "from tensorflow.keras.wrappers.scikit_learn import KerasClassifier\n", "from sklearn.model_selection import RandomizedSearchCV" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.1 Loading\n", "The data are stored in 11 csv formated files.We first load this data and then save it into a single combined file for further analysis.The name of the file is the same as the name of the crisis.The Tweets in each files has been labeled as \"on-topic\" or \"off-topic\" and do not contains information about the type of the crisis.However, the the type of the crisis is represented in the file name.We will use these file name to assign proper labels to each category.\n", "First lets load the data and have quick look at it.\n", "\n", "We will be using the Pipeline from Sklean library to streamline the preprocessing routine. As a result the the classes in the analysis should be compatible with the pipeline arcitecture." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "COMBINDED_DATASET='combined.csv'\n", "DATA_DIRECTORY='../datasets'\n", "class DatasetExtractor(BaseEstimator,TransformerMixin):\n", " \"\"\"Extractor class that loads multiple Tweet files and creates a single unified file.\"\"\"\n", " \n", " def transform(self,X,y=None):\n", " return self.hot_load()\n", " \n", " def hot_load(self):\n", " \"\"\"Loads the pre-combined file if exists otherwise load all the files\"\"\"\n", " combined_file_path=f'{DATA_DIRECTORY}/{COMBINDED_DATASET}'\n", " if os.path.isfile(combined_file_path):\n", " print('File Exists.Reloaded.')\n", " return pd.read_csv(combined_file_path, index_col=0)\n", " print('Loading Files..')\n", " combined_dataset=self.load_data()\n", " combined_dataset.to_csv(combined_file_path)\n", " return combined_dataset\n", " \n", " def load_data(self):\n", " \"\"\"Loads multiple disaster related tweet file and returns a Single Pandas data frame\"\"\" \n", " combined_dataset=pd.DataFrame()\n", " for file_name in os.listdir(path=DATA_DIRECTORY):\n", " category=self.extract_category_name(file_name)\n", " df=pd.read_csv(f'{DATA_DIRECTORY}/{file_name}')\n", " df['category']= category \n", " combined_dataset=combined_dataset.append(df,ignore_index = True)\n", " return combined_dataset \n", " \n", " def extract_category_name(self,file_name):\n", " \"\"\"Helper method that extracts the Disaster Category from the file name\"\"\"\n", " category=file_name.split('.')[0]\n", " if '_' in category:\n", " category=category.split('_')[0]\n", " return category " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the purpose of demonstration We load each part of the pipeline to provide explanation and explanation about each part separately .Ultimately we chain all of these methods into a pipeline for the final modeling. \n", "Lets load the data and see how it looks like:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "File Exists.Reloaded.\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tweet idtweetlabelcategory
0'348351442404376578'@Jay1972Jay Nope. Mid 80's. It's off Metallica...off-topicfloods
1'348167215536803841'Nothing like a :16 second downpour to give us ...off-topicfloods
2'348644655786778624'@NelsonTagoona so glad that you missed the flo...on-topicfloods
3'350519668815036416'Party hard , suns down , still warm , lovin li...off-topicfloods
4'351446519733432320'@Exclusionzone if you compare yourself to wate...off-topicfloods
\n", "
" ], "text/plain": [ " tweet id tweet \\\n", "0 '348351442404376578' @Jay1972Jay Nope. Mid 80's. It's off Metallica... \n", "1 '348167215536803841' Nothing like a :16 second downpour to give us ... \n", "2 '348644655786778624' @NelsonTagoona so glad that you missed the flo... \n", "3 '350519668815036416' Party hard , suns down , still warm , lovin li... \n", "4 '351446519733432320' @Exclusionzone if you compare yourself to wate... \n", "\n", " label category \n", "0 off-topic floods \n", "1 off-topic floods \n", "2 on-topic floods \n", "3 off-topic floods \n", "4 off-topic floods " ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset=DatasetExtractor().transform(None)\n", "dataset.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As mentioned, the data consistes of the following features: \n", " - tweet id\n", " - tweet\n", " - label \n", " - category\n", " \n", " The category feature was inferred from the file name and added to the data during the loading. The labels were assigned using human labels for each crisis.We will only use the \"on-topic\" tweets from each category.All the \"off-topic\" tweets would be combined and would be classified as the \"unrelated\".\n", " \n", "## 3.2 Data Cleaning\n", "\n", "#### 3.3.1 Text cleaning\n", "Tweets can contain many different kind of noise that can negatively affect the performance of the machine learning algorithms . We need to carefully get rid of them. We will use the of regular expressions and replace functionality in Pandas to remove the unwanted noise in the data.\n", "\n", "#### 3.3.1 Re-Tweets:\n", "They add no real value to the data and can sometimes lead to overfitting \n", "\n", "#### 3.3.2 URL's\n", "They do not deliver any predictive power, The sentiment of a tweet can not be judged by reading an URL. In the worst case scenario they might lead to overfitting.\n", "\n", "`df['tweet']=df['tweet'].str.replace('http\\S+', '',regex=True)`\n", "#### 3.3.3 Symbols\n", "Hashtags, commas, points and and all kind of punctuation symbols are removed.\n", "\n", "`df['tweet']=df['tweet'].str.replace('[^a-zA-Z\\s]', '',regex=True)`\n", "\n", "#### 3.3.3 White Spaces\n", "We also get ride of any additional white spaces in the texts that might be created due to the previous steps.\n", "\n", "`df['tweet']=df['tweet'].str.strip()`\n", "\n", "`df['tweet']=df['tweet'].str.replace('\\s+', '',regex=True)`\n", "\n", "#### 3.3.4 Lower case\n", "All texts are transformed to lowercase.\n", "\n", "#### 3.3.4 Location Names\n", "The names of the location which disaster happened were repeated in so many tweets.We want to prevent the model from associating these location names with the crisis and as a result we remove the most frequent ones from the Tweets.\n", "The follwing list of words were removed from the Tweets:\n", "\n", "*[\"Boston\", \"Oklahoma\",\"Texas\",\"Nepal\",\"California\",\"Calgary\",\"Chile\",\"Alberta\",\"Pakistan\" ,\"WestTX\",\"Canada\",\"yycflood\",\"USA\",\"'S\",]*\n", " " ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "STOP_WORDS=[\"Boston\", \"Oklahoma\",\"Texas\",\"Nepal\",\"California\",\"Calgary\",\"Chile\",\"Alberta\",\"Pakistan\" ,\"WestTX\",\"Canada\",\"yycflood\",\"USA\",\"'S\",]\n", "class DatasetCleaner(BaseEstimator,TransformerMixin):\n", " \"\"\"Removes Redundent features and rows with missing values\"\"\"\n", " def transform(self,X,y=None):\n", " columns=X.columns.tolist()\n", " X.columns=[column.strip() for column in columns]\n", " X=X.drop('tweet id',axis=1)\n", " X=X.dropna()\n", " X['tweet']=X['tweet'].str.replace('@', '')\n", " X['tweet']=X['tweet'].str.replace('#', '')\n", " X['tweet']=X['tweet'].str.replace('.', '')\n", " X['tweet']=X['tweet'].str.replace(',', '')\n", " X['tweet']=X['tweet'].str.replace('http\\S+', '',regex=True)\n", " X['tweet']=X['tweet'].str.replace('@\\w+', '',regex=True)\n", " X['tweet']=X['tweet'].str.replace('\\s+', '',regex=True)\n", " X['tweet']=X['tweet'].str.strip()\n", " X['tweet']=X['tweet'].str.lower()\n", " for word in STOP_WORDS:\n", " word=word.lower()\n", " X['tweet']=X['tweet'].str.replace(word, '') \n", " return X" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tweetlabelcategory
0jay1972jaynopemid80itoffmetallica2ndalbumridet...off-topicfloods
1nothinglikea:16seconddownpourtogiveussomemuchn...off-topicfloods
2nelsontagoonasogladthatyoumissedthefloodsandsa...on-topicfloods
3partyhardsunsdownstillwarmlovinlifesmileharddo...off-topicfloods
4exclusionzoneifyoucompareyourselftowaterdoesth...off-topicfloods
\n", "
" ], "text/plain": [ " tweet label category\n", "0 jay1972jaynopemid80itoffmetallica2ndalbumridet... off-topic floods\n", "1 nothinglikea:16seconddownpourtogiveussomemuchn... off-topic floods\n", "2 nelsontagoonasogladthatyoumissedthefloodsandsa... on-topic floods\n", "3 partyhardsunsdownstillwarmlovinlifesmileharddo... off-topic floods\n", "4 exclusionzoneifyoucompareyourselftowaterdoesth... off-topic floods" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset_cleaned=DatasetCleaner().transform(dataset)\n", "dataset_cleaned.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ## 3.3 Re-Sampling\n", "Lets take a look to see how many Tweets do we have in each category regardless of being on or off topic.We want to make sure the number of tweets in each category are in the same order and we have a balanced dataset.We would also shuffle the tweets to make sure that the tweets have no particular order.\n", "Lets first see how many tweets we have in each category.This would be total number of tweets.Each file has **on-topic** and **off-topic** tweets which is the way they have been labeled by human labelers." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CrisisTweet Count
0floods20064
1bombing10012
2hurricane10008
3explosion10006
4tornado9992
5earthquake9057
\n", "
" ], "text/plain": [ " Crisis Tweet Count\n", "0 floods 20064\n", "1 bombing 10012\n", "2 hurricane 10008\n", "3 explosion 10006\n", "4 tornado 9992\n", "5 earthquake 9057" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Crisis=pd.DataFrame(dataset['category'].value_counts())\n", "Crisis.reset_index(inplace=True)\n", "Crisis.rename(columns={'index':'Crisis',\"category\":'Tweet Count'} ,inplace=True)\n", "Crisis" ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 0, ' ')" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "f,ax =plt.subplots(figsize=(15,7))\n", "sns.barplot(x='Crisis',y='Tweet Count',data=Crisis ,palette=sns.light_palette((210, 90, 60),10, input=\"husl\" ,reverse=True),ax=ax)\n", "ax.set_xlabel(' ')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.3.1 All Tweets ( On-Topic and off-Topic ) for each category \n", "\n", "As you can see we have roughly about 10,000 tweets for each crisis, except floods.As a next step lets see how many related (on-topic) Tweets we have in each category.This is more important since we are only using the on-topic Tweets from each category during the classification." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CrisisTweet Count
0on-topic_floods10603
1off-topic_floods9461
2on-topic_hurricane6138
3on-topic_bombing5648
4on-topic_explosion5246
5off-topic_tornado5165
6on-topic_tornado4827
7off-topic_explosion4760
8on-topic_earthquake4580
9Off-topic_earthquake4475
10off-topic_bombing4364
11off-topic_hurricane3870
\n", "
" ], "text/plain": [ " Crisis Tweet Count\n", "0 on-topic_floods 10603\n", "1 off-topic_floods 9461\n", "2 on-topic_hurricane 6138\n", "3 on-topic_bombing 5648\n", "4 on-topic_explosion 5246\n", "5 off-topic_tornado 5165\n", "6 on-topic_tornado 4827\n", "7 off-topic_explosion 4760\n", "8 on-topic_earthquake 4580\n", "9 Off-topic_earthquake 4475\n", "10 off-topic_bombing 4364\n", "11 off-topic_hurricane 3870" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset['label_full']=dataset['label']+'_'+dataset['category']\n", "Crisis_topics=pd.DataFrame(dataset['label_full'].value_counts())\n", "Crisis_topics.drop('On-topic_earthquake',axis=0,inplace=True)\n", "Crisis_topics.reset_index(inplace=True)\n", "Crisis_topics.rename(columns={'index':'Crisis',\"label_full\":'Tweet Count'} ,inplace=True)\n", "\n", "Crisis_topics" ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Numer of on-topic and off-topic Tweets in each crisis Category ')" ] }, "execution_count": 112, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "f,ax =plt.subplots(figsize=(15,7))\n", "sns.barplot(y='Crisis',x='Tweet Count',data=Crisis_topics ,palette=sns.light_palette((210, 90, 60),20, input=\"husl\" ,reverse=True),ax=ax)\n", "ax.set_xlabel(' ')\n", "ax.set_title('Numer of on-topic and off-topic Tweets in each crisis Category ')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.3.2 On-Topic Tweets\n", "\n", "Lets take a look at only the on topic Tweets in each Category:\n", "We can see that labels are balanced. We have about 5000 on-topic Tweets in each category(except flood)." ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/adminn/.local/lib/python3.6/site-packages/ipykernel_launcher.py:2: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", " \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CrisisTweet CountTweet_pct
0on-topic_floods1060328.624264
2on-topic_hurricane613816.570380
3on-topic_bombing564815.247557
4on-topic_explosion524614.162302
6on-topic_tornado482713.031154
8on-topic_earthquake458012.364343
\n", "
" ], "text/plain": [ " Crisis Tweet Count Tweet_pct\n", "0 on-topic_floods 10603 28.624264\n", "2 on-topic_hurricane 6138 16.570380\n", "3 on-topic_bombing 5648 15.247557\n", "4 on-topic_explosion 5246 14.162302\n", "6 on-topic_tornado 4827 13.031154\n", "8 on-topic_earthquake 4580 12.364343" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Crisis_topics_on_topic= Crisis_topics[Crisis_topics['Crisis'].str.contains(\"on-topic\")]\n", "Crisis_topics_on_topic['Tweet_pct']=Crisis_topics_on_topic['Tweet Count']*100/Crisis_topics_on_topic['Tweet Count'].sum()\n", "Crisis_topics_on_topic" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'Percentage of on-topic Tweets in Each Category')" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "f,ax =plt.subplots(figsize=(15,7))\n", "sns.barplot(x='Crisis',y='Tweet_pct',data=Crisis_topics_on_topic ,palette=sns.light_palette((216, 100, 40), input=\"husl\" ,reverse=True),ax=ax)\n", "ax.set_xlabel(' ')\n", "ax.set_ylabel('Percentage of on-topic Tweets in Each Category')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 3.3.3 Off-Topic Tweets\n", "To avoid overfitting we also use a random set of off-topic tweets from each of the categories.We label all these tweets as **unrelated**. Using these additional label would let the model learn to better distinguish between the related and unrelated tweets for each category.\n", "\n" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total Number of 'Off-Topic' Tweets 32095\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CrisisTweet Count
1off-topic_floods9461
5off-topic_tornado5165
7off-topic_explosion4760
9Off-topic_earthquake4475
10off-topic_bombing4364
\n", "
" ], "text/plain": [ " Crisis Tweet Count\n", "1 off-topic_floods 9461\n", "5 off-topic_tornado 5165\n", "7 off-topic_explosion 4760\n", "9 Off-topic_earthquake 4475\n", "10 off-topic_bombing 4364" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Crisis_topics_off_topic= Crisis_topics[~Crisis_topics['Crisis'].str.contains(\"on-topic\")]\n", "total_off_topic_tweets=Crisis_topics_off_topic['Tweet Count'].sum()\n", "print(\"Total Number of 'Off-Topic' Tweets\",total_off_topic_tweets)\n", "Crisis_topics_off_topic.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### 3.3.3.1 Imbalanced Data\n", "Imbalanced data generally refers to an issue with classification problems where the classes are not represented equally.In our case, since each category has it own off-topic tweets, the total number of off-topic tweets from all of the categories would be way higher than the on-topic tweets in each category.This would make our database highly imbalanced.\n", "\n", "Lets plot the total number of off-topic tweets along with the on-topic tweets.Note that the \"off-topic\" would also be one of our prediction categories, as a result, this category should also have the same number of tweets (roughly) as the other categories. \n", "\n", "Lets label all of these tweets with an *unrelated* label." ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/adminn/.local/lib/python3.6/site-packages/ipykernel_launcher.py:1: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", " \"\"\"Entry point for launching an IPython kernel.\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CrisisTweet Count
0unrelated32095
0on-topic_floods10603
2on-topic_hurricane6138
3on-topic_bombing5648
4on-topic_explosion5246
6on-topic_tornado4827
8on-topic_earthquake4580
\n", "
" ], "text/plain": [ " Crisis Tweet Count\n", "0 unrelated 32095\n", "0 on-topic_floods 10603\n", "2 on-topic_hurricane 6138\n", "3 on-topic_bombing 5648\n", "4 on-topic_explosion 5246\n", "6 on-topic_tornado 4827\n", "8 on-topic_earthquake 4580" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Crisis_topics_off_topic['Crisis']='unrelated'\n", "Crisis_topics_off_topic_g=Crisis_topics_off_topic.groupby(by='Crisis').sum()\n", "Crisis_topics_off_topic_g.reset_index(inplace=True)\n", "all_topics =Crisis_topics_off_topic_g.append(Crisis_topics_on_topic[['Crisis','Tweet Count']])\n", "all_topics" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Imbalanced Dataset: Total Number of Tweets in each category')" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "f,ax =plt.subplots(figsize=(15,7))\n", "blues=sns.light_palette((216, 100, 40),all_topics.shape[0], input=\"husl\" ,reverse=True)\n", "blues[0]=sns.color_palette(\"RdBu\", 10)[0]\n", "sns.barplot(x='Crisis',y='Tweet Count',data=all_topics ,palette=blues, ax=ax)\n", "ax.set_xlabel(' ')\n", "ax.set_ylabel('Number of Tweets')\n", "ax.set_title( 'Imbalanced Dataset: Total Number of Tweets in each category')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### 3.3.3.2 balancing the Data\n", "As we can see the number of *unrelated* tweets are way higher than the actual on topic tweets.To solve the problem, we resample a subset of these these unrelated Tweets.The total number that we re-sample from these unrelated tweets would be equal to the average number of all tweets in each dataset." ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [], "source": [ "class DistributionValidSampler(BaseEstimator,TransformerMixin):\n", " \"\"\"Samples the (related and random ) tweets with equal proportion\"\"\"\n", " def __init__(self,unrelated_size=None ,ignore_unrelated_proportion=True):\n", " self._unrelated_size=unrelated_size\n", " self._ignore_unrelated_proportion=ignore_unrelated_proportion\n", "\n", " \n", " def transform(self,X,y=None):\n", " #Shuffle tweets\n", " X_=X.sample(frac=1).reset_index(drop=True)\n", " X_=self._label_categories(X_) \n", " related,unrelated =self._equal_split(X_)\n", " X_=self._merge(related,unrelated)\n", " X_=X_.drop('category',axis=1) \n", " return X_\n", " \n", " def _label_categories(self,X):\n", " \"\"\"Assings the category name to on-topic tweets and unrelated to off-topic tweets in \n", " each category\n", " \"\"\" \n", " if self._ignore_unrelated_proportion:\n", " X['label']=X.apply(lambda row: row['category'] if 'on-topic' in row['label'] else 'unrelated',axis=1 ) \n", " else:\n", " X['label']=X.apply(lambda row: row['category'] if 'on-topic' in row['label'] else 'unrelated_'+row['category'],axis=1 ) \n", " return X\n", " \n", " def _equal_split(self,X):\n", " \"\"\"Splits the dataseta into related and unrelated tweets.\n", " This ensures that the number of unrelated tweets are not too high and \n", " is in reasonable range.\n", " \"\"\"\n", " related=X[X['label'].str.contains('unrelated')==False]\n", " unrelated=X[X['label'].str.contains('unrelated')]\n", " ave_tweets=self._average_tweet_per_category(X)\n", " unrelated=self._slice(unrelated,size=self._unrelated_size ,ave_size=ave_tweets)\n", " return related,unrelated\n", " \n", " def _merge(self,X1,X2):\n", " \"\"\"Merges the dataframes toghether\"\"\"\n", " X=pd.DataFrame()\n", " X=X.append(X1)\n", " X=X.append(X2)\n", " return X\n", " \n", " def _slice(self,X, size ,ave_size):\n", " \"\"\"Extracts a subset of rows from a dataframe\"\"\"\n", " if size is None:\n", " size =ave_size\n", " if size < X.shape[0]:\n", " return X[:size]\n", " return X \n", " \n", " def _average_tweet_per_category(self,X):\n", " \"\"\"Calculate the average number of tweets across all tweet categories\"\"\"\n", " category_values=pd.DataFrame(X['label'].value_counts())\n", " category_values=category_values.drop('unrelated',axis=0)\n", " return int(category_values['label'].mean())\n", " " ] }, { "cell_type": "code", "execution_count": 99, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
label
floods10603
unrelated6173
hurricane6138
bombing5648
explosion5246
tornado4827
earthquake4580
\n", "
" ], "text/plain": [ " label\n", "floods 10603\n", "unrelated 6173\n", "hurricane 6138\n", "bombing 5648\n", "explosion 5246\n", "tornado 4827\n", "earthquake 4580" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dataset_resampled=DistributionValidSampler().transform(dataset_cleaned)\n", "dataset_resampled_topics=pd.DataFrame(dataset_resampled['label'].value_counts())\n", "display(dataset_resampled_topics)\n", "dataset_resampled_topics.reset_index(inplace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets see the number of tweets in each category in the re-sampled dataset:" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Balanced Dataset: Total Number of Tweets in each category')" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "f,ax =plt.subplots(figsize=(15,7))\n", "blues=sns.light_palette((216, 100, 40),dataset_resampled_topics.shape[0], input=\"husl\" ,reverse=True)\n", "blues[1]=sns.color_palette(\"RdBu\", 10)[0]\n", "sns.barplot(x='index',y='label',data=dataset_resampled_topics ,palette=blues, ax=ax)\n", "ax.set_xlabel(' ')\n", "ax.set_ylabel('Number of Tweets')\n", "ax.set_title( 'Balanced Dataset: Total Number of Tweets in each category')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Data Transformation\n", "### 4.1 Tokenization\n", "One of the common preprocessing task in NLP (Natural Language Processing) is tokenization. Given a character sequence and a defined document unit, tokenization is the task of chopping it up into pieces, called tokens[[1]](https://nlp.stanford.edu/IR-book/html/htmledition/tokenization-1.html)\n", "We used the Tokenizer() class from the Keras Preprocessing to vectorize our text data. It will turn our sentences into sequences of integers.We use 10,000 words for this analysis. \n", "### 4.2 Padding\n", "We pad all the vectorized text sequences with zeros to make all the sequences of the same length. We use the maximum size to be 100." ] }, { "cell_type": "code", "execution_count": 108, "metadata": {}, "outputs": [], "source": [ "class TextTokenizer(BaseEstimator,TransformerMixin):\n", " \"\"\"This is a simple Wrapper class for Keras Tokenizer.\"\"\"\n", " def __init__(self,pad_sequences,num_words=10000,max_length=100,max_pad_length=100 ):\n", " self._num_words=num_words\n", " self.max_length=max_length\n", " self._tokenizer=None\n", " self._pad_sequences=pad_sequences\n", " self._max_pad_length=max_pad_length\n", " self.vocab_size=None\n", " self.tokenizer=None\n", " \n", " def transform(self,X,y=None):\n", " self.tokenizer,self.vocab_size=self._get_tokenizer(X['tweet'])\n", " X['tweet_encoded']=self.tokenizer.texts_to_sequences(X['tweet'])\n", " X['tweet_encoded']= X['tweet_encoded'].apply(lambda x: self._pad_sequences([x],maxlen=self._max_pad_length ,padding='post')[0])\n", " \n", " return X\n", " def _get_tokenizer(self,X):\n", " tokenizer=tf.keras.preprocessing.text.Tokenizer(num_words=self._num_words)\n", " tokenizer.fit_on_texts(X)\n", " vocab_size=len(tokenizer.word_index)+1\n", " return tokenizer,vocab_size" ] }, { "cell_type": "code", "execution_count": 118, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Vocab Size: 65246\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tweetlabeltweet_encodedlabel_encodedlabel_one_hot
1zooduringfloodmtnatstechysonstaffmemberspentwe...floods[7554, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...3[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
4findthelatestlocalfloodinformation:assoutheast...floods[634, 824, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...3[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
5floodvictimslookingtogovernmentforhelp-mostins...floods[2366, 7555, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...3[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
7rt911buff::massiveexplosionu/d-localhospitalsn...explosion[41, 1743, 86, 2367, 0, 0, 0, 0, 0, 0, 0, 0, 0...2[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
8caughtoncamera:fertilizerplantexplosionnearwac...explosion[29, 27, 7556, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...2[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
\n", "
" ], "text/plain": [ " tweet label \\\n", "1 zooduringfloodmtnatstechysonstaffmemberspentwe... floods \n", "4 findthelatestlocalfloodinformation:assoutheast... floods \n", "5 floodvictimslookingtogovernmentforhelp-mostins... floods \n", "7 rt911buff::massiveexplosionu/d-localhospitalsn... explosion \n", "8 caughtoncamera:fertilizerplantexplosionnearwac... explosion \n", "\n", " tweet_encoded label_encoded \\\n", "1 [7554, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... 3 \n", "4 [634, 824, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... 3 \n", "5 [2366, 7555, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... 3 \n", "7 [41, 1743, 86, 2367, 0, 0, 0, 0, 0, 0, 0, 0, 0... 2 \n", "8 [29, 27, 7556, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... 2 \n", "\n", " label_one_hot \n", "1 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] \n", "4 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] \n", "5 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] \n", "7 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] \n", "8 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] " ] }, "execution_count": 118, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenization=TextTokenizer(pad_sequences)\n", "dataset_tokenized=tokenization.transform(dataset_resampled)\n", "vocab_size=tokenization.vocab_size\n", "print('Vocab Size:',vocab_size)\n", "dataset_tokenized.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.3 Label Encoding\n", "In this step all the target labels are converted to integer values.We use the LabelEncoder class from the Sklean package.\n", "\n", "### 4.4 One Hot Encoding \n", "In the next step we use the integer values for labels and create a one hot vector to be used for the machine learning analysis." ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [], "source": [ "class LabelOneHotEncoder(BaseEstimator,TransformerMixin):\n", " \"\"\"Transfroms the Categorical data to One Hot vector\"\"\"\n", " def __init__(self):\n", " self.label_encoder=None\n", " self.one_hot_encoder=None\n", " \n", " def transform(self,X,y=None):\n", " self.label_encoder=LabelEncoder().fit(X['label'])\n", " self.one_hot_encoder=to_categorical\n", " num_classes=len(set(X['label']))\n", " X['label_encoded']= self.label_encoder.transform(X['label'].values)\n", " X['label_one_hot']= X['label_encoded'].apply(lambda x: self.one_hot_encoder([x],num_classes=num_classes)[0]) \n", " \n", " return X" ] }, { "cell_type": "code", "execution_count": 124, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tweetlabeltweet_encodedlabel_encodedlabel_one_hot
1zooduringfloodmtnatstechysonstaffmemberspentwe...floods[7554, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...3[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
4findthelatestlocalfloodinformation:assoutheast...floods[634, 824, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...3[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
5floodvictimslookingtogovernmentforhelp-mostins...floods[2366, 7555, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...3[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
7rt911buff::massiveexplosionu/d-localhospitalsn...explosion[41, 1743, 86, 2367, 0, 0, 0, 0, 0, 0, 0, 0, 0...2[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
8caughtoncamera:fertilizerplantexplosionnearwac...explosion[29, 27, 7556, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...2[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
\n", "
" ], "text/plain": [ " tweet label \\\n", "1 zooduringfloodmtnatstechysonstaffmemberspentwe... floods \n", "4 findthelatestlocalfloodinformation:assoutheast... floods \n", "5 floodvictimslookingtogovernmentforhelp-mostins... floods \n", "7 rt911buff::massiveexplosionu/d-localhospitalsn... explosion \n", "8 caughtoncamera:fertilizerplantexplosionnearwac... explosion \n", "\n", " tweet_encoded label_encoded \\\n", "1 [7554, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... 3 \n", "4 [634, 824, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... 3 \n", "5 [2366, 7555, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... 3 \n", "7 [41, 1743, 86, 2367, 0, 0, 0, 0, 0, 0, 0, 0, 0... 2 \n", "8 [29, 27, 7556, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... 2 \n", "\n", " label_one_hot \n", "1 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] \n", "4 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] \n", "5 [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] \n", "7 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] \n", "8 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0] " ] }, "execution_count": 124, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder=LabelOneHotEncoder()\n", "dataset_encoded=encoder.transform(dataset_resampled)\n", "dataset_encoded.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4.3 Word embeddings\n", "\n", "Word Embedding is a representation of text where words that have the same meaning have a similar representation. In other words it represents words in a coordinate system where related words, based on a corpus of relationships, are placed closer together. In the deep learning frameworks such as TensorFlow, Keras, this part is usually handled by an embedding layer which stores a lookup table to map the words represented by numeric indexes to their dense vector representations.[[2]](https://towardsdatascience.com/machine-learning-word-embedding-sentiment-classification-using-keras-b83c28087456)\n", "\n", "Word embeddings can be generated using pre-trained word embeddings such as Glove and Word2Vec. Any one of them can be downloaded and used as transfer learning. In this work we use the Embedding Layer of Keras maps the pre-calculated integers to a dense vector of the embedding.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Train test Data set \n", "In this section we split our data into training and testing datasets.It is important to use a splitting strategy that preserve the percentage of samples for each class.We use the train_test_split tool from the sklean library to achieve this goal.\n", "\n" ] }, { "cell_type": "code", "execution_count": 115, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of Tweets in Training set: 30250\n", "Number of Tweets in Test set: 12965\n" ] } ], "source": [ "X_train,X_test,y_train,y_test =train_test_split(dataset_encoded['tweet_encoded'],dataset_encoded['label_one_hot'],test_size=0.3,stratify=dataset_encoded['label_encoded'])\n", "X_train=np.array(X_train.values.tolist())\n", "X_test=np.array(X_test.values.tolist())\n", "y_train=np.array(y_train.values.tolist())\n", "y_test=np.array(y_test.values.tolist())\n", "print('Number of Tweets in Training set: ',X_train.shape[0])\n", "print('Number of Tweets in Test set: ',X_test.shape[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Modeling \n", "\n", "### 5.1 Model Architecture\n", "For the modeling we will use the Keras's Sequential model API.The Sequential model is essentially a linear stack of layers.We can use different types of available Keras layers in this model.\n" ] }, { "cell_type": "code", "execution_count": 119, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_1\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "embedding (Embedding) (None, 100, 50) 3262300 \n", "_________________________________________________________________\n", "global_max_pooling1d (Global (None, 50) 0 \n", "_________________________________________________________________\n", "dropout (Dropout) (None, 50) 0 \n", "_________________________________________________________________\n", "dense (Dense) (None, 10) 510 \n", "_________________________________________________________________\n", "dropout_1 (Dropout) (None, 10) 0 \n", "_________________________________________________________________\n", "dense_1 (Dense) (None, 7) 77 \n", "=================================================================\n", "Total params: 3,262,887\n", "Trainable params: 3,262,887\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "max_length=100\n", "embeding_dim=50\n", "num_classes=y_train[0].shape[0]\n", "model=Sequential()\n", "model.add(Embedding(input_dim=vocab_size,output_dim=embeding_dim,input_length=max_length))\n", "model.add(GlobalMaxPool1D())\n", "model.add(Dropout(0.3))\n", "model.add(Dense(10,activation='relu'))\n", "model.add(Dropout(0.3))\n", "model.add(Dense(num_classes,activation='softmax'))\n", "model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'] )\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": 121, "metadata": {}, "outputs": [], "source": [ "class PlotLosses(tf.keras.callbacks.Callback):\n", " \"\"\"Simple utility function to plot the model losses during training\"\"\"\n", " def on_train_begin(self, logs={}):\n", " self.i = 0\n", " self.x = []\n", " self.losses = []\n", " self.val_losses = []\n", " \n", " self.fig = plt.figure()\n", " \n", " self.logs = []\n", "\n", " def on_epoch_end(self, epoch, logs={}):\n", " \n", " self.logs.append(logs)\n", " self.x.append(self.i)\n", " self.losses.append(logs.get('loss'))\n", " self.val_losses.append(logs.get('val_loss'))\n", " self.i += 1\n", " \n", " clear_output(wait=True)\n", " plt.plot(self.x, self.losses, label=\"loss\")\n", " plt.plot(self.x, self.val_losses, label=\"val_loss\")\n", " plt.legend()\n", " plt.show();\n", "plot_losses = PlotLosses() \n", "\n", "def save_model(model,save_name):\n", " with open(save_name,'w+') as f:\n", " f.write(model.to_json())\n", " model.save_weights(save_name+'.h5') " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.2 Training " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.fit(X_train,y_train,epochs=2,batch_size=10,verbose=0,validation_data=(X_test,y_test),callbacks=[plot_losses])\n", "save_model(model,'model')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.3 Evaluation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load json and create model\n", "json_file = open('model', 'r')\n", "loaded_model_json = json_file.read()\n", "json_file.close()\n", "loaded_model = model_from_json(loaded_model_json)\n", "# load weights into new model\n", "loaded_model.load_weights(\"model.h5\")\n", "print(\"Loaded model from disk\")\n", " \n", "# evaluate loaded model on test data\n", "loaded_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n", "score = loaded_model.evaluate(X_test, y_test, verbose=0)\n", "print(\"%s: %.2f%%\" % (loaded_model.metrics_names[1], score[1]*100))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5.5 Hyperparameters Optimization" ] }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [], "source": [ "def create_model(dropout, dense_size, vocab_size, embedding_dim, maxlen):\n", " model=Sequential()\n", " model.add(Embedding(input_dim=vocab_size,output_dim=embeding_dim,input_length=max_length))\n", " model.add(GlobalMaxPool1D())\n", " model.add(Dropout(dropout))\n", " model.add(Dense(dense_size,activation='relu'))\n", " model.add(Dropout(dropout))\n", " model.add(Dense(num_classes,activation='softmax'))\n", " model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Main settings\n", "epochs = 5\n", "embedding_dim = 50\n", "maxlen = 100\n", "vocab_size=10000\n", "output_file = 'output.txt'\n", "dense_size=[10, 50,100],\n", "# Parameter grid for grid search\n", "param_grid = dict(dropout=[0.1],\n", " dense_size=[10, 50,100],\n", " vocab_size=[vocab_size],\n", " embedding_dim=[embedding_dim],\n", " maxlen=[maxlen])\n", "model = KerasClassifier(build_fn=create_model,\n", " epochs=epochs, batch_size=10,\n", " verbose=False)\n", "grid = RandomizedSearchCV(estimator=model, param_distributions=param_grid,\n", " cv=4, verbose=1, n_iter=5 ,n_jobs=2)\n", "grid_result = grid.fit(X_train, y_train)\n", "\n", "# Evaluate testing set\n", "test_accuracy = grid.score(X_test, y_test)\n", "# Save and evaluate results\n", "with open(output_file, 'a') as f:\n", " s = ('Best Accuracy : '\n", " '{:.4f}\\n{}\\nTest Accuracy : {:.4f}\\n\\n')\n", " output_string = s.format(\n", " grid_result.best_score_,\n", " grid_result.best_params_,\n", " test_accuracy)\n", " print(output_string)\n", " f.write(output_string) \n", "print('Done')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }