{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "fastai_collab.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyN1v8T4yqV4jCEYOlmXCsEI",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/bipinKrishnan/fastai_course/blob/master/fastai_collab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "12UvxQcEEo4o"
      },
      "source": [
        "!pip install fastai --upgrade"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "neBLmF00EsVi"
      },
      "source": [
        "import fastai.collab as c\n",
        "import fastai.tabular.all as t\n",
        "import torch\n",
        "\n",
        "import pandas as pd\n",
        "import os\n",
        "os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\""
      ],
      "execution_count": 25,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pZHSdM3ZE3XT",
        "outputId": "e5ea3cdf-5399-4219-9a42-d0b5eb451d02",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        }
      },
      "source": [
        "path = c.untar_data(c.URLs.ML_100k)\n",
        "path.ls()"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(#23) [Path('/root/.fastai/data/ml-100k/ub.base'),Path('/root/.fastai/data/ml-100k/u5.base'),Path('/root/.fastai/data/ml-100k/u5.test'),Path('/root/.fastai/data/ml-100k/u.info'),Path('/root/.fastai/data/ml-100k/u2.base'),Path('/root/.fastai/data/ml-100k/u3.base'),Path('/root/.fastai/data/ml-100k/ua.test'),Path('/root/.fastai/data/ml-100k/u4.test'),Path('/root/.fastai/data/ml-100k/mku.sh'),Path('/root/.fastai/data/ml-100k/u.genre')...]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 26
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ldPlun7NFBqw",
        "outputId": "99e6eadd-d2dd-4e69-9d49-da0d5ab40f1b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "df = pd.read_csv(path/'u.data', delimiter='\\t', names=['user', 'movie', 'rating', 'timestamp'])\n",
        "df.head()"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>user</th>\n",
              "      <th>movie</th>\n",
              "      <th>rating</th>\n",
              "      <th>timestamp</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>196</td>\n",
              "      <td>242</td>\n",
              "      <td>3</td>\n",
              "      <td>881250949</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>186</td>\n",
              "      <td>302</td>\n",
              "      <td>3</td>\n",
              "      <td>891717742</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>22</td>\n",
              "      <td>377</td>\n",
              "      <td>1</td>\n",
              "      <td>878887116</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>244</td>\n",
              "      <td>51</td>\n",
              "      <td>2</td>\n",
              "      <td>880606923</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>166</td>\n",
              "      <td>346</td>\n",
              "      <td>1</td>\n",
              "      <td>886397596</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "   user  movie  rating  timestamp\n",
              "0   196    242       3  881250949\n",
              "1   186    302       3  891717742\n",
              "2    22    377       1  878887116\n",
              "3   244     51       2  880606923\n",
              "4   166    346       1  886397596"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SnFX7W8NF3K2",
        "outputId": "f79c3180-7c71-42db-c167-66f8183e9910",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 225
        }
      },
      "source": [
        "df.info(), df.shape"
      ],
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "<class 'pandas.core.frame.DataFrame'>\n",
            "RangeIndex: 100000 entries, 0 to 99999\n",
            "Data columns (total 4 columns):\n",
            " #   Column     Non-Null Count   Dtype\n",
            "---  ------     --------------   -----\n",
            " 0   user       100000 non-null  int64\n",
            " 1   movie      100000 non-null  int64\n",
            " 2   rating     100000 non-null  int64\n",
            " 3   timestamp  100000 non-null  int64\n",
            "dtypes: int64(4)\n",
            "memory usage: 3.1 MB\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(None, (100000, 4))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 28
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LHL4rnnPGf2E",
        "outputId": "a323bdaa-2c89-49ac-8816-f749451c6294",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "movies = pd.read_csv(path/'u.item', delimiter='|', names=['movie', 'title'], encoding='latin-1', usecols=(0, 1))\n",
        "movies.head()"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>movie</th>\n",
              "      <th>title</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>1</td>\n",
              "      <td>Toy Story (1995)</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>2</td>\n",
              "      <td>GoldenEye (1995)</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>3</td>\n",
              "      <td>Four Rooms (1995)</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>4</td>\n",
              "      <td>Get Shorty (1995)</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>5</td>\n",
              "      <td>Copycat (1995)</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "   movie              title\n",
              "0      1   Toy Story (1995)\n",
              "1      2   GoldenEye (1995)\n",
              "2      3  Four Rooms (1995)\n",
              "3      4  Get Shorty (1995)\n",
              "4      5     Copycat (1995)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pL6bEndiIzyT",
        "outputId": "026ecd26-0e1c-4854-b968-0b1ff5d21773",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "df = df.merge(movies)\n",
        "df.head()"
      ],
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>user</th>\n",
              "      <th>movie</th>\n",
              "      <th>rating</th>\n",
              "      <th>timestamp</th>\n",
              "      <th>title</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>196</td>\n",
              "      <td>242</td>\n",
              "      <td>3</td>\n",
              "      <td>881250949</td>\n",
              "      <td>Kolya (1996)</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>63</td>\n",
              "      <td>242</td>\n",
              "      <td>3</td>\n",
              "      <td>875747190</td>\n",
              "      <td>Kolya (1996)</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>226</td>\n",
              "      <td>242</td>\n",
              "      <td>5</td>\n",
              "      <td>883888671</td>\n",
              "      <td>Kolya (1996)</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>154</td>\n",
              "      <td>242</td>\n",
              "      <td>3</td>\n",
              "      <td>879138235</td>\n",
              "      <td>Kolya (1996)</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>306</td>\n",
              "      <td>242</td>\n",
              "      <td>5</td>\n",
              "      <td>876503793</td>\n",
              "      <td>Kolya (1996)</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "   user  movie  rating  timestamp         title\n",
              "0   196    242       3  881250949  Kolya (1996)\n",
              "1    63    242       3  875747190  Kolya (1996)\n",
              "2   226    242       5  883888671  Kolya (1996)\n",
              "3   154    242       3  879138235  Kolya (1996)\n",
              "4   306    242       5  876503793  Kolya (1996)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 30
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YXShxk83Jmnm",
        "outputId": "19e3c40e-aa0a-4cb9-8f28-446988ea5ee7",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 363
        }
      },
      "source": [
        "dls = c.CollabDataLoaders.from_df(\n",
        "    df,\n",
        "    user_name='user',\n",
        "    item_name='title',\n",
        "    rating_name='rating',\n",
        "    bs=32\n",
        ")\n",
        "\n",
        "dls.show_batch()"
      ],
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>user</th>\n",
              "      <th>title</th>\n",
              "      <th>rating</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>303</td>\n",
              "      <td>Shall We Dance? (1937)</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>880</td>\n",
              "      <td>Get Shorty (1995)</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>355</td>\n",
              "      <td>Kolya (1996)</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>334</td>\n",
              "      <td>Sound of Music, The (1965)</td>\n",
              "      <td>2</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>826</td>\n",
              "      <td>Striking Distance (1993)</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>942</td>\n",
              "      <td>It Happened One Night (1934)</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>289</td>\n",
              "      <td>Time to Kill, A (1996)</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>405</td>\n",
              "      <td>My Left Foot (1989)</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>452</td>\n",
              "      <td>Fantasia (1940)</td>\n",
              "      <td>2</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>747</td>\n",
              "      <td>Thirty-Two Short Films About Glenn Gould (1993)</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NTOLvhrqKgs0",
        "outputId": "7cedb26d-ecb2-44f6-bb0a-f436a752ef96",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 72
        }
      },
      "source": [
        "dls.classes"
      ],
      "execution_count": 32,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'title': (#1665) ['#na#',\"'Til There Was You (1997)\",'1-900 (1994)','101 Dalmatians (1996)','12 Angry Men (1957)','187 (1997)','2 Days in the Valley (1996)','20,000 Leagues Under the Sea (1954)','2001: A Space Odyssey (1968)','3 Ninjas: High Noon At Mega Mountain (1998)'...],\n",
              " 'user': (#944) ['#na#',1,2,3,4,5,6,7,8,9...]}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 32
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_lneJ6HaKqZR"
      },
      "source": [
        "n_users = len(dls.classes['user'])\n",
        "n_mov = len(dls.classes['title'])\n",
        "n_factors = 5\n",
        "\n",
        "user_factors = torch.randn(n_users, n_factors)\n",
        "mov_factors = torch.randn(n_mov, n_factors)"
      ],
      "execution_count": 33,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BZRvoW1eLL02",
        "outputId": "999f3bdc-f7fc-4ba0-e4ed-fa8588018a99",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "user_factors.shape, mov_factors.shape"
      ],
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(torch.Size([944, 5]), torch.Size([1665, 5]))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 34
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SDCVJkEkLXBX",
        "outputId": "9a482c2a-9b4c-4889-c82d-f9638a7df559",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "user_factors.t()@t.one_hot(3, n_users).float() "
      ],
      "execution_count": 35,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([-0.6902,  0.9260, -1.2527,  0.5396,  0.1056])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 35
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qaAy_vwAMQGG",
        "outputId": "e566925f-4ba7-415f-e29f-1023f4b31c10",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "user_factors[3]"
      ],
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([-0.6902,  0.9260, -1.2527,  0.5396,  0.1056])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 36
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cz5KC59ASGch"
      },
      "source": [
        "#Base line model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7XsV_cbKM0Ez"
      },
      "source": [
        "class DotProduct(torch.nn.Module):\n",
        "  def __init__(self, n_users, n_movies, n_factors):\n",
        "    super().__init__()\n",
        "    self.user_factors = c.Embedding(n_users, n_factors)\n",
        "    self.movie_factors = c.Embedding(n_movies, n_factors)\n",
        "\n",
        "  def forward(self, x):\n",
        "    users = self.user_factors(x[:, 0])\n",
        "    movies = self.movie_factors(x[:, 1])\n",
        "\n",
        "    return (users*movies).sum(dim=1)"
      ],
      "execution_count": 37,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WLLPAeKZPanH",
        "outputId": "d1319df4-c0f9-4c46-a226-558297ee7177",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "dls.one_batch()[0].shape, dls.one_batch()[1].shape"
      ],
      "execution_count": 38,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(torch.Size([32, 2]), torch.Size([32, 1]))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 38
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZnGFkHQ6Pofs"
      },
      "source": [
        "model = DotProduct(n_users, n_mov, 50)\n",
        "learn = c.Learner(dls, model, c.MSELossFlat())"
      ],
      "execution_count": 39,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-UM8TmTOQsxV",
        "outputId": "bff9753b-57a4-4160-ef09-34bca03e11cb",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "learn.fit_one_cycle(5, 5e-3)"
      ],
      "execution_count": 40,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>1.280856</td>\n",
              "      <td>1.324047</td>\n",
              "      <td>00:16</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.139197</td>\n",
              "      <td>1.145741</td>\n",
              "      <td>00:16</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.951527</td>\n",
              "      <td>1.011723</td>\n",
              "      <td>00:16</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.780049</td>\n",
              "      <td>0.897130</td>\n",
              "      <td>00:16</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.747144</td>\n",
              "      <td>0.874317</td>\n",
              "      <td>00:16</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "APF-yfBMSeWG"
      },
      "source": [
        "# Adding sigmoid range"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z3LOXdCcRHgu"
      },
      "source": [
        "class DotProduct(torch.nn.Module):\n",
        "  def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n",
        "    super().__init__()\n",
        "    self.user_factors = c.Embedding(n_users, n_factors)\n",
        "    self.movie_factors = c.Embedding(n_movies, n_factors)\n",
        "    self.y_range = y_range\n",
        "\n",
        "  def forward(self, x):\n",
        "    users = self.user_factors(x[:, 0])\n",
        "    movies = self.movie_factors(x[:, 1])\n",
        "\n",
        "    return c.sigmoid_range((users*movies).sum(dim=1), *self.y_range)"
      ],
      "execution_count": 41,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SEpfIhH2Qzef",
        "outputId": "23177f6d-af02-4827-c628-d1d90afe3300",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "model = DotProduct(n_users, n_mov, 50)\n",
        "learn = c.Learner(dls, model, c.MSELossFlat())\n",
        "learn.fit_one_cycle(5, 5e-3)"
      ],
      "execution_count": 42,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>1.038298</td>\n",
              "      <td>1.009943</td>\n",
              "      <td>00:16</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.880899</td>\n",
              "      <td>0.922624</td>\n",
              "      <td>00:16</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.685187</td>\n",
              "      <td>0.892175</td>\n",
              "      <td>00:16</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.453028</td>\n",
              "      <td>0.901179</td>\n",
              "      <td>00:16</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.308590</td>\n",
              "      <td>0.908446</td>\n",
              "      <td>00:16</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DaC8i0nVSNYU"
      },
      "source": [
        "#Adding bias parameter"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PH2yzzsPSk87"
      },
      "source": [
        "class DotProduct(torch.nn.Module):\n",
        "  def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n",
        "    super().__init__()\n",
        "    self.user_factors = c.Embedding(n_users, n_factors)\n",
        "    self.user_bias = c.Embedding(n_users, 1)\n",
        "\n",
        "    self.movie_factors = c.Embedding(n_movies, n_factors)\n",
        "    self.movie_bias = c.Embedding(n_movies, 1)\n",
        "\n",
        "    self.y_range = y_range\n",
        "\n",
        "  def forward(self, x):\n",
        "    users = self.user_factors(x[:, 0])\n",
        "    movies = self.movie_factors(x[:, 1])\n",
        "\n",
        "    out = (users*movies).sum(dim=1, keepdim=True)\n",
        "    out += self.user_bias(x[:, 0]) + self.movie_bias(x[:, 1])\n",
        "\n",
        "    return c.sigmoid_range(out, *self.y_range)"
      ],
      "execution_count": 43,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FeTFDuhBSk9O",
        "outputId": "99da5355-1941-4e0c-81df-7b5d90888b26",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "model = DotProduct(n_users, n_mov, 50)\n",
        "learn = c.Learner(dls, model, c.MSELossFlat())\n",
        "learn.fit_one_cycle(5, 5e-3)"
      ],
      "execution_count": 44,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>0.981148</td>\n",
              "      <td>0.945925</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.842618</td>\n",
              "      <td>0.882195</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.604646</td>\n",
              "      <td>0.913617</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.407241</td>\n",
              "      <td>0.941966</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.237349</td>\n",
              "      <td>0.950696</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rpLDE7RwSMvZ",
        "outputId": "8783ea29-ce75-401d-b01d-8ea78964025d",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "model = DotProduct(n_users, n_mov, 50)\n",
        "learn = c.Learner(dls, model, c.MSELossFlat())\n",
        "learn.fit_one_cycle(5, 5e-3, wd=0.1)"
      ],
      "execution_count": 45,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>0.979508</td>\n",
              "      <td>0.941654</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.858030</td>\n",
              "      <td>0.898452</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.798919</td>\n",
              "      <td>0.857021</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.672152</td>\n",
              "      <td>0.827143</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.545060</td>\n",
              "      <td>0.824844</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bVaSDYAaWP3z"
      },
      "source": [
        "## Embedding layer from scratch"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RTnq89LkVh4w"
      },
      "source": [
        "def create_params(size): return torch.nn.Parameter(torch.zeros(*size).normal_(0, 0.01))"
      ],
      "execution_count": 46,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Wg_XfhQ3WkXv",
        "outputId": "81e90c5f-9045-4403-ff01-8a22323087a8",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 86
        }
      },
      "source": [
        "create_params((3, 4))"
      ],
      "execution_count": 47,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Parameter containing:\n",
              "tensor([[ 0.0043, -0.0059, -0.0025, -0.0032],\n",
              "        [ 0.0016, -0.0125,  0.0049,  0.0115],\n",
              "        [ 0.0234,  0.0013, -0.0026, -0.0069]], requires_grad=True)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 47
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1CYDjgIGWmu3"
      },
      "source": [
        "class DotProduct(torch.nn.Module):\n",
        "  def __init__(self, n_users, n_movies, n_factors, y_range=(0, 5.5)):\n",
        "    super().__init__()\n",
        "    self.user_factors = create_params((n_users, n_factors))\n",
        "    self.user_bias = create_params((n_users, 1))\n",
        "\n",
        "    self.movie_factors = create_params((n_movies, n_factors))\n",
        "    self.movie_bias = create_params((n_movies, 1))\n",
        "\n",
        "    self.y_range = y_range\n",
        "\n",
        "  def forward(self, x):\n",
        "    users = self.user_factors[x[:, 0]]\n",
        "    movies = self.movie_factors[x[:, 1]]\n",
        "\n",
        "    out = (users*movies).sum(dim=1, keepdim=True)\n",
        "    out += self.user_bias[x[:, 0]] + self.movie_bias[x[:, 1]]\n",
        "\n",
        "    return c.sigmoid_range(out, *self.y_range)"
      ],
      "execution_count": 48,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ivmygidxW8LV",
        "outputId": "c1c399a1-04f7-46e9-83b4-ddb75d8c09d3",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 86
        }
      },
      "source": [
        "for p in DotProduct(3, 4, 5).parameters():\n",
        "  print(p.shape)"
      ],
      "execution_count": 49,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "torch.Size([3, 5])\n",
            "torch.Size([3, 1])\n",
            "torch.Size([4, 5])\n",
            "torch.Size([4, 1])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_X1lQzkmXCSX",
        "outputId": "387a8399-a372-4cf3-afc3-2250c8064d12",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "model = DotProduct(n_users, n_mov, 50)\n",
        "learn = c.Learner(dls, model, c.MSELossFlat())\n",
        "learn.fit_one_cycle(5, 5e-3, wd=0.1)"
      ],
      "execution_count": 50,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>0.934559</td>\n",
              "      <td>0.959005</td>\n",
              "      <td>00:19</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.847594</td>\n",
              "      <td>0.889536</td>\n",
              "      <td>00:19</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.797565</td>\n",
              "      <td>0.851402</td>\n",
              "      <td>00:19</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.709136</td>\n",
              "      <td>0.822152</td>\n",
              "      <td>00:19</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.585161</td>\n",
              "      <td>0.821672</td>\n",
              "      <td>00:19</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZzbSdATHYBi8",
        "outputId": "8dcb5ca1-06e6-4cba-b20b-04b6aadc0894",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 104
        }
      },
      "source": [
        "b = learn.model.movie_bias.squeeze()\n",
        "idx = b.argsort()[:5]\n",
        "\n",
        "[dls.classes['title'][i] for i in idx]"
      ],
      "execution_count": 51,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['Children of the Corn: The Gathering (1996)',\n",
              " 'Mortal Kombat: Annihilation (1997)',\n",
              " 'Crow: City of Angels, The (1996)',\n",
              " 'Island of Dr. Moreau, The (1996)',\n",
              " 'Lawnmower Man 2: Beyond Cyberspace (1996)']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 51
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WLnDvYYgaFKn",
        "outputId": "697abb6f-30d5-412e-add3-59c336225c9c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 104
        }
      },
      "source": [
        "b = learn.model.movie_bias.squeeze()\n",
        "idx = b.argsort(descending=True)[:5]\n",
        "\n",
        "[dls.classes['title'][i] for i in idx]"
      ],
      "execution_count": 52,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['Titanic (1997)',\n",
              " \"Schindler's List (1993)\",\n",
              " 'Star Wars (1977)',\n",
              " 'Silence of the Lambs, The (1991)',\n",
              " 'L.A. Confidential (1997)']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 52
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6yyEP2jFeusk"
      },
      "source": [
        "# Training using fastai collab_learner"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sDYjMvOccvcH"
      },
      "source": [
        "learn = c.collab_learner(dls, n_factors=50, y_range=(0, 5.5))"
      ],
      "execution_count": 53,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0e489As1e9Th",
        "outputId": "873671c4-34d7-460b-d460-6360169fb5e8",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 121
        }
      },
      "source": [
        "learn.model, learn.loss_func"
      ],
      "execution_count": 54,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(EmbeddingDotBias(\n",
              "   (u_weight): Embedding(944, 50)\n",
              "   (i_weight): Embedding(1665, 50)\n",
              "   (u_bias): Embedding(944, 1)\n",
              "   (i_bias): Embedding(1665, 1)\n",
              " ), FlattenedLoss of MSELoss())"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 54
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZIJroJerfATp",
        "outputId": "c543e862-4d14-4ade-c0bf-f4560212eb69",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "learn.fit_one_cycle(5, 5e-5, wd=0.1)"
      ],
      "execution_count": 55,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>1.894737</td>\n",
              "      <td>1.829872</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.807373</td>\n",
              "      <td>1.755659</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>1.684563</td>\n",
              "      <td>1.698134</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>1.673523</td>\n",
              "      <td>1.668491</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>1.664241</td>\n",
              "      <td>1.663463</td>\n",
              "      <td>00:18</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_Omm2PYTfIvz",
        "outputId": "fdcd73d5-6a79-4df8-9f59-c762909095cd",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "dls.classes['title'].o2i['2 Days in the Valley (1996)']"
      ],
      "execution_count": 56,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "6"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 56
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "buDN2R2Qf_mE",
        "outputId": "ac686ec7-63db-43d0-c0d9-49d99edb8921",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "(torch.nn.CosineSimilarity(dim=1)(learn.model.i_weight.weight[6][None], learn.model.i_weight.weight)).argsort(descending=True)[1]"
      ],
      "execution_count": 57,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor(37, device='cuda:0')"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 57
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DZl4FVkYgtK8",
        "outputId": "0291cda3-c609-4ba1-c577-3e2eb955ac45",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "dls.classes['title'][686]"
      ],
      "execution_count": 58,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'Highlander (1986)'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 58
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1NyNymNplOn2"
      },
      "source": [
        "## Using deep learning"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Qtj1sDoWktPs"
      },
      "source": [
        "class CollabNN(c.Module):\n",
        "  def __init__(self, user_sz, item_sz, n_act=100, y_range=(0, 5.5)):\n",
        "    super().__init__()\n",
        "    self.user_factors = c.Embedding(*user_sz)\n",
        "    self.movie_factors = c.Embedding(*item_sz)\n",
        "    self.layers = torch.nn.Sequential(\n",
        "        torch.nn.Linear(user_sz[1]+item_sz[1], n_act),\n",
        "        torch.nn.ReLU(),\n",
        "        torch.nn.Linear(n_act, 1)\n",
        "    )\n",
        "\n",
        "    self.y_range = y_range\n",
        "\n",
        "  def forward(self, x):\n",
        "    embs = self.user_factors(x[:, 0]), self.movie_factors(x[:, 1])\n",
        "    x = self.layers(torch.cat(embs, dim=1))\n",
        "\n",
        "    return c.sigmoid_range(x, *self.y_range)"
      ],
      "execution_count": 59,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Xrj70IfSm04W",
        "outputId": "1dc3be0a-c23f-49be-a635-ef3cd64a5bba",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "sz = c.get_emb_sz(dls)\n",
        "sz"
      ],
      "execution_count": 60,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[(944, 74), (1665, 102)]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 60
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JUe26GT4m78Z"
      },
      "source": [
        "model = CollabNN(*sz)"
      ],
      "execution_count": 61,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oSFfmWHJnLKX",
        "outputId": "93dc3396-1890-41e8-a703-81b13f363480",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "learn = c.Learner(dls, model, c.MSELossFlat())\n",
        "learn.fit_one_cycle(5, 5e-3, wd=0.01)"
      ],
      "execution_count": 62,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>0.942369</td>\n",
              "      <td>0.967605</td>\n",
              "      <td>00:20</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.902668</td>\n",
              "      <td>0.912845</td>\n",
              "      <td>00:20</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.868656</td>\n",
              "      <td>0.884016</td>\n",
              "      <td>00:20</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.841839</td>\n",
              "      <td>0.872575</td>\n",
              "      <td>00:20</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.764020</td>\n",
              "      <td>0.873379</td>\n",
              "      <td>00:20</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VAYV4lo6qGjV",
        "outputId": "5bea2e98-95be-4343-8918-6137a3dc4481",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "learn = c.collab_learner(dls, use_nn=True, layers=[10, 50], y_range=(0, 5.5))\n",
        "learn.fit_one_cycle(5, 5e-3, wd=0.1)"
      ],
      "execution_count": 63,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>0.978604</td>\n",
              "      <td>0.979922</td>\n",
              "      <td>00:23</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.958372</td>\n",
              "      <td>0.930927</td>\n",
              "      <td>00:24</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.900048</td>\n",
              "      <td>0.904019</td>\n",
              "      <td>00:23</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.877367</td>\n",
              "      <td>0.876002</td>\n",
              "      <td>00:23</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.816585</td>\n",
              "      <td>0.871966</td>\n",
              "      <td>00:24</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DttVouzZqwY1",
        "outputId": "855f0e14-7deb-4701-99cd-6de2d73d4efa",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 434
        }
      },
      "source": [
        "learn.model"
      ],
      "execution_count": 64,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "EmbeddingNN(\n",
              "  (embeds): ModuleList(\n",
              "    (0): Embedding(944, 74)\n",
              "    (1): Embedding(1665, 102)\n",
              "  )\n",
              "  (emb_drop): Dropout(p=0.0, inplace=False)\n",
              "  (bn_cont): BatchNorm1d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "  (layers): Sequential(\n",
              "    (0): LinBnDrop(\n",
              "      (0): BatchNorm1d(176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (1): Linear(in_features=176, out_features=10, bias=False)\n",
              "      (2): ReLU(inplace=True)\n",
              "    )\n",
              "    (1): LinBnDrop(\n",
              "      (0): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
              "      (1): Linear(in_features=10, out_features=50, bias=False)\n",
              "      (2): ReLU(inplace=True)\n",
              "    )\n",
              "    (2): LinBnDrop(\n",
              "      (0): Linear(in_features=50, out_features=1, bias=True)\n",
              "    )\n",
              "    (3): SigmoidRange(low=0, high=5.5)\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 64
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5P_LC2p0rW9J",
        "outputId": "c98e77a9-f757-4ad5-f321-c883e4dd4e45",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 173
        }
      },
      "source": [
        "CollabNN((944, 74), (1665, 102))"
      ],
      "execution_count": 65,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "CollabNN(\n",
              "  (user_factors): Embedding(944, 74)\n",
              "  (movie_factors): Embedding(1665, 102)\n",
              "  (layers): Sequential(\n",
              "    (0): Linear(in_features=176, out_features=100, bias=True)\n",
              "    (1): ReLU()\n",
              "    (2): Linear(in_features=100, out_features=1, bias=True)\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 65
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "arUEK0RbvuI8",
        "outputId": "2c3f11db-f504-44fd-bd46-8b70278d99b2",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "t.delegates"
      ],
      "execution_count": 66,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<function fastcore.meta.delegates>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 66
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WKWbh8YKsArO"
      },
      "source": [
        "@t.delegates(t.TabularModel)\n",
        "class EmbeddingNN(t.TabularModel):\n",
        "  def __init__(self, emb_sz, layers, **kwargs):\n",
        "    super().__init__(emb_sz, layers=layers, n_cont=0, out_sz=1, **kwargs)"
      ],
      "execution_count": 67,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y22fCzLzuIAt"
      },
      "source": [
        "### kwargs && args"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cp4Rd-D7wy1I"
      },
      "source": [
        "kwargs"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xQnyy_I0s3MR"
      },
      "source": [
        "def s(a, **kwargs): \n",
        "  return kwargs"
      ],
      "execution_count": 68,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wRH7RrBjtIsJ",
        "outputId": "d272d756-cd43-4d55-f6cf-445ca1b6ae9c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "s(2, s='dhdj', j=3)"
      ],
      "execution_count": 69,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'j': 3, 's': 'dhdj'}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 69
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T4C4MTfhwxYJ"
      },
      "source": [
        "args"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hFbqzi1Et_Yb"
      },
      "source": [
        "def a(b, *size): return b, size"
      ],
      "execution_count": 70,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "F99TCHZVwjNz",
        "outputId": "43bef199-62e5-47bf-9176-279bed4d6b0b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "a(2, 3, 'klk', 5)"
      ],
      "execution_count": 71,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(2, (3, 'klk', 5))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 71
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "voovs7v5yLfn",
        "outputId": "d5fd900c-1807-4de2-daf4-879f91e572ca",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 332
        }
      },
      "source": [
        "learn.show_results()"
      ],
      "execution_count": 72,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              ""
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>user</th>\n",
              "      <th>title</th>\n",
              "      <th>rating</th>\n",
              "      <th>rating_pred</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>621</td>\n",
              "      <td>1366</td>\n",
              "      <td>2</td>\n",
              "      <td>4.168081</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>577</td>\n",
              "      <td>1498</td>\n",
              "      <td>4</td>\n",
              "      <td>3.707471</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>846</td>\n",
              "      <td>920</td>\n",
              "      <td>4</td>\n",
              "      <td>4.199018</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>793</td>\n",
              "      <td>785</td>\n",
              "      <td>2</td>\n",
              "      <td>2.995941</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>336</td>\n",
              "      <td>407</td>\n",
              "      <td>3</td>\n",
              "      <td>3.256781</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>394</td>\n",
              "      <td>1544</td>\n",
              "      <td>5</td>\n",
              "      <td>4.228620</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>345</td>\n",
              "      <td>1336</td>\n",
              "      <td>4</td>\n",
              "      <td>4.130409</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>815</td>\n",
              "      <td>1581</td>\n",
              "      <td>4</td>\n",
              "      <td>4.287453</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>99</td>\n",
              "      <td>552</td>\n",
              "      <td>4</td>\n",
              "      <td>4.037339</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fpl_W8kG_It8"
      },
      "source": [
        "## Using cross entropy loss"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P6_WKrdT-WnY"
      },
      "source": [
        "df1 = df.copy()\n",
        "df1['rating'] = df1['rating']-1"
      ],
      "execution_count": 73,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FA9yUSor-VgU",
        "outputId": "c5c34ea5-8333-450f-dec7-5d7e80d6290e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 363
        }
      },
      "source": [
        "dls = c.CollabDataLoaders.from_df(\n",
        "    df1,\n",
        "    user_name='user',\n",
        "    item_name='title',\n",
        "    rating_name='rating',\n",
        "    bs=32\n",
        ")\n",
        "\n",
        "dls.show_batch()"
      ],
      "execution_count": 74,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>user</th>\n",
              "      <th>title</th>\n",
              "      <th>rating</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>891</td>\n",
              "      <td>Courage Under Fire (1996)</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>806</td>\n",
              "      <td>Raiders of the Lost Ark (1981)</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>790</td>\n",
              "      <td>Romeo Is Bleeding (1993)</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>311</td>\n",
              "      <td>Last of the Mohicans, The (1992)</td>\n",
              "      <td>2</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>790</td>\n",
              "      <td>Emma (1996)</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>151</td>\n",
              "      <td>Crimson Tide (1995)</td>\n",
              "      <td>2</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>489</td>\n",
              "      <td>Devil's Advocate, The (1997)</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>314</td>\n",
              "      <td>Pallbearer, The (1996)</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>411</td>\n",
              "      <td>Rear Window (1954)</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>483</td>\n",
              "      <td>Powder (1995)</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8Aa0_9dL0yhn"
      },
      "source": [
        "class CollabNN(c.Module):\n",
        "  def __init__(self, user_sz, item_sz, n_act=100):\n",
        "    super().__init__()\n",
        "    self.user_factors = c.Embedding(*user_sz)\n",
        "    self.movie_factors = c.Embedding(*item_sz)\n",
        "    self.layers = torch.nn.Sequential(\n",
        "        torch.nn.Linear(user_sz[1]+item_sz[1], n_act),\n",
        "        torch.nn.ReLU(),\n",
        "        torch.nn.Linear(n_act, 5)\n",
        "    )\n",
        "\n",
        "  def forward(self, x):\n",
        "    embs = self.user_factors(x[:, 0]), self.movie_factors(x[:, 1])\n",
        "    x = self.layers(torch.cat(embs, dim=1))\n",
        "\n",
        "    return x"
      ],
      "execution_count": 75,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1zl2oS8607PD"
      },
      "source": [
        "model = CollabNN(*sz)\n",
        "learn = c.Learner(dls, model, loss_func=c.CrossEntropyLossFlat(), metrics=c.accuracy)"
      ],
      "execution_count": 76,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_4TX0Hvt0_EU",
        "outputId": "39180f1e-4a9c-4471-b6c0-3cfacf32e1d6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        }
      },
      "source": [
        "learn.fit_one_cycle(5, 5e-3, wd=0.01)"
      ],
      "execution_count": 77,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: left;\">\n",
              "      <th>epoch</th>\n",
              "      <th>train_loss</th>\n",
              "      <th>valid_loss</th>\n",
              "      <th>accuracy</th>\n",
              "      <th>time</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>0</td>\n",
              "      <td>1.284872</td>\n",
              "      <td>1.289579</td>\n",
              "      <td>0.418500</td>\n",
              "      <td>00:20</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>1.248881</td>\n",
              "      <td>1.264122</td>\n",
              "      <td>0.440350</td>\n",
              "      <td>00:20</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>1.215094</td>\n",
              "      <td>1.232578</td>\n",
              "      <td>0.447150</td>\n",
              "      <td>00:20</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>1.162452</td>\n",
              "      <td>1.226760</td>\n",
              "      <td>0.456100</td>\n",
              "      <td>00:20</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>1.091269</td>\n",
              "      <td>1.234871</td>\n",
              "      <td>0.454250</td>\n",
              "      <td>00:20</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    }
  ]
}