{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 305 Batch Train\n", "\n", "View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/\n", "My Youtube Channel: https://www.youtube.com/user/MorvanZhou\n", "\n", "Dependencies:\n", "* torch: 0.1.11" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.utils.data as Data\n", "\n", "torch.manual_seed(1) # reproducible" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "BATCH_SIZE = 5\n", "# BATCH_SIZE = 8" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "x = torch.linspace(1, 10, 10) # this is x data (torch tensor)\n", "y = torch.linspace(10, 1, 10) # this is y data (torch tensor)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)\n", "loader = Data.DataLoader(\n", " dataset=torch_dataset, # torch TensorDataset format\n", " batch_size=BATCH_SIZE, # mini batch size\n", " shuffle=True, # random shuffle for training\n", " num_workers=2, # subprocesses for loading data\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0 | Step: 0 | batch x: [ 6. 7. 2. 3. 1.] | batch y: [ 5. 4. 9. 8. 10.]\n", "Epoch: 0 | Step: 1 | batch x: [ 9. 10. 4. 8. 5.] | batch y: [ 2. 1. 7. 3. 6.]\n", "Epoch: 1 | Step: 0 | batch x: [ 3. 4. 2. 9. 10.] | batch y: [ 8. 7. 9. 2. 1.]\n", "Epoch: 1 | Step: 1 | batch x: [ 1. 7. 8. 5. 6.] | batch y: [ 10. 4. 3. 6. 5.]\n", "Epoch: 2 | Step: 0 | batch x: [ 3. 9. 2. 6. 7.] | batch y: [ 8. 2. 9. 5. 4.]\n", "Epoch: 2 | Step: 1 | batch x: [ 10. 4. 8. 1. 5.] | batch y: [ 1. 7. 3. 10. 6.]\n" ] } ], "source": [ "for epoch in range(3): # train entire dataset 3 times\n", " for step, (batch_x, batch_y) in enumerate(loader): # for each training step\n", " # train your data...\n", " print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',\n", " batch_x.numpy(), '| batch y: ', batch_y.numpy())\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Suppose a different batch size that cannot be fully divided by the number of data entreis:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0 | Step: 0 | batch x: [ 3. 10. 9. 4. 7. 8. 2. 1.] | batch y: [ 8. 1. 2. 7. 4. 3. 9. 10.]\n", "Epoch: 0 | Step: 1 | batch x: [ 5. 6.] | batch y: [ 6. 5.]\n", "Epoch: 1 | Step: 0 | batch x: [ 4. 8. 3. 2. 1. 10. 5. 6.] | batch y: [ 7. 3. 8. 9. 10. 1. 6. 5.]\n", "Epoch: 1 | Step: 1 | batch x: [ 7. 9.] | batch y: [ 4. 2.]\n", "Epoch: 2 | Step: 0 | batch x: [ 6. 2. 4. 10. 9. 3. 8. 5.] | batch y: [ 5. 9. 7. 1. 2. 8. 3. 6.]\n", "Epoch: 2 | Step: 1 | batch x: [ 7. 1.] | batch y: [ 4. 10.]\n" ] } ], "source": [ "BATCH_SIZE = 8\n", "loader = Data.DataLoader(\n", " dataset=torch_dataset, # torch TensorDataset format\n", " batch_size=BATCH_SIZE, # mini batch size\n", " shuffle=True, # random shuffle for training\n", " num_workers=2, # subprocesses for loading data\n", ")\n", "for epoch in range(3): # train entire dataset 3 times\n", " for step, (batch_x, batch_y) in enumerate(loader): # for each training step\n", " # train your data...\n", " print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',\n", " batch_x.numpy(), '| batch y: ', batch_y.numpy())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }