{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "import matplotlib.pyplot as plt\n", "import torch\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from torch.utils.data import Dataset, DataLoader" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class DiabetesDataset(Dataset):\n", " \"\"\" Diabetes dataset.\"\"\"\n", "\n", " # Initialize your data, download, etc.\n", " def __init__(self):\n", " data = pd.read_csv(\"diabetes.csv\", header = None)\n", " self.len = 500\n", " self.x_data = torch.from_numpy(data.iloc[:500, 0:-1].values)\n", " self.y_data = torch.from_numpy(data.iloc[:500, -1].values)\n", "\n", " def __getitem__(self, index):\n", " return self.x_data[index], self.y_data[index]\n", "\n", " def __len__(self):\n", " return self.len" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "dataset = DiabetesDataset()\n", "train_loader = DataLoader(dataset=dataset,\n", " batch_size=128,\n", " shuffle=True,\n", " num_workers=2)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "
---|---|---|---|---|---|---|---|---|---|
0 | \n", "-0.294118 | \n", "0.487437 | \n", "0.180328 | \n", "-0.292929 | \n", "0.000000 | \n", "0.001490 | \n", "-0.531170 | \n", "-0.033333 | \n", "0 | \n", "
1 | \n", "-0.882353 | \n", "-0.145729 | \n", "0.081967 | \n", "-0.414141 | \n", "0.000000 | \n", "-0.207153 | \n", "-0.766866 | \n", "-0.666667 | \n", "1 | \n", "
2 | \n", "-0.058824 | \n", "0.839196 | \n", "0.049180 | \n", "0.000000 | \n", "0.000000 | \n", "-0.305514 | \n", "-0.492741 | \n", "-0.633333 | \n", "0 | \n", "
3 | \n", "-0.882353 | \n", "-0.105528 | \n", "0.081967 | \n", "-0.535354 | \n", "-0.777778 | \n", "-0.162444 | \n", "-0.923997 | \n", "0.000000 | \n", "1 | \n", "
4 | \n", "0.000000 | \n", "0.376884 | \n", "-0.344262 | \n", "-0.292929 | \n", "-0.602837 | \n", "0.284650 | \n", "0.887276 | \n", "-0.600000 | \n", "0 | \n", "