{ "cells": [ { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import mxnet as mx\n", "from mxnet import gluon, autograd, ndarray\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import mean_squared_error" ] }, { "cell_type": "code", "execution_count": 102, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_boston\n", "data = load_boston()\n", "df = pd.DataFrame(data.data, columns=data.feature_names)\n", "y = data.target" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Normalize data for the NN" ] }, { "cell_type": "code", "execution_count": 103, "metadata": {}, "outputs": [], "source": [ "df_norm = (df - df.mean()) / (df.max() - df.min())" ] }, { "cell_type": "code", "execution_count": 104, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | CRIM | \n", "ZN | \n", "INDUS | \n", "CHAS | \n", "NOX | \n", "RM | \n", "AGE | \n", "DIS | \n", "RAD | \n", "TAX | \n", "PTRATIO | \n", "B | \n", "LSTAT | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "-0.040322 | \n", "0.066364 | \n", "-0.323562 | \n", "-0.06917 | \n", "-0.034352 | \n", "0.055636 | \n", "-0.034757 | \n", "0.026822 | \n", "-0.371713 | \n", "-0.214193 | \n", "-0.335695 | \n", "0.101432 | \n", "-0.211729 | \n", "
1 | \n", "-0.040086 | \n", "-0.113636 | \n", "-0.149075 | \n", "-0.06917 | \n", "-0.176327 | \n", "0.026129 | \n", "0.106335 | \n", "0.106581 | \n", "-0.328235 | \n", "-0.317246 | \n", "-0.069738 | \n", "0.101432 | \n", "-0.096939 | \n", "
2 | \n", "-0.040086 | \n", "-0.113636 | \n", "-0.149075 | \n", "-0.06917 | \n", "-0.176327 | \n", "0.172517 | \n", "-0.076981 | \n", "0.106581 | \n", "-0.328235 | \n", "-0.317246 | \n", "-0.069738 | \n", "0.091169 | \n", "-0.237943 | \n", "
3 | \n", "-0.040029 | \n", "-0.113636 | \n", "-0.328328 | \n", "-0.06917 | \n", "-0.198961 | \n", "0.136686 | \n", "-0.234551 | \n", "0.206163 | \n", "-0.284757 | \n", "-0.355414 | \n", "0.026007 | \n", "0.095708 | \n", "-0.268021 | \n", "
4 | \n", "-0.039617 | \n", "-0.113636 | \n", "-0.328328 | \n", "-0.06917 | \n", "-0.198961 | \n", "0.165236 | \n", "-0.148042 | \n", "0.206163 | \n", "-0.284757 | \n", "-0.355414 | \n", "0.026007 | \n", "0.101432 | \n", "-0.202071 | \n", "