{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "taruma_udemy_boltzmann.ipynb",
"version": "0.3.2",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "n7Oy73RcZei5",
"colab_type": "text"
},
"source": [
"# Boltzmann Machines\n",
"\n",
"Notebook ini berdasarkan kursus __Deep Learning A-Z™: Hands-On Artificial Neural Networks__ di Udemy. [Lihat Kursus](https://www.udemy.com/deeplearning/).\n",
"\n",
"## Informasi Notebook\n",
"- __notebook name__: `taruma_udemy_boltzmann`\n",
"- __notebook version/date__: `1.0.0`/`20190730`\n",
"- __notebook server__: Google Colab\n",
"- __python version__: `3.6`\n",
"- __pytorch version__: `1.1.0`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "0y9V9qyUZeiH",
"colab_type": "code",
"outputId": "06aa828f-f253-4a73-f66d-c8f53b5780bb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
}
},
"source": [
"#### NOTEBOOK DESCRIPTION\n",
"\n",
"from datetime import datetime\n",
"\n",
"NOTEBOOK_TITLE = 'taruma_udemy_boltzmann'\n",
"NOTEBOOK_VERSION = '1.0.0'\n",
"NOTEBOOK_DATE = 1 # Set 1, if you want add date classifier\n",
"\n",
"NOTEBOOK_NAME = \"{}_{}\".format(\n",
" NOTEBOOK_TITLE, \n",
" NOTEBOOK_VERSION.replace('.','_')\n",
")\n",
"PROJECT_NAME = \"{}_{}{}\".format(\n",
" NOTEBOOK_TITLE, \n",
" NOTEBOOK_VERSION.replace('.','_'), \n",
" \"_\" + datetime.utcnow().strftime(\"%Y%m%d_%H%M\") if NOTEBOOK_DATE else \"\"\n",
")\n",
"\n",
"print(f\"Nama Notebook: {NOTEBOOK_NAME}\")\n",
"print(f\"Nama Proyek: {PROJECT_NAME}\")"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Nama Notebook: taruma_udemy_boltzmann_1_0_0\n",
"Nama Proyek: taruma_udemy_boltzmann_1_0_0_20190730_0822\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "1euCSTADZlh3",
"colab_type": "code",
"outputId": "79926104-9cd1-43d6-8ca6-35caf0003b2b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"#### System Version\n",
"import sys, torch\n",
"print(\"versi python: {}\".format(sys.version))\n",
"print(\"versi pytorch: {}\".format(torch.__version__))"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"versi python: 3.6.8 (default, Jan 14 2019, 11:02:34) \n",
"[GCC 8.0.1 20180414 (experimental) [trunk revision 259383]]\n",
"versi pytorch: 1.1.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "s0qrkxTVZj_P",
"colab_type": "code",
"colab": {}
},
"source": [
"#### Load Notebook Extensions\n",
"%load_ext google.colab.data_table"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "S8l7WZ0_ZmCK",
"colab_type": "code",
"colab": {}
},
"source": [
"#### Download dataset\n",
"# ref: https://grouplens.org/datasets/movielens/\n",
"!wget -O boltzmann.zip \"https://sds-platform-private.s3-us-east-2.amazonaws.com/uploads/P16-Boltzmann-Machines.zip\"\n",
"!unzip boltzmann.zip"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jXTtTiLAZmBu",
"colab_type": "code",
"colab": {}
},
"source": [
"#### Atur dataset path\n",
"DATASET_DIRECTORY = 'Boltzmann_Machines/'"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3cDBJSwAknug",
"colab_type": "code",
"colab": {}
},
"source": [
"def showdata(dataframe):\n",
" print('Dataframe Size: {}'.format(dataframe.shape))\n",
" return dataframe"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hqE6ozW8e0ra",
"colab_type": "text"
},
"source": [
"# STEP 1-5 DATA PREPROCESSING"
]
},
{
"cell_type": "code",
"metadata": {
"id": "fLvxd5pQdTQq",
"colab_type": "code",
"colab": {}
},
"source": [
"# Importing the libraries\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.parallel\n",
"import torch.optim as optim\n",
"import torch.utils.data\n",
"from torch.autograd import Variable"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lFQEACh4fJLp",
"colab_type": "code",
"outputId": "7aeb4f68-2294-439f-bc39-fd0b73e70de9",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 309
}
},
"source": [
"movies = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/movies.dat', sep='::', header=None, engine='python', encoding='latin-1')\n",
"showdata(movies).head(10)"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"Dataframe Size: (3883, 3)\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"Toy Story (1995)\",\n\"Animation|Children's|Comedy\"],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2,\n 'f': \"2\",\n },\n\"Jumanji (1995)\",\n\"Adventure|Children's|Fantasy\"],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n\"Grumpier Old Men (1995)\",\n\"Comedy|Romance\"],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n\"Waiting to Exhale (1995)\",\n\"Comedy|Drama\"],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n\"Father of the Bride Part II (1995)\",\n\"Comedy\"],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 6,\n 'f': \"6\",\n },\n\"Heat (1995)\",\n\"Action|Crime|Thriller\"],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"Sabrina (1995)\",\n\"Comedy|Romance\"],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 8,\n 'f': \"8\",\n },\n\"Tom and Huck (1995)\",\n\"Adventure|Children's\"],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"Sudden Death (1995)\",\n\"Action\"],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"GoldenEye (1995)\",\n\"Action|Adventure|Thriller\"]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"string\", \"1\"], [\"string\", \"2\"]],\n rowsPerPage: 25,\n });\n ",
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" Toy Story (1995) | \n",
" Animation|Children's|Comedy | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" Jumanji (1995) | \n",
" Adventure|Children's|Fantasy | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" Grumpier Old Men (1995) | \n",
" Comedy|Romance | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" Waiting to Exhale (1995) | \n",
" Comedy|Drama | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" Father of the Bride Part II (1995) | \n",
" Comedy | \n",
"
\n",
" \n",
" 5 | \n",
" 6 | \n",
" Heat (1995) | \n",
" Action|Crime|Thriller | \n",
"
\n",
" \n",
" 6 | \n",
" 7 | \n",
" Sabrina (1995) | \n",
" Comedy|Romance | \n",
"
\n",
" \n",
" 7 | \n",
" 8 | \n",
" Tom and Huck (1995) | \n",
" Adventure|Children's | \n",
"
\n",
" \n",
" 8 | \n",
" 9 | \n",
" Sudden Death (1995) | \n",
" Action | \n",
"
\n",
" \n",
" 9 | \n",
" 10 | \n",
" GoldenEye (1995) | \n",
" Action|Adventure|Thriller | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 0 1 2\n",
"0 1 Toy Story (1995) Animation|Children's|Comedy\n",
"1 2 Jumanji (1995) Adventure|Children's|Fantasy\n",
"2 3 Grumpier Old Men (1995) Comedy|Romance\n",
"3 4 Waiting to Exhale (1995) Comedy|Drama\n",
"4 5 Father of the Bride Part II (1995) Comedy\n",
"5 6 Heat (1995) Action|Crime|Thriller\n",
"6 7 Sabrina (1995) Comedy|Romance\n",
"7 8 Tom and Huck (1995) Adventure|Children's\n",
"8 9 Sudden Death (1995) Action\n",
"9 10 GoldenEye (1995) Action|Adventure|Thriller"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dgllaEzifthq",
"colab_type": "code",
"outputId": "6e9eee4f-859c-44bf-d906-c72e9289ee93",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 309
}
},
"source": [
"users = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/users.dat', sep='::', header=None, engine='python', encoding='latin-1')\n",
"showdata(users).head(10)"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"Dataframe Size: (6040, 5)\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"F\",\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"48067\"],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2,\n 'f': \"2\",\n },\n\"M\",\n{\n 'v': 56,\n 'f': \"56\",\n },\n{\n 'v': 16,\n 'f': \"16\",\n },\n\"70072\"],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 15,\n 'f': \"15\",\n },\n\"55117\"],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n\"M\",\n{\n 'v': 45,\n 'f': \"45\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"02460\"],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 20,\n 'f': \"20\",\n },\n\"55455\"],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 6,\n 'f': \"6\",\n },\n\"F\",\n{\n 'v': 50,\n 'f': \"50\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"55117\"],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"M\",\n{\n 'v': 35,\n 'f': \"35\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"06810\"],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 8,\n 'f': \"8\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 12,\n 'f': \"12\",\n },\n\"11413\"],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 17,\n 'f': \"17\",\n },\n\"61614\"],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"F\",\n{\n 'v': 35,\n 'f': \"35\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"95370\"]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"string\", \"1\"], [\"number\", \"2\"], [\"number\", \"3\"], [\"string\", \"4\"]],\n rowsPerPage: 25,\n });\n ",
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
" 4 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" F | \n",
" 1 | \n",
" 10 | \n",
" 48067 | \n",
"
\n",
" \n",
" 1 | \n",
" 2 | \n",
" M | \n",
" 56 | \n",
" 16 | \n",
" 70072 | \n",
"
\n",
" \n",
" 2 | \n",
" 3 | \n",
" M | \n",
" 25 | \n",
" 15 | \n",
" 55117 | \n",
"
\n",
" \n",
" 3 | \n",
" 4 | \n",
" M | \n",
" 45 | \n",
" 7 | \n",
" 02460 | \n",
"
\n",
" \n",
" 4 | \n",
" 5 | \n",
" M | \n",
" 25 | \n",
" 20 | \n",
" 55455 | \n",
"
\n",
" \n",
" 5 | \n",
" 6 | \n",
" F | \n",
" 50 | \n",
" 9 | \n",
" 55117 | \n",
"
\n",
" \n",
" 6 | \n",
" 7 | \n",
" M | \n",
" 35 | \n",
" 1 | \n",
" 06810 | \n",
"
\n",
" \n",
" 7 | \n",
" 8 | \n",
" M | \n",
" 25 | \n",
" 12 | \n",
" 11413 | \n",
"
\n",
" \n",
" 8 | \n",
" 9 | \n",
" M | \n",
" 25 | \n",
" 17 | \n",
" 61614 | \n",
"
\n",
" \n",
" 9 | \n",
" 10 | \n",
" F | \n",
" 35 | \n",
" 1 | \n",
" 95370 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 0 1 2 3 4\n",
"0 1 F 1 10 48067\n",
"1 2 M 56 16 70072\n",
"2 3 M 25 15 55117\n",
"3 4 M 45 7 02460\n",
"4 5 M 25 20 55455\n",
"5 6 F 50 9 55117\n",
"6 7 M 35 1 06810\n",
"7 8 M 25 12 11413\n",
"8 9 M 25 17 61614\n",
"9 10 F 35 1 95370"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "anc5Oi1sgDzc",
"colab_type": "code",
"outputId": "29b8bcdd-4562-40d0-ed3e-c179f86b9ed7",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 309
}
},
"source": [
"ratings = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/ratings.dat', sep='::', header=None, engine='python', encoding='latin-1')\n",
"showdata(ratings).head(10)"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": [
"Dataframe Size: (1000209, 4)\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1193,\n 'f': \"1193\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978300760,\n 'f': \"978300760\",\n }],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 661,\n 'f': \"661\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978302109,\n 'f': \"978302109\",\n }],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 914,\n 'f': \"914\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978301968,\n 'f': \"978301968\",\n }],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 3408,\n 'f': \"3408\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978300275,\n 'f': \"978300275\",\n }],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2355,\n 'f': \"2355\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978824291,\n 'f': \"978824291\",\n }],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1197,\n 'f': \"1197\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978302268,\n 'f': \"978302268\",\n }],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1287,\n 'f': \"1287\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978302039,\n 'f': \"978302039\",\n }],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2804,\n 'f': \"2804\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978300719,\n 'f': \"978300719\",\n }],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 594,\n 'f': \"594\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978302268,\n 'f': \"978302268\",\n }],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 919,\n 'f': \"919\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978301368,\n 'f': \"978301368\",\n }]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"number\", \"1\"], [\"number\", \"2\"], [\"number\", \"3\"]],\n rowsPerPage: 25,\n });\n ",
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" 0 | \n",
" 1 | \n",
" 2 | \n",
" 3 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1 | \n",
" 1193 | \n",
" 5 | \n",
" 978300760 | \n",
"
\n",
" \n",
" 1 | \n",
" 1 | \n",
" 661 | \n",
" 3 | \n",
" 978302109 | \n",
"
\n",
" \n",
" 2 | \n",
" 1 | \n",
" 914 | \n",
" 3 | \n",
" 978301968 | \n",
"
\n",
" \n",
" 3 | \n",
" 1 | \n",
" 3408 | \n",
" 4 | \n",
" 978300275 | \n",
"
\n",
" \n",
" 4 | \n",
" 1 | \n",
" 2355 | \n",
" 5 | \n",
" 978824291 | \n",
"
\n",
" \n",
" 5 | \n",
" 1 | \n",
" 1197 | \n",
" 3 | \n",
" 978302268 | \n",
"
\n",
" \n",
" 6 | \n",
" 1 | \n",
" 1287 | \n",
" 5 | \n",
" 978302039 | \n",
"
\n",
" \n",
" 7 | \n",
" 1 | \n",
" 2804 | \n",
" 5 | \n",
" 978300719 | \n",
"
\n",
" \n",
" 8 | \n",
" 1 | \n",
" 594 | \n",
" 4 | \n",
" 978302268 | \n",
"
\n",
" \n",
" 9 | \n",
" 1 | \n",
" 919 | \n",
" 4 | \n",
" 978301368 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 0 1 2 3\n",
"0 1 1193 5 978300760\n",
"1 1 661 3 978302109\n",
"2 1 914 3 978301968\n",
"3 1 3408 4 978300275\n",
"4 1 2355 5 978824291\n",
"5 1 1197 3 978302268\n",
"6 1 1287 5 978302039\n",
"7 1 2804 5 978300719\n",
"8 1 594 4 978302268\n",
"9 1 919 4 978301368"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "xU2S9y8NgRPW",
"colab_type": "code",
"colab": {}
},
"source": [
"# Preparing the training set and the test set\n",
"training_set = pd.read_csv(DATASET_DIRECTORY + 'ml-100k/u1.base', delimiter='\\t')\n",
"training_set = np.array(training_set, dtype='int')\n",
"test_set = pd.read_csv(DATASET_DIRECTORY + 'ml-100k/u1.test', delimiter='\\t')\n",
"test_set = np.array(test_set, dtype='int')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "G7WbrddJl3Q9",
"colab_type": "code",
"colab": {}
},
"source": [
"# Getting the number of users and movies\n",
"nb_users = int(max(max(training_set[:, 0]), max(test_set[:, 0])))\n",
"nb_movies = int(max(max(training_set[:, 1]), max(test_set[:, 1])))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "yRUTR_K3_rzP",
"colab_type": "code",
"colab": {}
},
"source": [
"# Converting the data into an array with users in lines and movies in columns\n",
"def convert(data):\n",
" new_data = []\n",
" for id_users in range(1, nb_users+1):\n",
" id_movies = data[:, 1][data[:, 0] == id_users]\n",
" id_ratings = data[:, 2][data[:, 0] == id_users]\n",
" ratings = np.zeros(nb_movies)\n",
" ratings[id_movies - 1] = id_ratings\n",
" new_data.append(list(ratings))\n",
" return new_data\n",
"\n",
"training_set = convert(training_set)\n",
"test_set = convert(test_set)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "u0Fk8Q0YCNZr",
"colab_type": "code",
"colab": {}
},
"source": [
"# Converting the data into Torch tensors\n",
"training_set = torch.FloatTensor(training_set)\n",
"test_set = torch.FloatTensor(test_set)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "q2arIJufDBYd",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 136
},
"outputId": "6f90b00a-799a-4db0-f1ac-82241e375329"
},
"source": [
"training_set."
],
"execution_count": 25,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0., 3., 4., ..., 0., 0., 0.],\n",
" [4., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [5., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 5., 0., ..., 0., 0., 0.]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 25
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wwWlLU5ODGgF",
"colab_type": "text"
},
"source": [
"# STEP 6"
]
},
{
"cell_type": "code",
"metadata": {
"id": "m0KbLrFZDCjK",
"colab_type": "code",
"colab": {}
},
"source": [
"# Converting the ratings into binary ratings 1 (Liked) or 0 (Not Liked)\n",
"training_set[training_set == 0] = -1\n",
"training_set[training_set == 1] = 0\n",
"training_set[training_set == 2] = 0\n",
"training_set[training_set >= 3] = 1\n",
"\n",
"test_set[test_set == 0] = -1\n",
"test_set[test_set == 1] = 0\n",
"test_set[test_set == 2] = 0\n",
"test_set[test_set >= 3] = 1"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zHEwVvTlD-DK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 136
},
"outputId": "08d4e08c-6c99-4fa5-e528-ca995f016ed2"
},
"source": [
"training_set"
],
"execution_count": 27,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[-1., 1., 1., ..., -1., -1., -1.],\n",
" [ 1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" ...,\n",
" [ 1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.],\n",
" [-1., 1., -1., ..., -1., -1., -1.]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 27
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a4YrUGpDEFEV",
"colab_type": "text"
},
"source": [
"# STEP 7 - 10 Building RBM Object"
]
},
{
"cell_type": "code",
"metadata": {
"id": "S3i8jV-RD_MV",
"colab_type": "code",
"colab": {}
},
"source": [
"# Creating the architecture of the Neural Network\n",
"# nv = number visible nodes, nh = number hidden nodes\n",
"class RBM():\n",
" def __init__(self, nv, nh):\n",
" self.W = torch.randn(nh, nv)\n",
" self.a = torch.randn(1, nh)\n",
" self.b = torch.randn(1, nv)\n",
" def sample_h(self, x):\n",
" wx = torch.mm(x, self.W.t())\n",
" activation = wx + self.a.expand_as(wx)\n",
" p_h_given_v = torch.sigmoid(activation)\n",
" return p_h_given_v, torch.bernoulli(p_h_given_v)\n",
" def sample_v(self, y):\n",
" wy = torch.mm(y, self.W)\n",
" activation = wy + self.b.expand_as(wy)\n",
" p_v_given_h = torch.sigmoid(activation)\n",
" return p_v_given_h, torch.bernoulli(p_v_given_h)\n",
" def train(self, v0, vk, ph0, phk):\n",
" self.W += (torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)).t()\n",
" self.b += torch.sum((v0 - vk), 0)\n",
" self.a += torch.sum((ph0 - phk), 0)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "GVmKKbizKFV2",
"colab_type": "text"
},
"source": [
"# STEP 11"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Itwi6_KlKGmf",
"colab_type": "code",
"colab": {}
},
"source": [
"nv = len(training_set[0])\n",
"nh = 100\n",
"batch_size = 100\n",
"rbm = RBM(nv, nh)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "45nKkm5QK3hx",
"colab_type": "text"
},
"source": [
"# STEP 12-13"
]
},
{
"cell_type": "code",
"metadata": {
"id": "NJ94UFahKrOw",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 187
},
"outputId": "841eb9e8-91ba-4f7b-cae3-b6370b9181fc"
},
"source": [
"# Training the RBM\n",
"nb_epochs = 10\n",
"for epoch in range(1, nb_epochs + 1):\n",
" train_loss = 0\n",
" s = 0.\n",
" for id_user in range(0, nb_users - batch_size, batch_size):\n",
" vk = training_set[id_user:id_user+batch_size]\n",
" v0 = training_set[id_user:id_user+batch_size]\n",
" ph0,_ = rbm.sample_h(v0)\n",
" for k in range(10):\n",
" _,hk = rbm.sample_h(vk)\n",
" _,vk = rbm.sample_v(hk)\n",
" vk[v0<0] = v0[v0<0]\n",
" phk,_ = rbm.sample_h(vk)\n",
" rbm.train(v0, vk, ph0, phk)\n",
" train_loss += torch.mean(torch.abs(v0[v0>=0] - vk[v0>=0]))\n",
" s += 1.\n",
" print('epoch: '+str(epoch)+' loss: '+str(train_loss/s))"
],
"execution_count": 39,
"outputs": [
{
"output_type": "stream",
"text": [
"epoch: 1 loss: tensor(0.3424)\n",
"epoch: 2 loss: tensor(0.2527)\n",
"epoch: 3 loss: tensor(0.2509)\n",
"epoch: 4 loss: tensor(0.2483)\n",
"epoch: 5 loss: tensor(0.2474)\n",
"epoch: 6 loss: tensor(0.2478)\n",
"epoch: 7 loss: tensor(0.2467)\n",
"epoch: 8 loss: tensor(0.2461)\n",
"epoch: 9 loss: tensor(0.2482)\n",
"epoch: 10 loss: tensor(0.2491)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qTvbF-u9NBdl",
"colab_type": "text"
},
"source": [
"# STEP 14"
]
},
{
"cell_type": "code",
"metadata": {
"id": "RSlbxB8ZLoy9",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "ee26a78e-47ac-4efc-d36d-9048dbae6cd8"
},
"source": [
"# Testing the RBM\n",
"test_loss = 0\n",
"s = 0.\n",
"for id_user in range(nb_users):\n",
" v = training_set[id_user:id_user+1]\n",
" vt = test_set[id_user:id_user+1]\n",
" if len(vt[vt>=0]) > 0:\n",
" _,h = rbm.sample_h(v)\n",
" _,v = rbm.sample_v(h)\n",
" test_loss += torch.mean(torch.abs(vt[vt>=0] - v[vt>=0]))\n",
" s += 1.\n",
"print('test loss: '+str(test_loss/s))"
],
"execution_count": 40,
"outputs": [
{
"output_type": "stream",
"text": [
"test loss: tensor(0.2403)\n"
],
"name": "stdout"
}
]
}
]
}