{ "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ "YearPredictionMSD\n", "\n", "Предсказание года выпуска аудиотрека. Features extracted from the 'timbre' features from The Echo Nest API. Выборка состоит из 90 признаков и 515345 объектов." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from matplotlib import pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "column_names = ['year', *('average' + str(i) for i in range(12)), *('covariance' + str(i) for i in range(78))]\n", "data = pd.read_csv(\"YearPredictionMSD.txt\", sep = \",\", header=None, names = column_names)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
yearaverage0average1average2average3average4average5average6average7average8...covariance68covariance69covariance70covariance71covariance72covariance73covariance74covariance75covariance76covariance77
0200149.9435721.4711473.077508.74861-17.40628-13.09905-25.01202-12.232577.83089...13.01620-54.4054858.9936715.373441.11144-23.0879368.40795-1.82223-27.463482.26327
1200148.7321518.4293070.3267912.94636-10.32437-24.837778.76630-0.9201918.76548...5.66812-19.6807333.0496442.87836-9.90378-32.2278870.4938812.0494158.4345326.92061
2200150.9571431.8560255.8185113.41693-6.57898-18.54940-3.27872-2.3503516.07017...3.0380026.05866-50.9277910.93792-0.0756843.20130-115.00698-0.0585939.67068-0.66345
3200148.24750-1.8983736.297722.587760.97170-26.216835.05097-10.341243.55005...34.57337-171.70734-16.96705-46.67617-12.5151682.58061-72.089939.90558199.6297118.85382
4200150.9702042.2099867.099648.46791-15.85279-16.81409-12.48207-9.3763612.63699...9.92661-55.9572464.92712-17.72522-1.49237-7.5003551.766317.8871355.6692628.74903
\n", "

5 rows × 91 columns

\n", "
" ], "text/plain": [ " year average0 average1 average2 average3 average4 average5 average6 \\\n", "0 2001 49.94357 21.47114 73.07750 8.74861 -17.40628 -13.09905 -25.01202 \n", "1 2001 48.73215 18.42930 70.32679 12.94636 -10.32437 -24.83777 8.76630 \n", "2 2001 50.95714 31.85602 55.81851 13.41693 -6.57898 -18.54940 -3.27872 \n", "3 2001 48.24750 -1.89837 36.29772 2.58776 0.97170 -26.21683 5.05097 \n", "4 2001 50.97020 42.20998 67.09964 8.46791 -15.85279 -16.81409 -12.48207 \n", "\n", " average7 average8 ... covariance68 covariance69 covariance70 \\\n", "0 -12.23257 7.83089 ... 13.01620 -54.40548 58.99367 \n", "1 -0.92019 18.76548 ... 5.66812 -19.68073 33.04964 \n", "2 -2.35035 16.07017 ... 3.03800 26.05866 -50.92779 \n", "3 -10.34124 3.55005 ... 34.57337 -171.70734 -16.96705 \n", "4 -9.37636 12.63699 ... 9.92661 -55.95724 64.92712 \n", "\n", " covariance71 covariance72 covariance73 covariance74 covariance75 \\\n", "0 15.37344 1.11144 -23.08793 68.40795 -1.82223 \n", "1 42.87836 -9.90378 -32.22788 70.49388 12.04941 \n", "2 10.93792 -0.07568 43.20130 -115.00698 -0.05859 \n", "3 -46.67617 -12.51516 82.58061 -72.08993 9.90558 \n", "4 -17.72522 -1.49237 -7.50035 51.76631 7.88713 \n", "\n", " covariance76 covariance77 \n", "0 -27.46348 2.26327 \n", "1 58.43453 26.92061 \n", "2 39.67068 -0.66345 \n", "3 199.62971 18.85382 \n", "4 55.66926 28.74903 \n", "\n", "[5 rows x 91 columns]" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You should respect the following train / test split:\n", "train: first 463,715 examples\n", "test: last 51,630 examples\n", "It avoids the 'producer effect' by making sure no song\n", "from a given artist ends up in both the train and test set." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(463715, 51630)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train = data.iloc[:463715, :]\n", "test = data.iloc[463715:, :]\n", "len(train), len(test)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "first_year, last_year = 1922, 2011\n", "num_years = last_year-first_year+1" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "from torch import nn" ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential()\n", "model.add_module('l1', nn.Linear(90, num_years))\n", "#model.add_module('activ', nn.ReLU())\n", "model.add_module('smax', nn.Softmax(0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [], "source": [ "import torch\n", "opt = torch.optim.Adam(model.parameters(), lr=1e-3)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([249807, 398891, 389318, 385061, 425867]),\n", " array([297735, 435139, 453997, 407096, 194297]))" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#check that answers dffer every time\n", "a = np.random.randint(0, len(train), 5)\n", "b = np.random.randint(0, len(train), 5)\n", "a, b" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "X = train.iloc[:, 1:].values\n", "Y = train.iloc[:, 0].values" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[2001,\n", " 2001,\n", " 2001,\n", " 2001,\n", " 2001,\n", " 2001,\n", " 2001,\n", " 2001,\n", " 2001,\n", " 2007,\n", " 2008,\n", " 2002,\n", " 2004,\n", " 2003,\n", " 1999,\n", " 2003,\n", " 2002,\n", " 1992,\n", " 1997,\n", " 1987,\n", " 2000,\n", " 2000,\n", " 2005,\n", " 2000,\n", " 1997,\n", " 1997,\n", " 1996,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1998,\n", " 2000,\n", " 2000,\n", " 2001,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 1998,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2003,\n", " 2003,\n", " 2001,\n", " 2003,\n", " 2001,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2008,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2008,\n", " 2008,\n", " 2009,\n", " 2009,\n", " 2008,\n", " 2008,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2008,\n", " 2008,\n", " 2007,\n", " 2008,\n", " 2007,\n", " 2008,\n", " 2008,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2007,\n", " 2004,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2006,\n", " 2006,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 1993,\n", " 1993,\n", " 1993,\n", " 1993,\n", " 1993,\n", " 1993,\n", " 1993,\n", " 1993,\n", " 1993,\n", " 1993,\n", " 1996,\n", " 2005,\n", " 1991,\n", " 1991,\n", " 1933,\n", " 1991,\n", " 1991,\n", " 1991,\n", " 1991,\n", " 1930,\n", " 1935,\n", " 2005,\n", " 2000,\n", " 1991,\n", " 1991,\n", " 1995,\n", " 1999,\n", " 1999,\n", " 1941,\n", " 1991,\n", " 2005,\n", " 1995,\n", " 1990,\n", " 1999,\n", " 2000,\n", " 1991,\n", " 1991,\n", " 1999,\n", " 1997,\n", " 2000,\n", " 1930,\n", " 1991,\n", " 1930,\n", " 1995,\n", " 1991,\n", " 1991,\n", " 1991,\n", " 1991,\n", " 1991,\n", " 1995,\n", " 1999,\n", " 1999,\n", " 1991,\n", " 1991,\n", " 1991,\n", " 1997,\n", " 1991,\n", " 1999,\n", " 1990,\n", " 1941,\n", " 1941,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1943,\n", " 1991,\n", " 1996,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2008,\n", " 2007,\n", " 2001,\n", " 2009,\n", " 2006,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2008,\n", " 2008,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2006,\n", " 2006,\n", " 1999,\n", " 1999,\n", " 1999,\n", " 1999,\n", " 1999,\n", " 1999,\n", " 1999,\n", " 2002,\n", " 2000,\n", " 2004,\n", " 2006,\n", " 2005,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2005,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 1998,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2000,\n", " 2000,\n", " 1999,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 1994,\n", " 1994,\n", " 1994,\n", " 2002,\n", " 2002,\n", " 1993,\n", " 1995,\n", " 2002,\n", " 1995,\n", " 1995,\n", " 1994,\n", " 2002,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 2003,\n", " 1993,\n", " 2002,\n", " 1993,\n", " 2002,\n", " 1994,\n", " 2002,\n", " 1995,\n", " 1995,\n", " 1995,\n", " 1995,\n", " 1995,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2000,\n", " 2000,\n", " 2000,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 1994,\n", " 1994,\n", " 1994,\n", " 1994,\n", " 1994,\n", " 1994,\n", " 2006,\n", " 2006,\n", " 1999,\n", " 1999,\n", " 1999,\n", " 1999,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 1999,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2004,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2008,\n", " 1996,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2004,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 1996,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 1990,\n", " 2008,\n", " 1999,\n", " 1974,\n", " 1974,\n", " 1974,\n", " 1974,\n", " 1974,\n", " 1974,\n", " 1974,\n", " 1976,\n", " 1976,\n", " 1976,\n", " 1976,\n", " 1976,\n", " 1976,\n", " 1999,\n", " 1975,\n", " 1975,\n", " 1999,\n", " 1975,\n", " 1975,\n", " 1975,\n", " 1975,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1971,\n", " 1971,\n", " 1971,\n", " 1971,\n", " 1971,\n", " 1971,\n", " 1970,\n", " 1981,\n", " 1981,\n", " 1981,\n", " 1981,\n", " 1981,\n", " 1981,\n", " 1981,\n", " 1981,\n", " 2009,\n", " 1989,\n", " 1989,\n", " 1989,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 2002,\n", " 1999,\n", " 1999,\n", " 2002,\n", " 1969,\n", " 1969,\n", " 1969,\n", " 1969,\n", " 1972,\n", " 1972,\n", " 1972,\n", " 1972,\n", " 1972,\n", " 1972,\n", " 1972,\n", " 1972,\n", " 1972,\n", " 1971,\n", " 1971,\n", " 1971,\n", " 1971,\n", " 1971,\n", " 1971,\n", " 1971,\n", " 1973,\n", " 1973,\n", " 1973,\n", " 1973,\n", " 1973,\n", " 1983,\n", " 1983,\n", " 1983,\n", " 1983,\n", " 1983,\n", " 1983,\n", " 1983,\n", " 1983,\n", " 1983,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2009,\n", " 2010,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2001,\n", " 2008,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2006,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 1985,\n", " 1988,\n", " 1989,\n", " 1989,\n", " 1989,\n", " 1989,\n", " 1989,\n", " 1989,\n", " 1993,\n", " 1985,\n", " 1987,\n", " 2004,\n", " 2004,\n", " 1999,\n", " 1993,\n", " 2004,\n", " 1991,\n", " 1987,\n", " 1987,\n", " 1987,\n", " 2001,\n", " 2001,\n", " 2001,\n", " 1979,\n", " 1980,\n", " 1980,\n", " 2001,\n", " 2001,\n", " 1979,\n", " 1980,\n", " 2000,\n", " 1980,\n", " 1980,\n", " 1980,\n", " 2001,\n", " 2000,\n", " 2001,\n", " 1979,\n", " 2000,\n", " 1980,\n", " 2000,\n", " 1980,\n", " 1980,\n", " 1979,\n", " 1980,\n", " 1999,\n", " 1989,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1986,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1970,\n", " 1971,\n", " 1971,\n", " 1971,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 1958,\n", " 2005,\n", " 1958,\n", " 1992,\n", " 2006,\n", " 1993,\n", " 1992,\n", " 1970,\n", " 1970,\n", " 1958,\n", " 2005,\n", " 1958,\n", " 1958,\n", " 1958,\n", " 1958,\n", " 1970,\n", " 1997,\n", " 1997,\n", " 2005,\n", " 2009,\n", " 1992,\n", " 1970,\n", " 1998,\n", " 1994,\n", " 2006,\n", " 2008,\n", " 1992,\n", " 1972,\n", " 1992,\n", " 1992,\n", " 2008,\n", " 1994,\n", " 2005,\n", " 1993,\n", " 1993,\n", " 1993,\n", " 2000,\n", " 1991,\n", " 1997,\n", " 1991,\n", " 1978,\n", " 1992,\n", " 1995,\n", " 1972,\n", " 1995,\n", " 1991,\n", " 1987,\n", " 1987,\n", " 1968,\n", " 2003,\n", " 1962,\n", " 1995,\n", " 1995,\n", " 1995,\n", " 1995,\n", " 1958,\n", " 1980,\n", " 1980,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 1981,\n", " 2005,\n", " 2005,\n", " 1981,\n", " 2005,\n", " 1981,\n", " 2005,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2003,\n", " 2002,\n", " 2003,\n", " 1987,\n", " 1987,\n", " 1987,\n", " 1987,\n", " 2002,\n", " 1992,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 1997,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 1997,\n", " 1992,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2007,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2008,\n", " 2010,\n", " 2010,\n", " 2010,\n", " 2010,\n", " 2010,\n", " 2010,\n", " 2010,\n", " 2010,\n", " 2010,\n", " 2010,\n", " 2010,\n", " 2006,\n", " 2005,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2005,\n", " 2005,\n", " 2005,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2004,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2005,\n", " 2004,\n", " 2005,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2005,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 2004,\n", " 1995,\n", " 1994,\n", " 1993,\n", " 1994,\n", " 1996,\n", " 1967,\n", " 2006,\n", " 1967,\n", " 2006,\n", " 2006,\n", " 1972,\n", " 2002,\n", " 1972,\n", " 2002,\n", " 1972,\n", " 1972,\n", " 1972,\n", " 1999,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2006,\n", " 2004,\n", " 2006,\n", " 2006,\n", " 1972,\n", " 1972,\n", " 1972,\n", " 1972,\n", " 1972,\n", " ...]" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(Y)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "Y = list(Y)\n", "for idx in range(len(Y)):\n", " year=Y[idx]\n", " Y[idx] = [int(i+first_year == year) for i in range(num_years)]" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "Y = torch.tensor(Y, dtype=torch.float32)" ] }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step #0 | mean loss = 0.011\n", "step #1 | mean loss = 0.011\n", "step #2 | mean loss = 0.011\n", "step #3 | mean loss = 0.011\n", "step #4 | mean loss = 0.011\n", "step #5 | mean loss = 0.011\n", "step #6 | mean loss = 0.011\n", "step #7 | mean loss = 0.011\n", "step #8 | mean loss = 0.011\n", "step #9 | mean loss = 0.011\n" ] } ], "source": [ "history = []\n", "\n", "batch_size = int(len(train)/50)\n", "for i in range(10):#around a thousand samples\n", " # sample batch_size random data\n", " ix = np.random.randint(0, len(train), batch_size)\n", " x_batch = torch.tensor(X[ix], dtype=torch.float32)\n", " y_batch = torch.tensor(Y[ix], dtype=torch.float32)\n", " \n", " # predict probabilities\n", " y_predicted = model(x_batch)\n", " \n", " #assert y_predicted.dim() == 1, \"did you forget to select first column with [:, 0]\"\n", " \n", " # compute loss, just like before\n", " loss = torch.mean( (y_predicted - y_batch)**2 )\n", " \n", " loss.backward() # add new gradients\n", " opt.step() # change weights\n", " opt.zero_grad() # clear gradients\n", " \n", " history.append(loss.data.numpy())\n", " \n", " if i % 1 == 0:\n", " print(\"step #%i | mean loss = %.3f\" % (i, np.mean(history[-10:])))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }