{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 3.17 Predicting House Prices on Kaggle" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.17.1 Kaggle" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.17.2 Accessing and Reading Data Sets" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import gluonbook as gb\n", "from mxnet import autograd, gluon, init, nd\n", "import mxnet as mx\n", "from mxnet.gluon import data as gdata, loss as gloss\n", "from mxnet.gluon import nn\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "train_data = pd.read_csv('./kaggle_house_dataset/train.csv')\n", "test_data = pd.read_csv('./kaggle_house_dataset/test.csv')" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1460, 81)\n", "(1459, 80)\n" ] } ], "source": [ "print(train_data.shape)\n", "print(test_data.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Label: SalePrice" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IdMSSubClassMSZoningLotFrontageSaleTypeSaleConditionSalePrice
0160RL65.0WDNormal208500
1220RL80.0WDNormal181500
2360RL68.0WDNormal223500
3470RL60.0WDAbnorml140000
\n", "
" ], "text/plain": [ " Id MSSubClass MSZoning LotFrontage SaleType SaleCondition SalePrice\n", "0 1 60 RL 65.0 WD Normal 208500\n", "1 2 20 RL 80.0 WD Normal 181500\n", "2 3 60 RL 68.0 WD Normal 223500\n", "3 4 70 RL 60.0 WD Abnorml 140000" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_data.iloc[0:4, [0, 1, 2, 3, -3, -2, -1]]" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IdMSSubClassMSZoningLotFrontageYrSoldSaleTypeSaleCondition
0146120RH80.02010WDNormal
1146220RL81.02010WDNormal
2146360RL74.02010WDNormal
3146460RL78.02010WDNormal
\n", "
" ], "text/plain": [ " Id MSSubClass MSZoning LotFrontage YrSold SaleType SaleCondition\n", "0 1461 20 RH 80.0 2010 WD Normal\n", "1 1462 20 RL 81.0 2010 WD Normal\n", "2 1463 60 RL 74.0 2010 WD Normal\n", "3 1464 60 RL 78.0 2010 WD Normal" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_data.iloc[0:4, [0, 1, 2, 3, -3, -2, -1]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We remove the 'Id' column from the dataset before feeding the data into the network." ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:]))" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
MSSubClassMSZoningLotFrontageLotAreaYrSoldSaleTypeSaleCondition
060RL65.084502008WDNormal
120RL80.096002007WDNormal
260RL68.0112502008WDNormal
370RL60.095502006WDAbnorml
\n", "
" ], "text/plain": [ " MSSubClass MSZoning LotFrontage LotArea YrSold SaleType SaleCondition\n", "0 60 RL 65.0 8450 2008 WD Normal\n", "1 20 RL 80.0 9600 2007 WD Normal\n", "2 60 RL 68.0 11250 2008 WD Normal\n", "3 70 RL 60.0 9550 2006 WD Abnorml" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_features.iloc[0:4, [0, 1, 2, 3, -3, -2, -1]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.17.3 Data Preprocessing" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MSSubClass int64\n", "MSZoning object\n", "LotFrontage float64\n", "LotArea int64\n", "Street object\n", "Alley object\n", "LotShape object\n", "LandContour object\n", "Utilities object\n", "LotConfig object\n", "LandSlope object\n", "Neighborhood object\n", "Condition1 object\n", "Condition2 object\n", "BldgType object\n", "HouseStyle object\n", "OverallQual int64\n", "OverallCond int64\n", "YearBuilt int64\n", "YearRemodAdd int64\n", "RoofStyle object\n", "RoofMatl object\n", "Exterior1st object\n", "Exterior2nd object\n", "MasVnrType object\n", "MasVnrArea float64\n", "ExterQual object\n", "ExterCond object\n", "Foundation object\n", "BsmtQual object\n", " ... \n", "HalfBath int64\n", "BedroomAbvGr int64\n", "KitchenAbvGr int64\n", "KitchenQual object\n", "TotRmsAbvGrd int64\n", "Functional object\n", "Fireplaces int64\n", "FireplaceQu object\n", "GarageType object\n", "GarageYrBlt float64\n", "GarageFinish object\n", "GarageCars float64\n", "GarageArea float64\n", "GarageQual object\n", "GarageCond object\n", "PavedDrive object\n", "WoodDeckSF int64\n", "OpenPorchSF int64\n", "EnclosedPorch int64\n", "3SsnPorch int64\n", "ScreenPorch int64\n", "PoolArea int64\n", "PoolQC object\n", "Fence object\n", "MiscFeature object\n", "MiscVal int64\n", "MoSold int64\n", "YrSold int64\n", "SaleType object\n", "SaleCondition object\n", "Length: 79, dtype: object" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_features.dtypes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Numerical features\n", " - 1) normalization\n", " - 2) fill the missing features with zero" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index\n", "all_features[numeric_features] = all_features[numeric_features].apply(lambda x: (x - x.mean()) / (x.std()))\n", "# after standardizing the data all means vanish, hence we can set missing values to 0\n", "all_features = all_features.fillna(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Discrete features\n", " - transform each feature into one-hot vector\n", " - this conversion increases the number of features from 79 to 354. " ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2919, 354)" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Dummy_na=True refers to a missing value being a legal eigenvalue, and creates 􏰀→ anindicativefeatureforit.\n", "all_features = pd.get_dummies(all_features, dummy_na=True)\n", "all_features.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We can extract the NumPy format from the Pandas dataframe\n", "- Then, convert it into MXNet’s native representation - NDArray for training." ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "n_train = train_data.shape[0]\n", "train_features = nd.array(all_features[:n_train].values)\n", "test_features = nd.array(all_features[n_train:].values)\n", "train_labels = nd.array(train_data.SalePrice.values).reshape((-1, 1))" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1460, 354)\n", "(1460, 1)\n", "(1459, 354)\n" ] } ], "source": [ "print(train_features.shape)\n", "print(train_labels.shape)\n", "print(test_features.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.17.4 Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- linear model with squared loss" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "loss = gloss.L2Loss()\n", "\n", "def get_net(model_type='linear'):\n", " if model_type == 'linear': \n", " net = nn.Sequential()\n", " net.add(nn.Dense(1))\n", " net.initialize()\n", " else:\n", " net = nn.Sequential()\n", " net.add(\n", " nn.Dense(128, activation=\"relu\"),\n", " nn.Dropout(0.5), \n", " nn.Dense(48, activation=\"relu\"),\n", " nn.Dropout(0.25),\n", " nn.Dense(1)\n", " )\n", " net.initialize(mx.init.Xavier())\n", " return net" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Getting a house price wrong by USD 100,000 is terrible in Rural Ohio\n", " - In Rural Ohio, the value of the house is USD 125,000. \n", "- In Los Altos Hills, California, we can be proud of the accuracy of our model \n", " - The median house price there exceeds USD 4,000,000.\n", "- One way to address this problem is to measure the discrepancy in the logarithm of the price estimates. \n", "- $|\\log y - \\log \\hat{y}| < \\delta $ translates into $e^{-\\delta} \\leq \\frac{\\hat{y}}{y} \\leq e^\\delta$. \n", "- This leads to the following loss function:$$L = \\sqrt{\\frac{1}{n}\\sum_{i=1}^n\\left(\\log y_i -\\log \\hat{y}_i\\right)^2}$$" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [], "source": [ "def log_rmse(net, features, labels):\n", " # To further stabilize the value when the logarithm is taken, set the value less than 1 as 1.\n", " clipped_preds = nd.clip(\n", " data=net(features), \n", " a_min=1, \n", " a_max=float('inf')\n", " )\n", " rmse = nd.sqrt(2 * loss(clipped_preds.log(), labels.log()).mean())\n", " return rmse.asscalar()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- The Adam optimization algorithm is relatively less sensitive to learning rates." ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "def train(net, train_features, train_labels, test_features, test_labels, \n", " num_epochs, learning_rate, weight_decay, batch_size):\n", " train_ls, test_ls = [], []\n", " train_iter = gdata.DataLoader(\n", " dataset=gdata.ArrayDataset(train_features, train_labels), \n", " batch_size=batch_size, \n", " shuffle=True\n", " )\n", " \n", " # The Adam optimization algorithm is used here.\n", " trainer = gluon.Trainer(\n", " params=net.collect_params(), \n", " optimizer='adam', \n", " optimizer_params={'learning_rate': learning_rate, 'wd': weight_decay}\n", " )\n", " \n", " for epoch in range(num_epochs):\n", " for X, y in train_iter:\n", " with autograd.record():\n", " l = loss(net(X), y)\n", " l.backward()\n", " trainer.step(batch_size)\n", " train_ls.append(log_rmse(net, train_features, train_labels))\n", " if test_labels is not None:\n", " test_ls.append(log_rmse(net, test_features, test_labels))\n", " return train_ls, test_ls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.17.5 k-Fold Cross-Validation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We first need a function that returns the i-th fold of the data in a k-fold cros-validation procedure. \n", " - It proceeds by slicing out the i-th segment as validation data and returning the rest as training data. " ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [], "source": [ "def get_k_fold_data(k, i, X, y):\n", " assert k > 1\n", " fold_size = X.shape[0] // k # floor division\n", " X_train, y_train = None, None\n", " for j in range(k):\n", " idx = slice(j * fold_size, (j + 1) * fold_size)\n", " X_part, y_part = X[idx, :], y[idx]\n", " if j == i:\n", " X_valid, y_valid = X_part, y_part\n", " elif X_train is None:\n", " X_train, y_train = X_part, y_part\n", " else:\n", " X_train = nd.concat(X_train, X_part, dim=0)\n", " y_train = nd.concat(y_train, y_part, dim=0)\n", " return X_train, y_train, X_valid, y_valid" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- The training and verification error averages are returned when we train $k$ times in the k-fold cross-validation." ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "def k_fold(k, X_train, y_train, num_epochs, learning_rate, weight_decay, batch_size):\n", " train_l_sum, valid_l_sum = 0, 0\n", " for i in range(k):\n", " data = get_k_fold_data(k, i, X_train, y_train)\n", " net = get_net(model_type=\"deep\")\n", " train_ls, valid_ls = train(net, *data, num_epochs, learning_rate, weight_decay, batch_size)\n", " train_l_sum += train_ls[-1]\n", " valid_l_sum += valid_ls[-1]\n", " if i == 0:\n", " gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'rmse',\n", " range(1, num_epochs + 1), valid_ls,\n", " ['train', 'valid'])\n", " print('fold %d, train rmse: %f, valid rmse: %f' % (i, train_ls[-1], valid_ls[-1]))\n", " return train_l_sum / k, valid_l_sum / k" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.17.6 Model Selection\n", "- We pick a rather un-tuned set of hyperparameters and leave it up to the reader to improve the model considerably." ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "fold 0, train rmse: 0.148676, valid rmse: 0.193381\n", "fold 1, train rmse: 0.198471, valid rmse: 0.242408\n", "fold 2, train rmse: 0.257584, valid rmse: 0.299181\n", "fold 3, train rmse: 0.274117, valid rmse: 0.271172\n", "fold 4, train rmse: 0.215440, valid rmse: 0.250864\n", "5-fold validation: avg train rmse: 0.218858, avg valid rmse: 0.251401\n" ] } ], "source": [ "k = 5\n", "num_epochs = 100\n", "lr = 0.5\n", "weight_decay = 5\n", "batch_size = 64\n", "\n", "train_l, valid_l = k_fold(k, train_features, train_labels, num_epochs, lr, weight_decay, batch_size)\n", "print('%d-fold validation: avg train rmse: %f, avg valid rmse: %f' % (k, train_l, valid_l))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.17.7 Predict and Submit" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Now that we know what a good choice of hyperparameters should be, we use all the data to train on it " ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [], "source": [ "def train_and_pred(train_features, test_feature, train_labels, test_data,\n", " num_epochs, lr, weight_decay, batch_size):\n", " net = get_net(model_type=\"deep\")\n", " train_ls, _ = train(\n", " net=net, \n", " train_features=train_features, \n", " train_labels=train_labels, \n", " test_features=None, \n", " test_labels=None, \n", " num_epochs=num_epochs, \n", " learning_rate=lr, \n", " weight_decay=weight_decay, \n", " batch_size=batch_size\n", " )\n", " gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'rmse')\n", " print('train rmse %f' % train_ls[-1])\n", " \n", " # apply the network to the test set\n", " preds = net(test_features).asnumpy()\n", " \n", " # reformat it for export to Kaggle\n", " test_data['SalePrice'] = pd.Series(preds.reshape(1, -1)[0])\n", " submission = pd.concat([test_data['Id'], test_data['SalePrice']], axis=1)\n", " submission.to_csv('submission2.csv', index=False)" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "train rmse 0.489604\n" ] } ], "source": [ "train_and_pred(train_features, test_features, train_labels, test_data, num_epochs, lr, weight_decay, batch_size)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }