{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ML Pipeline Preparation\n", "Follow the instructions below to help you create your ML pipeline.\n", "### 1. Import libraries and load data from database.\n", "- Import Python libraries\n", "- Load dataset from database with [`read_sql_table`](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.read_sql_table.html)\n", "- Define feature and target variables X and Y" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# import libraries\n", "import pandas as pd\n", "import numpy as np\n", "import pickle\n", "from sqlalchemy import create_engine\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# import NLP libraries\n", "import re\n", "import nltk \n", "from nltk.corpus import stopwords\n", "from nltk.tokenize import word_tokenize\n", "from nltk.stem.wordnet import WordNetLemmatizer\n", "# nltk.download('punkt')\n", "# nltk.download('stopwords')\n", "# nltk.download('wordnet') # download for lemmatization" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# import sklearn\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer\n", "from sklearn.model_selection import train_test_split, GridSearchCV\n", "from sklearn.multioutput import MultiOutputClassifier\n", "from sklearn.metrics import precision_score, recall_score, f1_score\n", "from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# load data from database\n", "engine = create_engine('sqlite:///data/DisasterResponse.db')\n", "df = pd.read_sql_table('DisasterResponse', engine)\n", "X = df['message']\n", "Y = df.drop(['id', 'message', 'original', 'genre'], axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Write a tokenization function to process your text data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def tokenize(text):\n", " # Define url pattern\n", " url_re = 'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\), ]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'\n", " \n", " # Detect and replace urls\n", " detected_urls = re.findall(url_re, text)\n", " for url in detected_urls:\n", " text = text.replace(url, \"urlplaceholder\")\n", " \n", " # tokenize sentences\n", " tokens = word_tokenize(text)\n", " lemmatizer = WordNetLemmatizer()\n", " \n", " # save cleaned tokens\n", " clean_tokens = [lemmatizer.lemmatize(tok).lower().strip() for tok in tokens]\n", " \n", " # remove stopwords\n", " STOPWORDS = list(set(stopwords.words('english')))\n", " clean_tokens = [token for token in clean_tokens if token not in STOPWORDS]\n", " \n", " return clean_tokens" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Build a machine learning pipeline\n", "- You'll find the [MultiOutputClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.multioutput.MultiOutputClassifier.html) helpful for predicting multiple target variables." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def build_pipeline():\n", " \n", " # build NLP pipeline - count words, tf-idf, multiple output classifier\n", " pipeline = Pipeline([\n", " ('vec', CountVectorizer(tokenizer=tokenize)),\n", " ('tfidf', TfidfTransformer()),\n", " ('clf', MultiOutputClassifier(RandomForestClassifier(n_estimators = 100, n_jobs = 6)))\n", " ])\n", " \n", " return pipeline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Train pipeline\n", "- Split data into train and test sets\n", "- Train pipeline" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Pipeline(memory=None,\n", " steps=[('vec', CountVectorizer(analyzer='word', binary=False, decode_error='strict',\n", " dtype=, encoding='utf-8', input='content',\n", " lowercase=True, max_df=1.0, max_features=None, min_df=1,\n", " ngram_range=(1, 1), preprocessor=None, stop_words=None,\n", " strip_..._score=False, random_state=None, verbose=0,\n", " warm_start=False),\n", " n_jobs=None))])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, Y)\n", "pipeline = build_pipeline()\n", "pipeline.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. Test your model\n", "Report the f1 score, precision and recall for each output category of the dataset. You can do this by iterating through the columns and calling sklearn's `classification_report` on each." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def build_report(pipeline, X_test, y_test):\n", " # predict on the X_test\n", " y_pred = pipeline.predict(X_test)\n", " \n", " # build classification report on every column\n", " performances = []\n", " for i in range(len(y_test.columns)):\n", " performances.append([f1_score(y_test.iloc[:, i].values, y_pred[:, i], average='micro'),\n", " precision_score(y_test.iloc[:, i].values, y_pred[:, i], average='micro'),\n", " recall_score(y_test.iloc[:, i].values, y_pred[:, i], average='micro')])\n", " # build dataframe\n", " performances = pd.DataFrame(performances, columns=['f1 score', 'precision', 'recall'],\n", " index = y_test.columns) \n", " return performances" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
f1 scoreprecisionrecall
related0.8016480.8016480.801648
request0.8948730.8948730.894873
offer0.9958800.9958800.995880
aid_related0.7773880.7773880.777388
medical_help0.9205070.9205070.920507
medical_products0.9560570.9560570.956057
search_and_rescue0.9717730.9717730.971773
security0.9823010.9823010.982301
military0.9681110.9681110.968111
child_alone1.0000001.0000001.000000
water0.9581930.9581930.958193
food0.9401890.9401890.940189
shelter0.9350020.9350020.935002
clothing0.9867260.9867260.986726
money0.9786390.9786390.978639
missing_people0.9899300.9899300.989930
refugees0.9687210.9687210.968721
death0.9621610.9621610.962161
other_aid0.8712240.8712240.871224
infrastructure_related0.9346960.9346960.934696
transport0.9537690.9537690.953769
buildings0.9517850.9517850.951785
electricity0.9812330.9812330.981233
tools0.9948120.9948120.994812
hospitals0.9905400.9905400.990540
shops0.9952700.9952700.995270
aid_centers0.9879460.9879460.987946
other_infrastructure0.9556000.9556000.955600
weather_related0.8787000.8787000.878700
floods0.9499540.9499540.949954
storm0.9372900.9372900.937290
fire0.9914560.9914560.991456
earthquake0.9710100.9710100.971010
cold0.9806230.9806230.980623
other_weather0.9469030.9469030.946903
direct_report0.8675620.8675620.867562
\n", "
" ], "text/plain": [ " f1 score precision recall\n", "related 0.801648 0.801648 0.801648\n", "request 0.894873 0.894873 0.894873\n", "offer 0.995880 0.995880 0.995880\n", "aid_related 0.777388 0.777388 0.777388\n", "medical_help 0.920507 0.920507 0.920507\n", "medical_products 0.956057 0.956057 0.956057\n", "search_and_rescue 0.971773 0.971773 0.971773\n", "security 0.982301 0.982301 0.982301\n", "military 0.968111 0.968111 0.968111\n", "child_alone 1.000000 1.000000 1.000000\n", "water 0.958193 0.958193 0.958193\n", "food 0.940189 0.940189 0.940189\n", "shelter 0.935002 0.935002 0.935002\n", "clothing 0.986726 0.986726 0.986726\n", "money 0.978639 0.978639 0.978639\n", "missing_people 0.989930 0.989930 0.989930\n", "refugees 0.968721 0.968721 0.968721\n", "death 0.962161 0.962161 0.962161\n", "other_aid 0.871224 0.871224 0.871224\n", "infrastructure_related 0.934696 0.934696 0.934696\n", "transport 0.953769 0.953769 0.953769\n", "buildings 0.951785 0.951785 0.951785\n", "electricity 0.981233 0.981233 0.981233\n", "tools 0.994812 0.994812 0.994812\n", "hospitals 0.990540 0.990540 0.990540\n", "shops 0.995270 0.995270 0.995270\n", "aid_centers 0.987946 0.987946 0.987946\n", "other_infrastructure 0.955600 0.955600 0.955600\n", "weather_related 0.878700 0.878700 0.878700\n", "floods 0.949954 0.949954 0.949954\n", "storm 0.937290 0.937290 0.937290\n", "fire 0.991456 0.991456 0.991456\n", "earthquake 0.971010 0.971010 0.971010\n", "cold 0.980623 0.980623 0.980623\n", "other_weather 0.946903 0.946903 0.946903\n", "direct_report 0.867562 0.867562 0.867562" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "build_report(pipeline, X_test, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6. Improve your model\n", "Use grid search to find better parameters. " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=5, error_score='raise-deprecating',\n", " estimator=Pipeline(memory=None,\n", " steps=[('vec', CountVectorizer(analyzer='word', binary=False, decode_error='strict',\n", " dtype=, encoding='utf-8', input='content',\n", " lowercase=True, max_df=1.0, max_features=None, min_df=1,\n", " ngram_range=(1, 1), preprocessor=None, stop_words=None,\n", " strip_..._score=False, random_state=None, verbose=0,\n", " warm_start=False),\n", " n_jobs=None))]),\n", " fit_params=None, iid='warn', n_jobs=6,\n", " param_grid={'clf__estimator__max_features': ['sqrt', 0.5], 'clf__estimator__n_estimators': [50, 100]},\n", " pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n", " scoring=None, verbose=0)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "parameters = {'clf__estimator__max_features':['sqrt', 0.5],\n", " 'clf__estimator__n_estimators':[50, 100]}\n", "\n", "cv = GridSearchCV(estimator=pipeline, param_grid = parameters, cv = 5, n_jobs = 6)\n", "cv.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7. Test your model\n", "Show the accuracy, precision, and recall of the tuned model. \n", "\n", "Since this project focuses on code quality, process, and pipelines, there is no minimum performance metric needed to pass. However, make sure to fine tune your models for accuracy, precision and recall to make your project stand out - especially for your portfolio!" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
f1 scoreprecisionrecall
related0.8019530.8019530.801953
request0.8880070.8880070.888007
offer0.9952700.9952700.995270
aid_related0.7653340.7653340.765334
medical_help0.9206590.9206590.920659
medical_products0.9623130.9623130.962313
search_and_rescue0.9705520.9705520.970552
security0.9789440.9789440.978944
military0.9668900.9668900.966890
child_alone1.0000001.0000001.000000
water0.9662800.9662800.966280
food0.9514800.9514800.951480
shelter0.9485810.9485810.948581
clothing0.9890140.9890140.989014
money0.9784860.9784860.978486
missing_people0.9903880.9903880.990388
refugees0.9716200.9716200.971620
death0.9736040.9736040.973604
other_aid0.8686300.8686300.868630
infrastructure_related0.9296610.9296610.929661
transport0.9543790.9543790.954379
buildings0.9563630.9563630.956363
electricity0.9797070.9797070.979707
tools0.9935920.9935920.993592
hospitals0.9884040.9884040.988404
shops0.9943550.9943550.994355
aid_centers0.9877940.9877940.987794
other_infrastructure0.9523950.9523950.952395
weather_related0.8806840.8806840.880684
floods0.9557520.9557520.955752
storm0.9456820.9456820.945682
fire0.9919130.9919130.991913
earthquake0.9731460.9731460.973146
cold0.9830640.9830640.983064
other_weather0.9411050.9411050.941105
direct_report0.8564240.8564240.856424
\n", "
" ], "text/plain": [ " f1 score precision recall\n", "related 0.801953 0.801953 0.801953\n", "request 0.888007 0.888007 0.888007\n", "offer 0.995270 0.995270 0.995270\n", "aid_related 0.765334 0.765334 0.765334\n", "medical_help 0.920659 0.920659 0.920659\n", "medical_products 0.962313 0.962313 0.962313\n", "search_and_rescue 0.970552 0.970552 0.970552\n", "security 0.978944 0.978944 0.978944\n", "military 0.966890 0.966890 0.966890\n", "child_alone 1.000000 1.000000 1.000000\n", "water 0.966280 0.966280 0.966280\n", "food 0.951480 0.951480 0.951480\n", "shelter 0.948581 0.948581 0.948581\n", "clothing 0.989014 0.989014 0.989014\n", "money 0.978486 0.978486 0.978486\n", "missing_people 0.990388 0.990388 0.990388\n", "refugees 0.971620 0.971620 0.971620\n", "death 0.973604 0.973604 0.973604\n", "other_aid 0.868630 0.868630 0.868630\n", "infrastructure_related 0.929661 0.929661 0.929661\n", "transport 0.954379 0.954379 0.954379\n", "buildings 0.956363 0.956363 0.956363\n", "electricity 0.979707 0.979707 0.979707\n", "tools 0.993592 0.993592 0.993592\n", "hospitals 0.988404 0.988404 0.988404\n", "shops 0.994355 0.994355 0.994355\n", "aid_centers 0.987794 0.987794 0.987794\n", "other_infrastructure 0.952395 0.952395 0.952395\n", "weather_related 0.880684 0.880684 0.880684\n", "floods 0.955752 0.955752 0.955752\n", "storm 0.945682 0.945682 0.945682\n", "fire 0.991913 0.991913 0.991913\n", "earthquake 0.973146 0.973146 0.973146\n", "cold 0.983064 0.983064 0.983064\n", "other_weather 0.941105 0.941105 0.941105\n", "direct_report 0.856424 0.856424 0.856424" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "build_report(cv, X_test, y_test)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'clf__estimator__max_features': 0.5, 'clf__estimator__n_estimators': 100}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cv.best_params_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 8. Try improving your model further. Here are a few ideas:\n", "* try other machine learning algorithms\n", "* add other features besides the TF-IDF" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
f1 scoreprecisionrecall
related0.7628930.7628930.762893
request0.8921270.8921270.892127
offer0.9940490.9940490.994049
aid_related0.7673180.7673180.767318
medical_help0.9237110.9237110.923711
medical_products0.9613980.9613980.961398
search_and_rescue0.9707050.9707050.970705
security0.9778760.9778760.977876
military0.9714680.9714680.971468
child_alone1.0000001.0000001.000000
water0.9633810.9633810.963381
food0.9469030.9469030.946903
shelter0.9418680.9418680.941868
clothing0.9879460.9879460.987946
money0.9774180.9774180.977418
missing_people0.9893190.9893190.989319
refugees0.9694840.9694840.969484
death0.9687210.9687210.968721
other_aid0.8687820.8687820.868782
infrastructure_related0.9287460.9287460.928746
transport0.9554470.9554470.955447
buildings0.9548370.9548370.954837
electricity0.9809280.9809280.980928
tools0.9935920.9935920.993592
hospitals0.9874890.9874890.987489
shops0.9940490.9940490.994049
aid_centers0.9867260.9867260.986726
other_infrastructure0.9519380.9519380.951938
weather_related0.8762590.8762590.876259
floods0.9536160.9536160.953616
storm0.9389690.9389690.938969
fire0.9911500.9911500.991150
earthquake0.9707050.9707050.970705
cold0.9818430.9818430.981843
other_weather0.9426300.9426300.942630
direct_report0.8588650.8588650.858865
\n", "
" ], "text/plain": [ " f1 score precision recall\n", "related 0.762893 0.762893 0.762893\n", "request 0.892127 0.892127 0.892127\n", "offer 0.994049 0.994049 0.994049\n", "aid_related 0.767318 0.767318 0.767318\n", "medical_help 0.923711 0.923711 0.923711\n", "medical_products 0.961398 0.961398 0.961398\n", "search_and_rescue 0.970705 0.970705 0.970705\n", "security 0.977876 0.977876 0.977876\n", "military 0.971468 0.971468 0.971468\n", "child_alone 1.000000 1.000000 1.000000\n", "water 0.963381 0.963381 0.963381\n", "food 0.946903 0.946903 0.946903\n", "shelter 0.941868 0.941868 0.941868\n", "clothing 0.987946 0.987946 0.987946\n", "money 0.977418 0.977418 0.977418\n", "missing_people 0.989319 0.989319 0.989319\n", "refugees 0.969484 0.969484 0.969484\n", "death 0.968721 0.968721 0.968721\n", "other_aid 0.868782 0.868782 0.868782\n", "infrastructure_related 0.928746 0.928746 0.928746\n", "transport 0.955447 0.955447 0.955447\n", "buildings 0.954837 0.954837 0.954837\n", "electricity 0.980928 0.980928 0.980928\n", "tools 0.993592 0.993592 0.993592\n", "hospitals 0.987489 0.987489 0.987489\n", "shops 0.994049 0.994049 0.994049\n", "aid_centers 0.986726 0.986726 0.986726\n", "other_infrastructure 0.951938 0.951938 0.951938\n", "weather_related 0.876259 0.876259 0.876259\n", "floods 0.953616 0.953616 0.953616\n", "storm 0.938969 0.938969 0.938969\n", "fire 0.991150 0.991150 0.991150\n", "earthquake 0.970705 0.970705 0.970705\n", "cold 0.981843 0.981843 0.981843\n", "other_weather 0.942630 0.942630 0.942630\n", "direct_report 0.858865 0.858865 0.858865" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline_improved = Pipeline([\n", " ('vect', CountVectorizer(tokenizer=tokenize)),\n", " ('tfidf', TfidfTransformer()),\n", " ('clf', MultiOutputClassifier(AdaBoostClassifier(n_estimators = 100)))\n", " ])\n", "pipeline_improved.fit(X_train, y_train)\n", "y_pred_improved = pipeline_improved.predict(X_test)\n", "build_report(pipeline_improved, X_test, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 9. Export your model as a pickle file" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "pickle.dump(pipeline, open('rf_model.pkl', 'wb'))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "pickle.dump(pipeline_improved, open('adaboost_model.pkl', 'wb'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 10. Use this notebook to complete `train.py`\n", "Use the template file attached in the Resources folder to write a script that runs the steps above to create a database and export a model based on a new dataset specified by the user." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }