{ "cells": [ { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import math\n", "import scipy.special,scipy.linalg\n", "import numpy as np\n", "import time\n", "from matplotlib import pyplot as plt\n", "from sklearn.datasets import fetch_mldata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate Data (MNIST or Gaussian mixture)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def gen_data(testcase,Tr,Te,prop,means=None,covs=None):\n", " rng = np.random\n", " \n", " if testcase is 'MNIST':\n", " mnist=fetch_mldata('MNIST original')\n", " X,y = mnist.data,mnist.target\n", " X_train_full, X_test_full = X[:60000], X[60000:]\n", " y_train_full, y_test_full = y[:60000], y[60000:]\n", "\n", " selected_target = [7,9]\n", " K=len(selected_target)\n", " X_train = np.array([]).reshape(p,0)\n", " X_test = np.array([]).reshape(p,0) \n", " \n", " y_train = []\n", " y_test = []\n", " ind=0\n", " for i in selected_target:\n", " locate_target_train = np.where(y_train_full==i)[0][range(np.int(prop[ind]*Tr))]\n", " locate_target_test = np.where(y_test_full==i)[0][range(np.int(prop[ind]*Te))]\n", " X_train = np.concatenate( (X_train,X_train_full[locate_target_train].T),axis=1)\n", " y_train = np.concatenate( (y_train,2*(ind-K/2+.5)*np.ones(np.int(Tr*prop[ind]))) )\n", " X_test = np.concatenate( (X_test,X_test_full[locate_target_test].T),axis=1)\n", " y_test = np.concatenate( (y_test,2*(ind-K/2+.5)*np.ones(np.int(Te*prop[ind]))) )\n", " ind+=1 \n", " \n", " X_train = X_train - np.mean(X_train,axis=1).reshape(p,1)\n", " X_train = X_train*np.sqrt(784)/np.sqrt(np.sum(X_train**2,(0,1))/Tr)\n", " \n", " X_test = X_test - np.mean(X_test,axis=1).reshape(p,1)\n", " X_test = X_test*np.sqrt(784)/np.sqrt(np.sum(X_test**2,(0,1))/Te)\n", " \n", " else:\n", " X_train = np.array([]).reshape(p,0)\n", " X_test = np.array([]).reshape(p,0) \n", " y_train = []\n", " y_test = []\n", " K = len(prop)\n", " for i in range(K): \n", " X_train = np.concatenate((X_train,rng.multivariate_normal(means[i],covs[i],size=np.int(Tr*prop[i])).T),axis=1)\n", " X_test = np.concatenate((X_test, rng.multivariate_normal(means[i],covs[i],size=np.int(Te*prop[i])).T),axis=1)\n", " y_train = np.concatenate( (y_train,2*(i-K/2+.5)*np.ones(np.int(Tr*prop[i]))) )\n", " y_test = np.concatenate( (y_test,2*(i-K/2+.5)*np.ones(np.int(Te*prop[i]))) ) \n", " \n", " X_train = X_train/math.sqrt(p)\n", " X_test = X_test/math.sqrt(p)\n", " \n", " return X_train, X_test, y_train, y_test\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate $\\sigma(\\cdot)$ activation functions" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def gen_sig(fun,Z,polynom=None):\n", " \n", " if fun is 'poly2':\n", " sig = polynom[0]*Z**2+polynom[1]*Z+polynom[2]\n", " elif fun is 'ReLu':\n", " sig = np.maximum(Z,0)\n", " elif fun is 'sign':\n", " sig = np.sign(Z)\n", " elif fun is 'posit':\n", " sig = (Z>0).astype(int)\n", " elif fun is 'erf':\n", " sig = scipy.special.erf(Z)\n", " elif fun is 'cos':\n", " sig = np.cos(Z)\n", " elif fun is 'abs':\n", " sig = np.abs(Z)\n", " \n", " return sig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate matrices $\\Phi_{AB}$" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def gen_Phi(fun,A,B,polynom=None,distrib=None,nu=None):\n", " normA = np.sqrt(np.sum(A**2,axis=0))\n", " normB = np.sqrt(np.sum(B**2,axis=0))\n", " \n", " AB = A.T @ B\n", " angle_AB = np.minimum( (1/normA).reshape((len(normA),1)) * AB * (1/normB).reshape( (1,len(normB)) ) ,1.)\n", " \n", " if fun is 'poly2':\n", " mom = {'gauss': [1,0,3],'bern': [1,0,1],'bern_skewed': [1,-2/math.sqrt(3),7/3],'student':[1,0,6/(nu-4)+3]}\n", " A2 = A**2\n", " B2 = B**2\n", " Phi = polynom[0]**2*(mom[distrib][0]**2*(2*AB**2+(normA**2).reshape((len(normA),1))*(normB**2).reshape((1,len(normB))) )+(mom[distrib][2]-3*mom[distrib][0]**2)*(A2.T@B2))+polynom[1]**2*mom[distrib][0]*AB+polynom[1]*polynom[0]*mom[distrib][1]*(A2.T@B+A.T@B2)+polynom[2]*polynom[0]*mom[distrib][0]*( (normA**2).reshape( (len(normA),1) )+(normB**2).reshape( (1,len(normB)) ) )+polynom[2]**2\n", " \n", " elif fun is 'ReLu':\n", " Phi = 1/(2*math.pi)* normA.reshape((len(normA),1)) * (angle_AB*np.arccos(-angle_AB)+np.sqrt(1-angle_AB**2)) * normB.reshape( (1,len(normB)) )\n", " \n", " elif fun is 'abs':\n", " Phi = 2/math.pi* normA.reshape((len(normA),1)) * (angle_AB*np.arcsin(angle_AB)+np.sqrt(1-angle_AB**2)) * normB.reshape( (1,len(normB)) )\n", " \n", " elif fun is 'posit':\n", " Phi = 1/2-1/(2*math.pi)*np.arccos(angle_AB)\n", " \n", " elif fun is 'sign':\n", " Phi = 1-2/math.pi*np.arccos(angle_AB)\n", " \n", " elif fun is 'cos':\n", " Phi = np.exp(-.5*( (normA**2).reshape((len(normA),1))+(normB**2).reshape((1,len(normB))) ))*np.cosh(AB)\n", " \n", " elif fun is 'erf':\n", " Phi = 2/math.pi*np.arcsin(2*AB/np.sqrt((1+2*(normA**2).reshape((len(normA),1)))*(1+2*(normB**2).reshape((1,len(normB))))))\n", "\n", " return Phi" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate $E_{\\rm train}$ and $E_{\\rm test}$" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def gen_E_th():\n", " d=0\n", " dt=-1\n", "\n", " while np.abs(d-dt)>1e-6:\n", " dt=d\n", " d=np.mean(L/(L*n/Tr/(1+d)+gamma))\n", " \n", " L_psi = L*n/Tr/(1+d)\n", " L_bQ = 1/(L_psi+gamma)\n", "\n", " # E_train\n", " E_train_th = gamma**2*np.mean(Uy_train**2*L_bQ**2*(1/n*np.sum(L_psi*L_bQ**2)/(1-1/n*np.sum(L_psi**2*L_bQ**2))*L_psi+1)) \n", " \n", " #E_test\n", " E_test_th = np.mean((y_test-n/Tr/(1+d)*UPhi_cross.T@(L_bQ*Uy_train))**2)+(1/n*np.sum(Uy_train**2*L_psi*L_bQ**2))/(1-1/n*np.sum(L_psi**2*L_bQ**2))*(np.mean((n/Tr/(1+d))*D_Phi_test)-Tr/Te*np.mean( (n/Tr/(1+d))**2*D_UPhi_cross2*(1+gamma*L_bQ)*L_bQ))\n", " \n", " return E_train_th,E_test_th\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Main code" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time for Theoretical Computation 0min 1s\n", "Time for Simulations Computation 0min 16s\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdcAAAEACAYAAADhvzxWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8U1XaB/DfSUtbyr4vBdpXoEAFAalUZxhBEAQFFMuA\nUkVAQQTGFwRGpYgIgqIMCoKUvSgFYRCYsqgviizjArZYsOyLlEVK2VqWLjTNef84RNI2KUmb5N4k\nv+/nk0/bm5vkuZTk6Tn33OcRUkoQERGR8xi0DoCIiMjbMLkSERE5GZMrERGRkzG5EhERORmTKxER\nkZMxuRIRETkZkysREZGTMbkSERE5GZMrERGRkzG5EhEROZm/1gGUpGbNmjIsLEzrMIiI9CsrC/L4\nCQhI3EIA/sAtXJJSaB2Wr9N1cg0LC0NSUpLWYRAR6YuUwNdfI3/iZJQ7fhyn0Agz8AaqIBPfYoLW\n0RF0nlyJiMjC7aRaMGky/JL24A8RimliETLC/4IZR55CMxxDpNYxEgCecyUi0j8pga++ginqQeDx\nx3FubwZewiKM63UUo1NfwobDEbj4yts46xeqdaR0m9Bzy7nIyEjJaWEi8jkJCUBsLJCWBtSuDVmx\nEsTJEzjjH4Z3jLE49beBePeDADz4YPGHCiGSpZQcwGqM08JERHqSkADj4KHwz8+BBCAyMvBHhj/e\nxiKktBiIqTMC0L07IBxYspScnFzb399/MYCW4Iyls5gApBqNxpfatWuXUfROJlciIh258eqbqJif\nAwBIR128hanYhQ74Z8X5WJgSAEMpUqO/v//iunXrtqhVq9ZVg8Gg3+lKD2IymcTFixcj0tPTFwPo\nXfR+Xf4FI4ToJYRYmJWVpXUoRETus38/Klw5AxOAyXgbD2AP2iAFv+E+DL7xSakS620ta9WqdY2J\n1XkMBoOsVatWFtRsQPH73RyPXaSUG6WUw6pUqaJ1KEREriclsHAhZFQUrqA6umAb0hCKQ4jAKMxD\nAPJxGo3K8goGJlbnu/1vajWP6jK5EhH5jGvXgAEDgJdfxg9+DyMCB/AENmEphqASbgAAbiIYs2pM\n0zhQcgSTKxGRVvbuBdq1g2nNv/FO4HT0CfwKL02oiwPl7kcaQmGCwCmEYlS5hYiaHeO2sOLiUL1+\nfbQyGNCufn20iotD9bI+p5+fX7vmzZtHmG8TJkyoa2vfKVOm1L5+/brD+Wn06NH1N2zYUKlskToH\nFzQREbmblMC8eZBjx+JGUC08btqO68074JcNQFgYkBARg06xMTh9GmjUCJg2DYhxU26Ni0P1MWMQ\nmpurBl/nzyNgzBiEAsDw4bhS2ucNDAw0HT58+KA9+y5YsKDO0KFDr1SqVMlU9D6j0Qh/f+up6+OP\nP/6jtPE5G0euRETulJkJ9O0L/OMf+LVmV/zPtRSE9O+AH35QiRVQifTUKcBkUl/dlVgBYMoUhJgT\nq1luLgxTpiDEHa//7rvv1s7IyCjXsWPH8KioqHAACA4Objt06NAGzZo1i/juu+8qjhs3rl7Lli1b\nNG3a9N5nn3021GRSOTg6Ojps2bJl1QAgJCSk1ZgxY+pHRES0CA8Pj/j111+D3BG/GZMrEZG77NkD\ntG0LmZiIWfU/xAN/JGL8+zWxahVQoYLWwSnp6QhwZLu98vLyDJbTwosWLapmbb+JEydm1K5dO3/H\njh1Hd+/efRQAcnJyDFFRUTePHDly8LHHHrsxfvz4jNTU1EPHjh07kJOTY/jiiy+srn6tWbOm8eDB\ng4eGDBly8f33369TlvgdxWlhIiJXsay0VK0acO0acms2wFMVduHnmw9i0xagRw+tgyysbl3cOn++\neCKtWxe3yvK8jkwLF+Xn54dBgwZdNf/81VdfVZo1a1bd3NxcQ2Zmpn9EREQOgGLXbg4YMOAqALRv\n3z47MTHRajJ3FY5ciYhcISEBGDZMJVYAuHoVJ0xhaJCRjLR6D2LPHv0lVgCYNAnngoJQ6FxnUBBM\nkybhnFYxBQQEmMznWbOzs8XYsWND161bd+Lo0aMHn3vuuUu5ublWc1lQUJAEAH9/f2k0Gt3aho/J\nlYjIFWJjgexsAIAEsBLPoIk8hr8E7cXu3UB4uLbh2TJ8OK589BHS6tXDLSGAevVw66OPkFaWxUyO\nqlChQkFWVpbV/JSdnW0AgLp16xqzsrIMGzdudOuI1F6cFiYicgGZlgYBoAAGvIjFWI7BeAtT8HbO\nO/CrXKB1eCUaPhxXnJ1MzedczT937tw569NPP7U6Gn7hhRcude/ePbxOnTq3zOddzWrWrFkQExNz\nsUWLFvfWqlXL2Lp165vOjNNZdNkVRwjRC0CvJk2aDD127JjW4RAROebwYRS0aAkJoAu+RTIi8RkG\n4mmsx1m/UDQwnnLZS1vrirNv375TrVu3vuSyF/Vh+/btq9m6deuwott1OS3M8odE5LEOHwYeeQTX\nUAmP4WscQgR+wkN4GutxE8F4vYCVlnwBp4WJiJzl8GGgUyeYJNDR/0ccMjbBl3ga9+IATiEUEzAN\nP4a68aJVnevatWvjM2fOBFpumzZt2tno6OhrWsXkLEyuRETOcOgQ8MgjkACGhH2Pg5dbwD8QeDJv\n45+7BAcDCzlw/dPWrVtPaB2Dq+hyWpiIyKNYJNbX2m7H8j0tsHw5sGQJEBqqGpuHhgILF7q32hJp\nhyNXIqKyOHgQ6NwZUghM6/I9Pk5ojpkz7yRRJlPfxJErEVFp3U6sEAILn/kebyU0x2uvAWPHah0Y\naY3JlYioNA4eBB55BBAC61/9HsM/bo6YGODDD7UOjPSAyZWIyFEHDqjEajBg++Tt+PtbzdGtG7B0\nKWDgp6pVeuvn+uWXX1Y2xxIcHNw2LCysZfPmzSP69OkTNmfOnBoDBw5s5OjrW+I5VyIiRxw4oKaC\n/fyQ8tH3eHxwM7RpA6xdCwSUqW+Md9NbP9fo6Ohr0dHRBwGgffv2zWbOnHnm4YcfzgaAOXPm1LD3\neWxhciUiuhvL7jYGA1C5Mk6u/BldnmuG+vWBLVuASnaNl7Q3ZAgapqYi2JnP2bIlspcuxRlnPJdl\nP9dq1aoZd+/efTQ4OLhtTEzMxZ07d1aeM2fO6a1bt1b6+uuvq+bl5RkiIyNvJCQkpBkMBkRHR4f1\n7Nkza/DgwVdDQkJa9evX7/I333xTxWg0itWrV59s27Ztrr1xpKenl/vb3/7W9PTp04E9evTIjIuL\nO+vIcXACg4ioJEW725hMuJBdEY883wD+/sA33wC1a2sboifwtH6uBw8eDN6wYcPJQ4cOHUhMTKx2\n/Pjxco48XpcjV4vawlqHQkS+zqK7DQBcQ0V0vbUZV65I7EgCGjfWMLZScNYI01Ge1s+1Q4cO12rU\nqFEAAE2aNMk9ceJEYJMmTfLtfbwuR66sLUxEumEesQLIQzn0wiYcRnOsl31w//0axuVDtOjnGhAQ\n8GdXGz8/P5mfn+/Q43WZXImIdCE3FwhUpW8LIDAAq7ATHfEZBuLRUHbschX2cyUi8lYmEzB4MJCX\nh1wEYCxmYR2i8RFGoxc24r+PL0QHrWP0IOznqiORkZEyKSlJ6zCIyBfFxgLTp+P9qu9jfWZH7MGD\nGIcPMRLz/uxuc+qU1kEWx36u7mWrnytHrkRERS1dCkyfDgwdijcX/ROAOt02E+MxE+MBAOK0hvGR\n7jG5EhFZ2roVePll4LHHcGPGPPgvEzAai+/WqEz1ewhgP1ciIt+Qmgr07Qu0aAG5eg1eGVUOBQVq\nTVNe3p3dgoOBaezLWmbs50pE5O3++AN4/HGgYkVg82YsXVsZK1YA77zDvqzkOI5ciYhu3AB69QKu\nXAF27cJvmQ0xahTQpQswYQLg58dkSo5hciUi31ZQADz7LJCSAmzciBtN26LfA0DVqqryoZ+f1gGS\nJ+K0MBH5LimB0aOBTZuATz6B7PE4RowAjhxRibWOQ9VoqSTuaDkHAJ9//nnV5OTkoNJH6hxMrkTk\nu2bPBubOBcaOBUaMQHw88PnnwNtvq65yPisurjrq128Fg6Ed6tdvhbi46mV9SnNtYfNt+vTp6bb2\nXbBgQZ0bN26UKj9t2LCh6v79+8uXPlLnYHIlIt+0fj3w2mtAdDTwwQc4cAAYOVIl1YkTtQ5OQ3Fx\n1TFmTCjOnw+AlMD58wEYMybUGQnWHpYt56KiosIBYN26dZXbtGnTPCIiokWPHj3uMZdGHDFiREjj\nxo3vDQ8Pjxg2bFiDrVu3Vvj222+rTpw4sUHz5s0jDhw4EFjyq7kOKzQRke/Zswfo1Am47z7g++9x\n01QeDzyg1jOlpAB1bU5Y6l+ZKzTVr98K588Xb/ter94t/PHHb6WNy8/Pr13Tpk1zzD+PHTv2/NCh\nQ69a2zckJKRVUlLSoXr16hnPnz/v36tXr8bbtm07VrlyZVNsbGzdvLw8MW7cuIwHH3ywxcmTJ1MN\nBgMuXbrkV7NmzQLLnq6ljdURHlWhiS3niMglEhKAf/5TXXbj7w8MHAiUL4+Rg4DDh1X9CE9OrE6R\nnl48sZa03U6lbTm3ffv2CidOnAhq3759cwDIz88X7dq1u1GjRo2CwMBAU//+/cN69uyZ2b9//2It\n57Sky2lhtpwjIqdLSACGDlWJFQCMRmD8eMS//BOWLwfeektdeuPz6ta95dB2F5NSokOHDtfM52pP\nnDhxYM2aNWnlypVDSkrKob59+17dtGlT1U6dOjXVIj5bdJlciYicbsIEICen0KaD2aEYuag1OnUC\nJk3SJizdmTTpHIKCTIW2BQWZMGmS1Q42rmDZcq5Tp043k5KSKqampgYCwLVr1wz79+8PzMrKMly5\ncsWvf//+WXFxcWcOHz4cDAAVK1YsuHbtmua5TZfTwkRETmUyAacLV9q/iWD8Hf9GRXkdK1cG83pW\ns+HDrwAApkwJQXp6AOrWvYVJk879ub2UytJybsGCBaeeeeaZe27duiUA4O233z5XpUoVU8+ePZvk\n5eUJAJg6deoZAIiJibnyyiuvhMXFxdVZu3btiXvvvTfP2mu4Ghc0EZF3M1/LOmdOoc1DsATxGIRv\nag9E1wsrNArO+dhyzr1sLWjSfOhMRORS//qXSqw9eqiK+wA+w/NYhiGI9f8AXWf10DhA8kacFiYi\n77VqFTB+PNC/P7ByJf47ahXyFyzFK6b5iMJudHmxERAzQOsofRZbzhEReZpt24AXXgA6dgSWL0fC\nKgOGxscgx6Qq8O9GFJ74PAoL/8ai/FphyzkiIk+yfz/Qpw8QHg5s2AAEBiI2tthiYWRnA7Gx2oRI\n3o3JlYi8y+nT6vxqpUrAV1+p9jYA0tJs707kbJwWJiLvcfWqSqw3bwK7dgENGwIAzp4FDAZ1RU5R\njRq5OUbyCUyuROQdcnOBJ58Ejh8HvvkGaNUKgGrX+txzQLlygBBqN7PgYGDaNI3iJa/GaWEi8nwm\nE/D882q0+tlnqij/be+/D+zYAcyfDyxeDISGqiQbGgosXMjFTO6idT/X119/va75tS1jeffdd2tH\nR0eHLVu2rFppXs8WjlyJyLNJqVrHrV2rrmnt3//Pu376SfVm7d8fGDRIJVUmU204Urh/wYIFdYYO\nHXqlUqVKVibyS7Zhw4aqRqMxq127drmW22fMmJE+Y8aMdAAIDg5uaxlLdHR0mKOvczccuRKRZ5s1\nSzU9HzNGJdnbsrKAAQPUade4OJVYSf+06ue6Y8eOim3btm3eoEGDVs4YxXLkSkSeJyFBXUNjXgIc\nFQXMnPnn3VICr7wCnDmjZopvLxgmABgypCFSU4Od+pwtW2Zj6dIzJe1StLawrX6uEydOzJg/f36d\nHTt2HDX3c50+fXq9nTt3HjX3c506dWqdcePGZWzZsqVa0X6ujz76aGZp+rleuHChXFJS0uGUlJSg\nPn36NClrP1gmVyLyLAkJwLBh6iJVs/37VTWm23O+n32mfpw6FXjoIY3ipEL03s+1d+/emX5+fmjX\nrl3u5cuXy5XluQAmVyLyNLGxhRMroKpDxMYCMTE4ehQYOVIVZnrzTW1C1LW7jDD1xtzPdePGjb8X\nvS8lJeVQYmJi5bVr11abP39+7Z9//vloaV8nKCjozy42zmhoo8tzrkKIXkKIhVlZumosT0R6UEI1\niFu31HnWgADg88/BNnIeyhv6uWoegDVSyo1SymFVqlTROhQi0pPVq23f16gRYmOB5GRgyZI/60eQ\nTpjPuZpvI0aMCLG1r7mfa1RUVHj9+vWN5n6u4eHhEZGRkc1/++23oMzMTL/u3bs3DQ8Pj3jooYea\nWfZznTNnTt0WLVo4tKDJ2djPlYg8w8KFwPDhQLNmavRqWSg4OBj/949EPDajC4YPV9e0+ir2c3Uv\n9nMlIs/1wQfAyy+r0oZ79wKLFhWqBpHx4XIMjO+CiAh1qSuR1vSdXJOTgbAwtTrQHgkJan+DwbHH\nEZE+SalWJb3+OvDMM8D69UD58mpV8KlTgMkE+fspDN7cF5mZwBdf/NkPnTxA165dG1tOFTdv3jzi\nyy+/rKx1XM6g/9XCaWnAiy8CJ06oFlKBgUBQ0J2vQUFq9cLKlYWX56elqZ+Bu5dkMV8zd/q0quI9\nbRrLuBBpzWRSy37j4tSodd48qyuU5swBtmwBPvnkz3LC5CG8uZ+rvs+5CiHLfMY1KAjo3Vu1n7J2\nS04GPv0UyMu785jgYPuKjjIpE7lGfr6qV7hypRq1vvee1RJLv/4KPPgg8NhjwH/+wypMgM1zridb\ntWp11WAw6PcD3wOZTCbx22+/VWvduvU9Re/T/8jV0tq1Kgnm5t75av7+nXesPyY3F9i3D7h+Xd1u\n3FBTTSXJzlZv7Llzgdq11a1WrTvf164N/PKLukLdvKjCkZEyEdmWkwP06wds2qSS6htvFNslIUHN\nFp85owazjz/OxHoXqRcvXoyoVatWFhOsc5hMJnHx4sUqAFKt3e85yTU0FIiOtn1/fLz1699CQ4HD\nh+/8bDKp5GlOts2aWX8+o1GNbNPSVCLNyFC9q0qSnQ2MGqVGy02aAI0bAxUrFt6Ho10i265dUzNN\nO3eqJb/DhxfbpWiBpoICYOxY9XblW8k6o9H4Unp6+uL09PSW0PtaG89hApBqNBpfsnanZ0wL2zNN\na60kmj2PCwuznZRPnbrzs8kEZGaqJJuRocq/2KNOHZVkmzRRDZw3bgRu3XIsRiJfcOmSWg2ckqLq\nFz77rNXd7H3L+ipr08LkfvpPrqGh9o/uSjMqdHZSbtgQ2LBBLcA6flx9NX9/9qz156pWDUhMBO6/\nn0sdybdYFuAvd7uc6/r1wBNP2HyIrelfIdTfwL6OyVUnpJS6vbVr1066xYoVUoaGSimE+rpihX2P\nCQ6WUp3BVbfg4JIfa7mvtZufn5Rt2kg5bJiUixdLuX+/lEZj6WMk0jNr76HAwBL/b6enq7eJtbdP\naKj7QtczAElSB5/fvn7T98hV7xWaHB0p2xrtNmigzi/t3g3s2aNumZnqvgoV1Gj4+HF1HtiM08nk\n6Ryc383PBx59FPj5Z3Upe65FK2y+He7gyFUfmFzdyd4paClVMjUn2wULCp+nNatTBzh3jtXJyfPk\n5qpiENbYmN999VV1LWtCgnqLcF2gdUyu+sDk6m6lOS9sMNi+fKhGDbUIpGdPdbEfu0KT3qWmqtY1\nv/1m/X4rI9f4eGDwYGDMGGDWLJdH6NGYXPWBS7LdzaJsG06dsu/P7UaNrG+vWVMt/PjqK1UarlYt\noHNn9elz9HZbQ5aEJL2QUl07HhmpVtyPH198AV9wsPqD08Ivv6grcjp3ViWGiTyC1id9S7q5bUGT\n3t1t8ZTRKOUPP0j55ptStmx5Z5+6daX097f9OCJ3uXBByieeUP8Hn3hC/SzlXRfqXbggZYMG6q6L\nF90dtGcCFzTp4sZpYU/hyHTyqVPA5s3AuHGFV32YNWpku+E0kbN9/bWqeJaZCcycqeoF21FOybyA\nac8e4Icf1JVqdHecFtYHTgt7Ckemk8PC1AeYZb1kS6dPq/m18+ddECjRbbm5wOjRak1ArVpAUpKq\nYGZnncJx41ShpkWLmFjJ8zC5ejNb52oDA1Ux9AYN1Dlbc81mImc5cACIigJmzwb+8Q81/GzZ0u6H\nf/aZ6nYzejTw3HMujJPIRZhcvdm0adYXjCxZAhw5ohLsvn3A3/8O1K+vPgSTk+/e2ICoKPPCOSHU\nCva2bYH0dHV6Ys4c25fdWJGcrK5Y69QJ+PBDl0VM5FJMrt4sJkZdQxsaqj70QkPvXFMbHg5Mn67O\nvX79NdCtm5p/i4wEWrdW+zRsyFXGdHfm67fN5/GvXFHV9N9+W7WrcUBGhmrbXKcOsGYN4O85rUWI\nCuGCJrrj6lVg9Wq16OREkR7G5cur5Msr9amo0FB1Ht/adgcq6efnA127qtopXMBUelzQpA+6HLkK\nIXoJIRZmZWVpHYpvqVZNXVBoWWbRLCdHTRubyzISSQmsW2c9sQK2t9swfjywY4eaXGFiJU+ny+Qq\npdwopRxWpUoVrUPxTbY+FK9eBUJCVAK2VV2HfMNPPwEdOqgey+ZuNkXZWlBnwfJU7ezZqsjY8887\nN1QiLegyuZLGbH0o1q2rKkEtXw7cd5/qabtmjZrPI99w/DjQty/wl78AJ0+qYebixXZVWiqq6Kla\nQF16w9P75A2YXKk4W6uMZ85UK43PnlXLOM+cAfr3V0OPKVPUdbMst+idLl1SlfNbtFAL4CZPBo4d\nA4YOBQYOtL1wrgSxsYV7WADq7ENsrOsOg8hduKCJrLOnIlRBgfqgnTtXfTV35ykouLMPe4F5tpwc\nNV/73nvAjRsqmU6erGYxyohNz12DC5r0gcmVnOPYMaBdO+D69eL3ObhqlDRk/qMqLU1dryqlurSm\nVy9gxgw1cnWCa9dU0SZrnRT536VsmFz1gdPC5BxNm6qRjTVpaVwA5QkSEtTI1HwS9PJltTp8wgQg\nMdFpiTU3F3jqKbUoPTCw8H12nKol8ghMruQ8Ja0Ove8+tbo0IYGlFvXoxg11TjUnp/B2k8mp582N\nRrUm7vvvVYnDJUscPlVL5BGYXMl5bC2Emj9fLYa6cEEVim3QQJVePHlSmzjpjpMngbFj1e/kyhXr\n+zh4vaotJhPw0kvAf/4DfPKJSqKlaW9M5AmYXMl5bJVbHD5cfYAfOQJs3Qo8/DDwr38BjRsD3bur\nT9vPPuMqY3eREti2DXjySaBJE1X7t0cP24uU7Lhe1Z6XHDtWXcX1zjuqOQ6RV9O6oWxJNzZL92Jn\nz0o5ebKUISFSAqpZNpu6u9bNm1IuWCDlvfeqf+OaNaWMjVW/CynVv3dwsEt+D1Onqqf73/+V0mQq\n89NRCcBm6bq4aR5ASTcmVx+Qny9lrVqy0Ae6+daggdbRea4VK6QMDVV/tISESPnEE1JWq6b+Xdu0\nkXLZMilzckp+XGioUxLr3LnqZQcOlLKgoMxPR3fB5KqPGy/FIe0ZDLbb3D3xhFoB07s3ULmye+Py\nVOZVv0UXJ7Vvr6bj//pXuxuWl9XKlepsQe/ewJdfssuNO/BSHH3gOVfSnq1zepUrq0t4nn9e9SDr\n2xf497+Ll/UhJTsbWL8eePnl4okVUAvKOnRwW2LdvBl44QXVl3X1aiZW8i1MrqQ9W6uMP/0U+P13\n4McfVRHaH34A+vUDatdWw6GNG4H4eN9eCHXliloM1qcPULMm8PTTwM2b1vd10qpfe+zapf4Wat1a\nrVcLCnLbSxPpg9bz0iXdeM7Vh9hzrs9olHLbNimHDZOyRg1p9Tytty2EsvbvcvaslPPmSdmli5R+\nfuq4Q0KkHDlSyu++k7JRI+v/NqGhbgl5714pK1eWslkzKTMy3PKSZAE856qLG8+5kmfKz1fXZmZk\nFL+vUiV1CVCHDmofT2VuG2M5DW4w3Cm826yZGrH26QNERqr7bD3OxTWeLasmGgyqNfCvvwING7rk\n5agEPOeqD5wWJs9Urhxw8aL1+65fB559Vn2yh4WpwhVxcUBq6p3EpNfuPVKq6dvERGDkyOLnl00m\noGpV4OBB4PBhVVC/ffs7iRWwfb2xCxOrZes4k0nNTO/c6ZKXI/IIHLmS5woLK9wM1KxRI7Ww57//\nVbddu4D0dHVf1arqcQcOFO5Da+/Izp5uQfY+rl8/4NAhICWl8O3q1ZKfS2dtY2z9GliAXxscueoD\nkyt5LnunP6VUC6PMyTY+3nqD96AgtQqnXj11q1+/8PcbNtz99aRUrV5u3lT1em/cANauBaZPL1xT\nWQjVos9ovPPa990HtGkDtG2rvvbrp3rmFqWzrMXWcfrC5KoPTK7k2UozkizputqwMNX03VpzASGs\nP87fX5UONCdTc8K8m8qV1XR1mzaqq1DRa1U0OHfqqFOnVOjWDllnfwP4DCZXfeCVZ+TZzNXfHdGo\nke15zN9/Vwk0MxP44w+VaM1f33jD+vMZjUDXrkDFinduFSrc+f6ZZ6w/znxuuKRjA0o3De0GJ04A\nnTsDAQFqEG759whbx5Gv48iVfE9pR4SlPbnohScljxwBunRRvVm3blXrq3T6N4DP4chVH7hamHxP\naVfT2ip2cbchWmkfp1MHD6qqS7duqb6sbduydRxRUUyu5JtKkw1Km5TdfGmMK+3frxIrAGzfDrRq\npWU0RPrFaWEissveverUcvnyqh1seLjWEZE1nBbWB45cieiu9uxR51grVlTFIZhYiUrG5EpEJfrx\nR+DRR4Hq1VVivecerSMi0j8mVyKyaccOoFs3VUdjxw51upiI7o7JlYis+u47oEcPdWnN9u2e3QOB\nyN2YXInoT+Z+BkKoxUs1a6rEWq+e1pEReRYmVyICULy7jZTApUuqSAQROYbJlYgAqApLRTvc5eSo\n7UTkGF0mVyFELyHEwqysLK1DIfIZ1io0AqqkIRE5RpfJVUq5UUo5rEqVKlqHQuQT5s+3fV+jRu6L\ng8hb6DK5EpH7/OtfwIgRqkZw+fKF7/PgEshEmmJyJfJRUgJTpgDjxqm+7Lt3A4sWeUUJZCLNsZ8r\nkQ+SUrWn/eADYNAgYPFi1ZO1NO1xiag4JlciH2MyAa++Csybp6aDP/kEMHAOi8ipmFyJfEhBAfDS\nS0B8vJrzsRf+AAAOmElEQVQO/uADNQVMRM7Fv1eJfER+vpryjY8HJk9mYiVyJY5ciXxAbi7Qvz+Q\nmAh8+KEatRKR6zC5Enm57GzgqadUGUPzeVYici1OCxN5IcsC/NWrA99+q6aDmViJ3IMjVyIvYy7A\nb64TnJcHBAQA/ny3E7kNR65EXsZaAf5bt1iAn8idmFyJvAwL8BNpj8mVyIucO2d7+pcF+Inch8mV\nyEucPg107KiSa2Bg4ftYgJ/IvZhcibzA77+rxHrpErB9O7BkCQvwE2mJ6weJPNyxY0DnzmoR07Zt\nwP33A1FRTKZEWmJyJfJghw4BXbqo0obffw/cd5/WERERwORK5LFSU1ViFUJNBd97r9YREZEZz7kS\neaCUFKBTJ7V4accOJlYivWFyJfIwSUnqHGuFCsDOnUCzZlpHRERFMbkSeZCfflJTwVWrqhFr48Za\nR0RE1jC5EnmIXbuAbt2AOnVUYg0L0zoiIrKFyZVIxyy723TsCFSurBJrw4ZaR0ZEJWFyJdIpc3cb\nc61gKYGrV9W1rESkb0yuRDplrbtNTg672xB5AiZXIp1idxsiz8XkSqRDX35p+z52tyHSPyZXIp1Z\nvRro3x8IDwfKly98H7vbEHkGJlciHVmxAhgwAPjrX4HkZGDRIna3IfJErC1MpBPx8cCQIcAjjwCJ\niaoCU0wMkymRJ+LIlUgHFi0CBg8GunYFNm1SiZWIPBeTK5HGPv1UXc/6+OPAf/5T/DwrEXkeJlci\nDc2eDYwcCfTuDaxbBwQFaR0RETkDkyuRRmbOBEaPBp5+Gvj3v4HAQK0jIiJnYXIl0sB77wHjxwP9\n+gFffAEEBGgdERE5E5MrkRuYC/AbDKpd3IQJ6pKbhASgXDmtoyMiZ+OlOEQuZi7Ab64TnJUF+PkB\n3bsD/nwHEnkljlyJXMxaAf6CAuCtt7SJh4hcj8mVyMVYgJ/I9zC5ErmQ0Wi7IAQL8BN5LyZXIhfJ\ny1MF+G/eLL5oiQX4ibwbkyuRC2RnA08+qQpDfPQRsGwZC/AT+RKuVSRysmvXgJ49gf/+F1i8GHjx\nRbWdyZTIdzC5EjnR5cvqEpuUFGDVKjUtTES+h8mVyEnOn1ddbY4fV9PBvXppHRERaYXJlcgJ0tKA\nRx9VCXbzZqBLF60jIiItMbkSldHRoyqxXrsGbN0KPPSQ1hERkdaYXInKYP9+oFs3wGQCtm8H2rTR\nOiIi0gNeikPkAMsC/PXqAX/5i6oPvHMnEysR3cGRK5GdihbgT09X161OnQo0b65tbESkL24buQoh\n7hFCLBFCrHXXaxI5k7UC/FICs2drEw8R6ZddyVUIsVQIkSGESC2yvbsQ4ogQ4rgQ4o2SnkNKeVJK\n+WJZgiXSkq1C+yzAT0RF2TstHA9gLoDPzBuEEH4A5gHoCuAsgF+EEIkA/AC8V+TxQ6SUGWWOlkgj\nJhNQqZJaEVwUC/ATUVF2JVcp5U4hRFiRze0BHJdSngQAIcQXAJ6UUr4HoKczgyTSUm4uMHiwSqz+\n/qrTjRkL8BORNWU55xoC4IzFz2dvb7NKCFFDCBEHoK0Q4s0S9hsmhEgSQiRdvHixDOERld3ly6rq\n0hdfADNmsAA/EdnHbauFpZSXAQy3Y7+FABYCQGRkpHR1XES2nDwJ9OgBnDqlkqu5TvBzz2kaFhF5\ngLIk13MAGlr83OD2NiKPt3u3qg1cUAB89x3QoYPWERGRJynLtPAvAJoKIf5HCBEA4BkAic4Ji0g7\n69cDnTqpBUw//sjESkSOs/dSnFUAfgLQTAhxVgjxopTSCGAUgG8AHAKwRkp5wHWhErnexx8D0dGq\n2tLPPwPNmmkdERF5IntXCz9rY/sWAFucGhGRBgoKgNdeA+bMAZ5+GlixAihfXuuoiMhTsbYw+STL\nGsGNGgFRUSqxjhkDrFnDxEpEZcPawuRzitYIPnNG3QYOBGbN0jY2IvIOHLmSz7FWIxgAduxwfyxE\n5J10mVyFEL2EEAuzsrK0DoW8EGsEE5Gr6TK5Sik3SimHValSRetQyMtkZwMVKli/jzWCichZdJlc\niVzh8GG1cOnmTVUj2BJrBBORMzG5kk/4/HMgMhK4cAH4+msgPp41gonIdbhamLxadjYwapQquN+x\nI7ByJVC/vrqPyZSIXIUjV/JaBw8C7durUerEicC3395JrERErsSRK3ml5cuBESPU4qVvvlFt44iI\n3IUjV/IqN2+qxuaDBgEPPACkpDCxEpH7MbmSR7MsY1i/viq0v3w58NZbnAYmIu3oclpYCNELQK8m\nTZpoHQrpWNEyhufPq69vvAFMmaJdXEREuhy5sogE2cNWGcNVq9wfCxGRJV0mV6K7yc0F0tKs38cy\nhkSkNSZX8jibNwMtW9q+n2UMiUhrTK7kMU6cAHr1Anr2BMqVA15/XZUttMQyhkSkB0yupHvZ2Wr1\nb0QEsH078OGHwL59wPvvq7KFLGNIRHqjy9XCRAAgJbBuHfDaa+o8akwM8MEHhS+viYlhMiUi/eHI\nlXTB8nrVsDCVRLt1A/r2BapWVY3MV6zgdatE5Bk4ciXNFb1eNS3tzvnUTz4Bhg8v3iKOiEjP+JFF\nmrN1vWr16qqjDRGRp+G0MGnqyBHb16ueO+feWIiInEWXyVUI0UsIsTArK0vrUMgFTCZg0ybgsceA\n5s1t78frVYnIU+kyubL8oXfKzARmzQKaNlXXq6amAlOnAvPm8XpVIvIuukyu5LmKrvpNSFBJdPhw\nICQEGDtWrfhdvRo4dUo1MR8xgterEpF3EVJKrWOwKTIyUiYlJWkdBtmp6KpfQCVZkwkIClLJctQo\noE0b7WIk8nZCiGQpZaTWcfg6rhYmp3nzzeKrfk0mdZ3q8eNAjRraxEVE5G5MrlRqBQXA3r3A1q3q\nduaM9f2ysphYici38JwrWWXt3CkA/P67Oh/6978DtWsD7dur61SvXgUqV7b+XFz1S0S+hsmVijGf\nO01LU/V909KAQYOAOnWAe+4BXn4Z+OknoHdvtW96OpCSAnz6KVf9EhEBTK5ez9YI1FJBgVq5+3//\nB8ydC7zySvFzp0YjcP06MHs2cPCgmgJetgwYMEAlXUAtWOKqXyIirhb2GAkJavr19Gk1zTpt2t2T\nlrXVu4GBwPPPq3OgR4+q2/HjQF7e3WMQQi1QIiL94mphfdD1yDU52fZoyxp7RmlaP660jyk6TTt0\nKDBzJvDzz8CWLapjzJw5wOTJwKuvqsT70kvFR6B5ecDixaqYw+HDQJMmav9Fi1TnmfPnbZ8j5blT\nIiL76HrkKkSkBJIQGAiMHg107qy2Wwt52zY1ZWk5AgsMVInjkUcKP8b8vZSq+fbcuYUfFxCgpkYf\nfljtY+22a5dKUrduFX5cTAwQGammWk0m9dV827sX2LAByM+/8xh/f6BDB6BhQyA3V93y8gp/PXBA\nTcvaq2pVVfT+5Enb++Tn2+40Y23EGxzMKV4iT8CRqz54RHL1dn5+alQYGKiKLZi/mr9PTLT92M2b\n1RRv9erqVrWqej5AjYytFcUPDVXnWEtSmmloItIek6s+6DK5CiF6AegF1BgKhFnck5xs+1Ht2tm+\nr0yPqwLASgcBhx53+/vSxnhfK6BcQPHt+UZg/z7bj6tZHWgUCgiL6X9pAk6nAZeuFI6tEHu2WTm+\nQt/XBHDJdmx2sfFv79B+pTk+e47VW47P1vd6OD57tzv6fxMo+/G56ndnbbujx9dMSlnJjtjIlaSU\nur0BSCrj4xeWdT9b9xXdXtLP1r4v67G58vjs2ebNx2fPsXrL8ZXwvebH56r3njOOz9s/W3gr+03X\nC5qcYKMT9rN1X9HtJf1s6/uyctXx2bPNm4/P3mMtKz0cn6uOzZHns/c9Zmu7N/3ftLZdi+OjMtLl\ntLCZECJJeum5A28+NoDH5+l4fJ7Lm4/Nk+h95LpQ6wBcyJuPDeDxeToen+fy5mPzGLoeuRIREXki\nvY9ciYiIPA6TKxERkZMxuRIRETmZxyZXIUQFIUSSEKKn1rE4mxCihRAiTgixVgjxitbxOJsQ4ikh\nxCIhxGohRDet43E2IcQ9QoglQoi1WsfiDLffa8tv/868rk6Xt/2+ivL295teuT25CiGWCiEyhBCp\nRbZ3F0IcEUIcF0K8YcdTvQ5gjWuiLD1nHJ+U8pCUcjiAfgD+6sp4HeWk49sgpRwKYDiA/q6M11FO\nOr6TUsoXXRtp2Th4nE8DWHv7d9bb7cGWgiPH5wm/r6IcPD7dvt+8mRYj13gA3S03CCH8AMwD0ANA\nBIBnhRARQohWQohNRW61hRBdARwEkOHu4O0QjzIe3+3H9AawGcAW94Z/V/FwwvHdNvH24/QkHs47\nPj2Lh53HCaABgDO3dytwY4xlEQ/7j88TxcPx49Pj+81r2eiL4jpSyp1CiLAim9sDOC6lPAkAQogv\nADwppXwPQLFpXyFEJwAVoP4D5QghtkgpddFp1BnHd/t5EgEkCiE2A1jpuogd46TfnwDwPoCvpJR7\nXRuxY5z1+9M7R44TwFmoBJsCDzmV5ODxHXRvdGXnyPEJIQ5Bp+83b6aXN0oI7vxlDKg3c4itnaWU\nsVLK0VBJZ5FeEmsJHDo+IUQnIcQcIcQC6G/kao1DxwfgHwAeBdBXCDHclYE5iaO/vxpCiDgAbYUQ\nb7o6OCeydZzrAEQLIebDs0vsWT0+D/59FWXr9+dp7zev4PaRqzNJKeO1jsEVpJTbAWzXOAyXkVLO\nATBH6zhcRUp5Ger8lleQUt4EMFjrOFzF235fRXn7+02v9DJyPQegocXPDW5v8xY8Ps/m7cdn5u3H\nyeMjt9FLcv0FQFMhxP8IIQIAPAOghBbhHofH59m8/fjMvP04eXzkNlpcirMKwE8AmgkhzgohXpRS\nGgGMAvANgEMA1kgpD7g7Nmfg8fH4PIG3HyePz7OPzxuwcD8REZGT6WVamIiIyGswuRIRETkZkysR\nEZGTMbkSERE5GZMrERGRkzG5EhERORmTKxERkZMxuRIRETkZkysREZGT/T/pspA/hIoxbwAAAABJ\nRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "## Parameter setting\n", "n=512\n", "p=256\n", "Tr=1024 # Training length\n", "Te=Tr # Testing length\n", "\n", "prop=[.5,.5] # proportions of each class\n", "K=len(prop) # number of data classes\n", "\n", "gammas = [10**x for x in np.arange(-4,2.25,.25)] # Range of gamma for simulations\n", "\n", "testcase='MNIST' # testcase for simulation, among 'iid','means','var','orth','mixed',MNIST'\n", "sigma='ReLu' # activation function, among 'ReLu', 'sign', 'posit', 'erf', 'poly2', 'cos', 'abs'\n", "\n", "\n", "# Only used for sigma='poly2'\n", "polynom=[-.5,0,1] # sigma(t)=polynom[0].t²+polynom[1].t+polynom[2]\n", "distrib='student' # distribution of Wij, among 'gauss','bern','bern_skewed','student'\n", "\n", "# Only used for sigma='poly2' and distrib='student'\n", "nu=7 # degrees of freedom of Student-t distribution\n", " \n", "\n", "## Generate X_train,X_test,y_train,y_test\n", "if testcase is 'MNIST':\n", " p=784\n", " X_train,X_test,y_train,y_test = gen_data(testcase,Tr,Te,prop)\n", "else: \n", " means=[]\n", " covs=[]\n", " if testcase is 'iid':\n", " for i in range(K):\n", " means.append(np.zeros(p))\n", " covs.append(np.eye(p)) \n", " elif testcase is 'means':\n", " for i in range(K):\n", " means.append( np.concatenate( (np.zeros(i),4*np.ones(1),np.zeros(p-i-1)) ) )\n", " covs.append(np.eye(p))\n", " elif testcase is 'var':\n", " for i in range(K):\n", " means.append(np.zeros(p))\n", " covs.append(np.eye(p)*(1+8*i/np.sqrt(p)))\n", " elif testcase is 'orth':\n", " for i in range(K):\n", " means.append(np.zeros(p))\n", " covs.append( np.diag(np.concatenate( (np.ones(np.int(np.sum(prop[0:i]*p))),4*np.ones(np.int(prop[i]*p)),np.ones(np.int(np.sum(prop[i+1:]*p))) ) ) ))\n", " elif testcase is 'mixed':\n", " for i in range(K):\n", " means.append( np.concatenate( (np.zeros(i),4*np.ones(1),np.zeros(p-i-1)) ) )\n", " covs.append((1+4*i/np.sqrt(p))*scipy.linalg.toeplitz( [(.4*i)**x for x in range(p)] )) \n", "\n", " X_train,X_test,y_train,y_test = gen_data(testcase,Tr,Te,prop,means,covs)\n", "\n", "##Theory\n", "start_th_calculus = time.time()\n", "\n", "Phi=gen_Phi(sigma,X_train,X_train,polynom,distrib,nu)\n", "L,U = np.linalg.eigh(Phi)\n", "Phi_cross = gen_Phi(sigma,X_train,X_test,polynom,distrib,nu)\n", "UPhi_cross = U.T@Phi_cross\n", "D_UPhi_cross2 = np.sum(UPhi_cross**2,axis=1)\n", "\n", "Phi_test = gen_Phi(sigma,X_test,X_test,polynom,distrib,nu)\n", "D_Phi_test = np.diag(Phi_test)\n", "Uy_train = U.T@y_train\n", "\n", "E_train_th=np.zeros(len(gammas))\n", "E_test_th =np.zeros(len(gammas))\n", "\n", "ind=0\n", "for gamma in gammas:\n", " E_train_th[ind],E_test_th[ind] = gen_E_th()\n", " ind+=1\n", " \n", "end_th_calculus = time.time() \n", "\n", "m,s = divmod(end_th_calculus-start_th_calculus,60)\n", "print('Time for Theoretical Computation {:d}min {:d}s'.format( int(m),math.ceil(s) )) \n", " \n", "## Simulations\n", "start_sim_calculus = time.time()\n", "\n", "loops = 10 # Number of generations of W to be averaged over\n", "\n", "E_train=np.zeros(len(gammas))\n", "E_test =np.zeros(len(gammas))\n", "\n", "\n", "rng = np.random\n", "\n", "for loop in range(loops): \n", " if sigma is 'poly2':\n", " if distrib is 'student':\n", " W = rng.standard_t(nu,n*p).reshape(n,p)/np.sqrt(nu/(nu-2))\n", " elif distrib is 'bern':\n", " W = np.sign(rng.randn(n,p))\n", " elif distrib is 'bern_skewed':\n", " Z = rng.rand(n,p)\n", " W = (Z<.75)/np.sqrt(3)+(Z>.75)*(-np.sqrt(3))\n", " elif distrib is 'gauss':\n", " W = rng.randn(n,p)\n", " else:\n", " W = rng.randn(n,p)\n", "\n", " S_train = gen_sig(sigma,W @ X_train,polynom)\n", " SS = S_train.T @ S_train\n", "\n", " S_test = gen_sig(sigma, W @ X_test,polynom)\n", "\n", " ind = 0\n", " for gamma in gammas:\n", "\n", " inv_resolv = np.linalg.solve( SS/Tr+gamma*np.eye(Tr),y_train)\n", " beta = S_train @ inv_resolv/Tr\n", " z_train = S_train.T @ beta\n", "\n", " z_test = S_test.T @ beta\n", "\n", "\n", " E_train[ind] += gamma**2*np.linalg.norm(inv_resolv)**2/Tr/loops\n", " E_test[ind] += np.linalg.norm(y_test-z_test)**2/Te/loops\n", "\n", " ind+=1 \n", " \n", "end_sim_calculus = time.time() \n", "\n", "m,s = divmod(end_sim_calculus-start_sim_calculus,60)\n", "print('Time for Simulations Computation {:d}min {:d}s'.format( int(m),math.ceil(s) )) \n", " \n", "#Plots \n", "p11,=plt.plot(gammas,E_train,'bo')\n", "p21,=plt.plot(gammas,E_test,'ro')\n", "\n", "p12,=plt.plot(gammas,E_train_th,'b-')\n", "p22,=plt.plot(gammas,E_test_th,'r-')\n", "plt.xscale('log')\n", "plt.yscale('log')\n", "plt.xlim( gammas[0],gammas[-1] )\n", "plt.ylim(np.amin( (E_train,E_train_th) ),np.amax( (E_test,E_test_th) ))\n", "plt.legend([p11,p12,p21,p22], [\"E_train\", \"E_train Th\",\"E_test\",\"E_test Th\"],bbox_to_anchor=(1, 1), loc='upper left')\n", "plt.show()\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }