{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "taruma_udemy_boltzmann.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "n7Oy73RcZei5", "colab_type": "text" }, "source": [ "# Boltzmann Machines\n", "\n", "Notebook ini berdasarkan kursus __Deep Learning A-Z™: Hands-On Artificial Neural Networks__ di Udemy. [Lihat Kursus](https://www.udemy.com/deeplearning/).\n", "\n", "## Informasi Notebook\n", "- __notebook name__: `taruma_udemy_boltzmann`\n", "- __notebook version/date__: `1.0.0`/`20190730`\n", "- __notebook server__: Google Colab\n", "- __python version__: `3.6`\n", "- __pytorch version__: `1.1.0`" ] }, { "cell_type": "code", "metadata": { "id": "0y9V9qyUZeiH", "colab_type": "code", "outputId": "06aa828f-f253-4a73-f66d-c8f53b5780bb", "colab": { "base_uri": "https://localhost:8080/", "height": 51 } }, "source": [ "#### NOTEBOOK DESCRIPTION\n", "\n", "from datetime import datetime\n", "\n", "NOTEBOOK_TITLE = 'taruma_udemy_boltzmann'\n", "NOTEBOOK_VERSION = '1.0.0'\n", "NOTEBOOK_DATE = 1 # Set 1, if you want add date classifier\n", "\n", "NOTEBOOK_NAME = \"{}_{}\".format(\n", " NOTEBOOK_TITLE, \n", " NOTEBOOK_VERSION.replace('.','_')\n", ")\n", "PROJECT_NAME = \"{}_{}{}\".format(\n", " NOTEBOOK_TITLE, \n", " NOTEBOOK_VERSION.replace('.','_'), \n", " \"_\" + datetime.utcnow().strftime(\"%Y%m%d_%H%M\") if NOTEBOOK_DATE else \"\"\n", ")\n", "\n", "print(f\"Nama Notebook: {NOTEBOOK_NAME}\")\n", "print(f\"Nama Proyek: {PROJECT_NAME}\")" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "Nama Notebook: taruma_udemy_boltzmann_1_0_0\n", "Nama Proyek: taruma_udemy_boltzmann_1_0_0_20190730_0822\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "1euCSTADZlh3", "colab_type": "code", "outputId": "79926104-9cd1-43d6-8ca6-35caf0003b2b", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "source": [ "#### System Version\n", "import sys, torch\n", "print(\"versi python: {}\".format(sys.version))\n", "print(\"versi pytorch: {}\".format(torch.__version__))" ], "execution_count": 2, "outputs": [ { "output_type": "stream", "text": [ "versi python: 3.6.8 (default, Jan 14 2019, 11:02:34) \n", "[GCC 8.0.1 20180414 (experimental) [trunk revision 259383]]\n", "versi pytorch: 1.1.0\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "s0qrkxTVZj_P", "colab_type": "code", "colab": {} }, "source": [ "#### Load Notebook Extensions\n", "%load_ext google.colab.data_table" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "S8l7WZ0_ZmCK", "colab_type": "code", "colab": {} }, "source": [ "#### Download dataset\n", "# ref: https://grouplens.org/datasets/movielens/\n", "!wget -O boltzmann.zip \"https://sds-platform-private.s3-us-east-2.amazonaws.com/uploads/P16-Boltzmann-Machines.zip\"\n", "!unzip boltzmann.zip" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "jXTtTiLAZmBu", "colab_type": "code", "colab": {} }, "source": [ "#### Atur dataset path\n", "DATASET_DIRECTORY = 'Boltzmann_Machines/'" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "3cDBJSwAknug", "colab_type": "code", "colab": {} }, "source": [ "def showdata(dataframe):\n", " print('Dataframe Size: {}'.format(dataframe.shape))\n", " return dataframe" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "hqE6ozW8e0ra", "colab_type": "text" }, "source": [ "# STEP 1-5 DATA PREPROCESSING" ] }, { "cell_type": "code", "metadata": { "id": "fLvxd5pQdTQq", "colab_type": "code", "colab": {} }, "source": [ "# Importing the libraries\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.parallel\n", "import torch.optim as optim\n", "import torch.utils.data\n", "from torch.autograd import Variable" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "lFQEACh4fJLp", "colab_type": "code", "outputId": "7aeb4f68-2294-439f-bc39-fd0b73e70de9", "colab": { "base_uri": "https://localhost:8080/", "height": 309 } }, "source": [ "movies = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/movies.dat', sep='::', header=None, engine='python', encoding='latin-1')\n", "showdata(movies).head(10)" ], "execution_count": 8, "outputs": [ { "output_type": "stream", "text": [ "Dataframe Size: (3883, 3)\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"Toy Story (1995)\",\n\"Animation|Children's|Comedy\"],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2,\n 'f': \"2\",\n },\n\"Jumanji (1995)\",\n\"Adventure|Children's|Fantasy\"],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n\"Grumpier Old Men (1995)\",\n\"Comedy|Romance\"],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n\"Waiting to Exhale (1995)\",\n\"Comedy|Drama\"],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n\"Father of the Bride Part II (1995)\",\n\"Comedy\"],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 6,\n 'f': \"6\",\n },\n\"Heat (1995)\",\n\"Action|Crime|Thriller\"],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"Sabrina (1995)\",\n\"Comedy|Romance\"],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 8,\n 'f': \"8\",\n },\n\"Tom and Huck (1995)\",\n\"Adventure|Children's\"],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"Sudden Death (1995)\",\n\"Action\"],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"GoldenEye (1995)\",\n\"Action|Adventure|Thriller\"]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"string\", \"1\"], [\"string\", \"2\"]],\n rowsPerPage: 25,\n });\n ", "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
012
01Toy Story (1995)Animation|Children's|Comedy
12Jumanji (1995)Adventure|Children's|Fantasy
23Grumpier Old Men (1995)Comedy|Romance
34Waiting to Exhale (1995)Comedy|Drama
45Father of the Bride Part II (1995)Comedy
56Heat (1995)Action|Crime|Thriller
67Sabrina (1995)Comedy|Romance
78Tom and Huck (1995)Adventure|Children's
89Sudden Death (1995)Action
910GoldenEye (1995)Action|Adventure|Thriller
\n", "
" ], "text/plain": [ " 0 1 2\n", "0 1 Toy Story (1995) Animation|Children's|Comedy\n", "1 2 Jumanji (1995) Adventure|Children's|Fantasy\n", "2 3 Grumpier Old Men (1995) Comedy|Romance\n", "3 4 Waiting to Exhale (1995) Comedy|Drama\n", "4 5 Father of the Bride Part II (1995) Comedy\n", "5 6 Heat (1995) Action|Crime|Thriller\n", "6 7 Sabrina (1995) Comedy|Romance\n", "7 8 Tom and Huck (1995) Adventure|Children's\n", "8 9 Sudden Death (1995) Action\n", "9 10 GoldenEye (1995) Action|Adventure|Thriller" ] }, "metadata": { "tags": [] }, "execution_count": 8 } ] }, { "cell_type": "code", "metadata": { "id": "dgllaEzifthq", "colab_type": "code", "outputId": "6e9eee4f-859c-44bf-d906-c72e9289ee93", "colab": { "base_uri": "https://localhost:8080/", "height": 309 } }, "source": [ "users = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/users.dat', sep='::', header=None, engine='python', encoding='latin-1')\n", "showdata(users).head(10)" ], "execution_count": 9, "outputs": [ { "output_type": "stream", "text": [ "Dataframe Size: (6040, 5)\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"F\",\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"48067\"],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2,\n 'f': \"2\",\n },\n\"M\",\n{\n 'v': 56,\n 'f': \"56\",\n },\n{\n 'v': 16,\n 'f': \"16\",\n },\n\"70072\"],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 15,\n 'f': \"15\",\n },\n\"55117\"],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n\"M\",\n{\n 'v': 45,\n 'f': \"45\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"02460\"],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 20,\n 'f': \"20\",\n },\n\"55455\"],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 6,\n 'f': \"6\",\n },\n\"F\",\n{\n 'v': 50,\n 'f': \"50\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"55117\"],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 7,\n 'f': \"7\",\n },\n\"M\",\n{\n 'v': 35,\n 'f': \"35\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"06810\"],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 8,\n 'f': \"8\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 12,\n 'f': \"12\",\n },\n\"11413\"],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 9,\n 'f': \"9\",\n },\n\"M\",\n{\n 'v': 25,\n 'f': \"25\",\n },\n{\n 'v': 17,\n 'f': \"17\",\n },\n\"61614\"],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 10,\n 'f': \"10\",\n },\n\"F\",\n{\n 'v': 35,\n 'f': \"35\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n\"95370\"]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"string\", \"1\"], [\"number\", \"2\"], [\"number\", \"3\"], [\"string\", \"4\"]],\n rowsPerPage: 25,\n });\n ", "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
01234
01F11048067
12M561670072
23M251555117
34M45702460
45M252055455
56F50955117
67M35106810
78M251211413
89M251761614
910F35195370
\n", "
" ], "text/plain": [ " 0 1 2 3 4\n", "0 1 F 1 10 48067\n", "1 2 M 56 16 70072\n", "2 3 M 25 15 55117\n", "3 4 M 45 7 02460\n", "4 5 M 25 20 55455\n", "5 6 F 50 9 55117\n", "6 7 M 35 1 06810\n", "7 8 M 25 12 11413\n", "8 9 M 25 17 61614\n", "9 10 F 35 1 95370" ] }, "metadata": { "tags": [] }, "execution_count": 9 } ] }, { "cell_type": "code", "metadata": { "id": "anc5Oi1sgDzc", "colab_type": "code", "outputId": "29b8bcdd-4562-40d0-ed3e-c179f86b9ed7", "colab": { "base_uri": "https://localhost:8080/", "height": 309 } }, "source": [ "ratings = pd.read_csv(DATASET_DIRECTORY + 'ml-1m/ratings.dat', sep='::', header=None, engine='python', encoding='latin-1')\n", "showdata(ratings).head(10)" ], "execution_count": 10, "outputs": [ { "output_type": "stream", "text": [ "Dataframe Size: (1000209, 4)\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "application/vnd.google.colaboratory.module+javascript": "\n import \"https://ssl.gstatic.com/colaboratory/data_table/81868506e94e6988/data_table.js\";\n\n window.createDataTable({\n data: [[{\n 'v': 0,\n 'f': \"0\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1193,\n 'f': \"1193\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978300760,\n 'f': \"978300760\",\n }],\n [{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 661,\n 'f': \"661\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978302109,\n 'f': \"978302109\",\n }],\n [{\n 'v': 2,\n 'f': \"2\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 914,\n 'f': \"914\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978301968,\n 'f': \"978301968\",\n }],\n [{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 3408,\n 'f': \"3408\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978300275,\n 'f': \"978300275\",\n }],\n [{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2355,\n 'f': \"2355\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978824291,\n 'f': \"978824291\",\n }],\n [{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1197,\n 'f': \"1197\",\n },\n{\n 'v': 3,\n 'f': \"3\",\n },\n{\n 'v': 978302268,\n 'f': \"978302268\",\n }],\n [{\n 'v': 6,\n 'f': \"6\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 1287,\n 'f': \"1287\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978302039,\n 'f': \"978302039\",\n }],\n [{\n 'v': 7,\n 'f': \"7\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 2804,\n 'f': \"2804\",\n },\n{\n 'v': 5,\n 'f': \"5\",\n },\n{\n 'v': 978300719,\n 'f': \"978300719\",\n }],\n [{\n 'v': 8,\n 'f': \"8\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 594,\n 'f': \"594\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978302268,\n 'f': \"978302268\",\n }],\n [{\n 'v': 9,\n 'f': \"9\",\n },\n{\n 'v': 1,\n 'f': \"1\",\n },\n{\n 'v': 919,\n 'f': \"919\",\n },\n{\n 'v': 4,\n 'f': \"4\",\n },\n{\n 'v': 978301368,\n 'f': \"978301368\",\n }]],\n columns: [[\"number\", \"index\"], [\"number\", \"0\"], [\"number\", \"1\"], [\"number\", \"2\"], [\"number\", \"3\"]],\n rowsPerPage: 25,\n });\n ", "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123
0111935978300760
116613978302109
219143978301968
3134084978300275
4123555978824291
5111973978302268
6112875978302039
7128045978300719
815944978302268
919194978301368
\n", "
" ], "text/plain": [ " 0 1 2 3\n", "0 1 1193 5 978300760\n", "1 1 661 3 978302109\n", "2 1 914 3 978301968\n", "3 1 3408 4 978300275\n", "4 1 2355 5 978824291\n", "5 1 1197 3 978302268\n", "6 1 1287 5 978302039\n", "7 1 2804 5 978300719\n", "8 1 594 4 978302268\n", "9 1 919 4 978301368" ] }, "metadata": { "tags": [] }, "execution_count": 10 } ] }, { "cell_type": "code", "metadata": { "id": "xU2S9y8NgRPW", "colab_type": "code", "colab": {} }, "source": [ "# Preparing the training set and the test set\n", "training_set = pd.read_csv(DATASET_DIRECTORY + 'ml-100k/u1.base', delimiter='\\t')\n", "training_set = np.array(training_set, dtype='int')\n", "test_set = pd.read_csv(DATASET_DIRECTORY + 'ml-100k/u1.test', delimiter='\\t')\n", "test_set = np.array(test_set, dtype='int')" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "G7WbrddJl3Q9", "colab_type": "code", "colab": {} }, "source": [ "# Getting the number of users and movies\n", "nb_users = int(max(max(training_set[:, 0]), max(test_set[:, 0])))\n", "nb_movies = int(max(max(training_set[:, 1]), max(test_set[:, 1])))" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "yRUTR_K3_rzP", "colab_type": "code", "colab": {} }, "source": [ "# Converting the data into an array with users in lines and movies in columns\n", "def convert(data):\n", " new_data = []\n", " for id_users in range(1, nb_users+1):\n", " id_movies = data[:, 1][data[:, 0] == id_users]\n", " id_ratings = data[:, 2][data[:, 0] == id_users]\n", " ratings = np.zeros(nb_movies)\n", " ratings[id_movies - 1] = id_ratings\n", " new_data.append(list(ratings))\n", " return new_data\n", "\n", "training_set = convert(training_set)\n", "test_set = convert(test_set)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "u0Fk8Q0YCNZr", "colab_type": "code", "colab": {} }, "source": [ "# Converting the data into Torch tensors\n", "training_set = torch.FloatTensor(training_set)\n", "test_set = torch.FloatTensor(test_set)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "q2arIJufDBYd", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 136 }, "outputId": "6f90b00a-799a-4db0-f1ac-82241e375329" }, "source": [ "training_set." ], "execution_count": 25, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[0., 3., 4., ..., 0., 0., 0.],\n", " [4., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [5., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 5., 0., ..., 0., 0., 0.]])" ] }, "metadata": { "tags": [] }, "execution_count": 25 } ] }, { "cell_type": "markdown", "metadata": { "id": "wwWlLU5ODGgF", "colab_type": "text" }, "source": [ "# STEP 6" ] }, { "cell_type": "code", "metadata": { "id": "m0KbLrFZDCjK", "colab_type": "code", "colab": {} }, "source": [ "# Converting the ratings into binary ratings 1 (Liked) or 0 (Not Liked)\n", "training_set[training_set == 0] = -1\n", "training_set[training_set == 1] = 0\n", "training_set[training_set == 2] = 0\n", "training_set[training_set >= 3] = 1\n", "\n", "test_set[test_set == 0] = -1\n", "test_set[test_set == 1] = 0\n", "test_set[test_set == 2] = 0\n", "test_set[test_set >= 3] = 1" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "zHEwVvTlD-DK", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 136 }, "outputId": "08d4e08c-6c99-4fa5-e528-ca995f016ed2" }, "source": [ "training_set" ], "execution_count": 27, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[-1., 1., 1., ..., -1., -1., -1.],\n", " [ 1., -1., -1., ..., -1., -1., -1.],\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", " ...,\n", " [ 1., -1., -1., ..., -1., -1., -1.],\n", " [-1., -1., -1., ..., -1., -1., -1.],\n", " [-1., 1., -1., ..., -1., -1., -1.]])" ] }, "metadata": { "tags": [] }, "execution_count": 27 } ] }, { "cell_type": "markdown", "metadata": { "id": "a4YrUGpDEFEV", "colab_type": "text" }, "source": [ "# STEP 7 - 10 Building RBM Object" ] }, { "cell_type": "code", "metadata": { "id": "S3i8jV-RD_MV", "colab_type": "code", "colab": {} }, "source": [ "# Creating the architecture of the Neural Network\n", "# nv = number visible nodes, nh = number hidden nodes\n", "class RBM():\n", " def __init__(self, nv, nh):\n", " self.W = torch.randn(nh, nv)\n", " self.a = torch.randn(1, nh)\n", " self.b = torch.randn(1, nv)\n", " def sample_h(self, x):\n", " wx = torch.mm(x, self.W.t())\n", " activation = wx + self.a.expand_as(wx)\n", " p_h_given_v = torch.sigmoid(activation)\n", " return p_h_given_v, torch.bernoulli(p_h_given_v)\n", " def sample_v(self, y):\n", " wy = torch.mm(y, self.W)\n", " activation = wy + self.b.expand_as(wy)\n", " p_v_given_h = torch.sigmoid(activation)\n", " return p_v_given_h, torch.bernoulli(p_v_given_h)\n", " def train(self, v0, vk, ph0, phk):\n", " self.W += (torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)).t()\n", " self.b += torch.sum((v0 - vk), 0)\n", " self.a += torch.sum((ph0 - phk), 0)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "GVmKKbizKFV2", "colab_type": "text" }, "source": [ "# STEP 11" ] }, { "cell_type": "code", "metadata": { "id": "Itwi6_KlKGmf", "colab_type": "code", "colab": {} }, "source": [ "nv = len(training_set[0])\n", "nh = 100\n", "batch_size = 100\n", "rbm = RBM(nv, nh)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "45nKkm5QK3hx", "colab_type": "text" }, "source": [ "# STEP 12-13" ] }, { "cell_type": "code", "metadata": { "id": "NJ94UFahKrOw", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 187 }, "outputId": "841eb9e8-91ba-4f7b-cae3-b6370b9181fc" }, "source": [ "# Training the RBM\n", "nb_epochs = 10\n", "for epoch in range(1, nb_epochs + 1):\n", " train_loss = 0\n", " s = 0.\n", " for id_user in range(0, nb_users - batch_size, batch_size):\n", " vk = training_set[id_user:id_user+batch_size]\n", " v0 = training_set[id_user:id_user+batch_size]\n", " ph0,_ = rbm.sample_h(v0)\n", " for k in range(10):\n", " _,hk = rbm.sample_h(vk)\n", " _,vk = rbm.sample_v(hk)\n", " vk[v0<0] = v0[v0<0]\n", " phk,_ = rbm.sample_h(vk)\n", " rbm.train(v0, vk, ph0, phk)\n", " train_loss += torch.mean(torch.abs(v0[v0>=0] - vk[v0>=0]))\n", " s += 1.\n", " print('epoch: '+str(epoch)+' loss: '+str(train_loss/s))" ], "execution_count": 39, "outputs": [ { "output_type": "stream", "text": [ "epoch: 1 loss: tensor(0.3424)\n", "epoch: 2 loss: tensor(0.2527)\n", "epoch: 3 loss: tensor(0.2509)\n", "epoch: 4 loss: tensor(0.2483)\n", "epoch: 5 loss: tensor(0.2474)\n", "epoch: 6 loss: tensor(0.2478)\n", "epoch: 7 loss: tensor(0.2467)\n", "epoch: 8 loss: tensor(0.2461)\n", "epoch: 9 loss: tensor(0.2482)\n", "epoch: 10 loss: tensor(0.2491)\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "qTvbF-u9NBdl", "colab_type": "text" }, "source": [ "# STEP 14" ] }, { "cell_type": "code", "metadata": { "id": "RSlbxB8ZLoy9", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "ee26a78e-47ac-4efc-d36d-9048dbae6cd8" }, "source": [ "# Testing the RBM\n", "test_loss = 0\n", "s = 0.\n", "for id_user in range(nb_users):\n", " v = training_set[id_user:id_user+1]\n", " vt = test_set[id_user:id_user+1]\n", " if len(vt[vt>=0]) > 0:\n", " _,h = rbm.sample_h(v)\n", " _,v = rbm.sample_v(h)\n", " test_loss += torch.mean(torch.abs(vt[vt>=0] - v[vt>=0]))\n", " s += 1.\n", "print('test loss: '+str(test_loss/s))" ], "execution_count": 40, "outputs": [ { "output_type": "stream", "text": [ "test loss: tensor(0.2403)\n" ], "name": "stdout" } ] } ] }