{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "이 노트북의 코드에 대한 설명은 [MLPClassifier의 다중 레이블 분류](https://tensorflow.blog/2018/02/18/mlpclassifier%EC%9D%98-%EB%8B%A4%EC%A4%91-%EB%A0%88%EC%9D%B4%EB%B8%94-%EB%B6%84%EB%A5%98/) 글을 참고하세요." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPython 3.5.6\n", "IPython 6.5.0\n", "\n", "sklearn 0.20.1\n", "numpy 1.15.2\n", "scipy 1.1.0\n", "matplotlib 3.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -v -p sklearn,numpy,scipy,matplotlib" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from preamble import *" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((75,),\n", " array([1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1,\n", " 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1,\n", " 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0,\n", " 0, 0, 0, 1, 0, 0, 1, 0, 0]))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.datasets import make_moons\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.neural_network import MLPClassifier\n", "\n", "X, y = make_moons(n_samples=100, noise=0.25, random_state=3)\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)\n", "y_train.shape, y_train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "이진 데이터셋을 멀티 클래스로 변경" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[0., 1.],\n", " [0., 1.],\n", " [1., 0.],\n", " [0., 1.],\n", " [0., 1.],\n", " [1., 0.],\n", " [0., 1.],\n", " [1., 0.],\n", " [1., 0.],\n", " [0., 1.]]), array([[0., 1.],\n", " [0., 1.],\n", " [0., 1.],\n", " [1., 0.],\n", " [0., 1.],\n", " [0., 1.],\n", " [1., 0.],\n", " [0., 1.],\n", " [0., 1.],\n", " [0., 1.]]))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y_train = np.zeros((y_train.shape[0], 2))\n", "for index, x in np.ndenumerate(y_train):\n", " Y_train[index, x] = 1\n", "\n", "Y_test = np.zeros((y_test.shape[0], 2))\n", "for index, x in np.ndenumerate(y_test):\n", " Y_test[index, x] = 1\n", "\n", "Y_train[:10], Y_test[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "MLPClassifier는 출력값의 두번째 차원을 보고 멀티 클래스 문제인지 자동으로 인식합니다." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "mlp_multi = MLPClassifier(solver='lbfgs', random_state=0).fit(X_train, Y_train)\n", "mglearn.plots.plot_2d_separator(mlp_multi, X_train, fill=True, alpha=.3)\n", "\n", "mglearn.discrete_scatter(X_train[:, 0], X_train[:, 1], y_train)\n", "plt.xlabel(\"특성 0\")\n", "plt.ylabel(\"특성 1\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "책의 예제처럼 이진 분류일 경우" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD5CAYAAAAHtt/AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvqOYd8AAAIABJREFUeJzt3X10VeW9J/Dvk0BCQMgLSCAmwRiSXCtBGsk5LGCpFB2ELltfqzItmGq7Or2ld6xLOp0lvU51lne4M1xboDqtgnQ64r260PsSxdCo9IY15vBiIRI0SRslGCSaHA4IJAGy54+ww8nJednnnGe/PHt/P2uxquHkZBc5v/3bv+f3/B6haRqIiEgdGXZfABERJYeBm4hIMQzcRESKYeAmIlIMAzcRkWIYuImIFMPATUSkGAZuIiLFMHATESlmnBlvOmVKgVZYWGLGWxMRuVZHx6EvNE27MtHrTAnchYUl2LDhTTPemojItb7xjas+MfI6lkqIiBTDwE1EpBgGbiIixTBwExEphoGbiEgxDNxERIph4CYiUgwDNxGRYhi4iYgUw8BNRKQYBm4iIsUwcBMROUBhS5Ph1zJwExHZLJmgDTBwExHZSg/aPl+B4e9h4CYislkyQRtg4CYisk2yJRIdAzcRkQ1SKZHoGLiJiGySStAGGLiJiCxX2NKUctAGGLiJiCyVal07HAM3EZFF0qlrh2PgJiKyULpBGwDGSbgOopgGBk6hu7sFvb0dEGIQmpaFqVNno6ioGtnZU+y+PCLLyCiR6Bi4yTShUBfa2hrR3JyLvXtLEQqNR27uedTW9sDv34HKyqXIzS2x+zKJTCerRKJjqYRMMTBwCm1tjdi69So0NExHMJiFoSGBYDALDQ3TsXXrVWhra8TAwCm7L5XIErKCNsDATSbp7m5Bc3Mujh2bGPX3jx2biEAgF93dH1h8ZUTWSrf1LxoGbjJFb28H9u7Ni/uaQCAPvb0dFl0RkfVk1rXDMXCTKYQYRCg0Pu5rQqHxEGLQoisispbsunY4Lk4qzqldG5qWhdzc8wgGs2K+Jjf3PDQt9u8Tqc6MoA0wcCvNCV0bsW4ceXmlqK09iYaG6TG/1+c7ialTZ5t6fUR2MKOuHY6BW1HhXRvhC4B610Zr6xWoq2vE3Ll3mZZ5x7tx+Hwh+P1DaG29IuoCZXHxWfh8IRQVLTXl2ojsYlZdOxwDt6KMdm1Mm/YBysoWSv/5Rm4cDz54DHV1xxAI5CEQyBsJ7D7fSfh8IVRWLuUmHHIVM+va4Ri4FTXctVEa9zWBQB78/g5TAreRG8fevXlYtiwfK1Zkwe+PrMEzaJM7mR20AQZuZdndtWH8xnEU8+evMuXmQeQkVpRIdGwHVJTetRGPmV0bdt84iJzEqhKJjoFbUVOnzkZt7cm4rzGza8PuGweR01gVtAEGbmUVFVXD7w+huPhs1N+/3LUxx5Sfb/eNg8gpzG79i4aBW1HZ2VNQWbkUdXWfYtmyHuTnDyIjQ0N+/iCWLetBXd2npnZt2H3jIHICK+va4bg4qbDc3BLMnXsXpk37wPKujcs3jkYEArls9yPPsbquHY6BW3HZ2VNQVrbQlq4NO28cRE5gR9AGGLgpTXbeOIjsYleJRMcaNxFREuwskeiYcbuIUycFErmNnUEbYOB2DSdMCiRyOzta/6JhqcQFeL4jkfnsrmuHY+B2AZ7vSGQuJ9S1wzFwuwDPdyQyn1OCNsDA7Qoc+ERkHqfUtcMxcLsABz4RmcNJde1wDNwuwIFPRPI5ra4djoHbBTjwicgcTgzaAPu4XYEDn4jkcmqJRMfA7RIc+EQkh5NLJDoGbhfhwCciOZwctAHWuImIRjix9S8aBm4iIji/rh2OpRJyHE45JKupUNcOx8BNjsIph2QXVYI2wMBNDhI+5TB8YJY+5bC19QrU1TVi7ty7pGTezOwJUKtEomPgJscwOuVw2rQP0u6cYWZPgHolEh0Dt42Y8Y02POWwNO5rAoE8+P0daQVuqzN7ciZVgzbAwG0b2Rmf7JtA5OPjierFSb9HsqyacmhlZk/OpmLQBtgOaAvZJ9aEQl04dGgH6ut7sHFjKZ588lps3FiK+voeHDq0A6FQV0rXWVKu4eevPI3e0yctqQNaNeWQ88tJxbp2OAZuG8g8scbMY8vWb9+OQGsrdh15w9B1dHbuwb5927B//2+xb982dHbuSernWjXlkPPLvU3lEomOgdsGMjM+M44tK2xpwhengtjW0IDGhQuxraEBpbNjZymyMn6rphxyfrl3uSFoAwzctpCZ8Zn12L/rwzexurgYX83Lw6riYqx/6SUAY4O3zIz/8pTDT7FsWQ/y8weRkaEhP38Qy5b1oK7uUylTDjm/3NtUD9oAFydtoWd8wWDsjM5oxmfGY7+ebR++6SYAwNrycsxpaMDalStxtGM4eOuLleku9EVbVM3Pn4VlywT8/qOmTDkczux3oLX1iqjXfTmzX5r2zyLnUL2uHY6B2wbDGV8PGhqmx3yN0YxP5k1A9/vdr2F1cTFmTpgAAJg5YcJI1r3hRz9CINA38tp0Wvhid9achN8fMq2XmvPLvcctJRIdA7cNZGZ86dwEomW7V5y+iJ1/2o3DN9886rXhWTdwOetONeO3u5ea88u9w21BG2DgtoXMjC/Vm0CsbHfq5N/gP4Zl27poWXdhS1PKGb8Teqk5v9w73BS0AQZu28jK+FK5CcTOds/gXCiA//ofbo76s8Kzbp+vAIFAH4pxBWprTyad8Vu1S9Jq3A3rLG6qa4dj4LaRrIwv2ZtArGw3K+M11M0qGZNt6yKzbp+vAGeaB+H3H00643djLzXnnziLG0skOgZul0jmJhAr2x0n2vFs55/xbGd73O9fJMTIPy/xz8DZ9wYhVneheV++4bKPGYuq6Ug3U7a7Zk+juTloAwzcnhQr2z178emRf87I0PDzdUfwt4sXJXy/ry8oxcTmLJT4g/D7uwyVfWR21qRLRqbshJo9jebWoA0wcHuS0Ww3QzP+12OJfwYmBbIwF8YGUkVbVM3PH4TP14fq6hAmTryIixcFLl7Mw8DAKdOyVFmZsltr9ipya107HHdOepCRnYN+XxDzCq9M6n31DMfIBydyl+S8eUE8/HAnLlwQeOGFMjz11LX49a/LsXPnybQGZSUia2SAG2v2KnJ7iUTHwO1BRmaCLPCFsKhkZtLvnUzw1hdVly3Lx4oVn2H79hI0NhZKHZSViKyRAbLmn8gY1uVVXgnaAEslnhSvhdDvC2KBL4SV1ZUoyMlJ6f31VsHwrfHxriUzczz27ZtqS31YiEH4/X2YM2e4PHP2bCZaWnIRCBSMlJKMZMoyavbsSkmfF4I2wMDtWWNbCAeQqY3DvMIrsahkbspBWxcevIH4dW+76sOhUBfOnxc4f364PKMHypqa4bLNa68VoaNjsqFMOdWNUHo3yxdfdCAjYxADA5kQYggARj11sCslPi/UtcMxcDtA5If3wgUBITRkZIzH1KkVpm3eCG8hLGxpkp6t6O+XKPu2oz6sL0r+7nezxixKNjYW4qOPJuOBB7rw/PNlhrpbUtkIFSvDjrxxAOxKicdLJRIdA7fN4n14b7jhJPbsOYaqqnZTH5PNzlYuB/Do2bcdPd1GFiUPHMjDkiU9uO66c4bmxiSzESpeN0vkjUP/c2FXylheDNoAFydtFW+WdWNjIV56qQSzZp3Bq69ON21xTmfFX/xYC5d2zMc2sih54EA+rr32dFKTAvWnmPnzV+GGGx7G/PmrUFa20PDuVZ1+4/D5Lk9iZFdKdF4L2gADtymMdgYY/fCWlp5L+hQbp/L5CuDzFaCwpWkkgFt18k04o+WZceNgypOO0RtHdXVo5N95Ks9oXqtrh2OpRLJkOgOMLModOJCPhx7qxAsvlLnqMTmyfNJv8Xxsu7fcG71xTJx4ceTfeSrPZV4tkegYuCVKdhdeMh9etz4m690nlUc/wVXZpZi2YoIl87Ht3nJv9MZx9mwmAJ7KE87rQRtg4JYq2XkVyXx43fyYHN59suDLISD7GkPb5tNh9/FlRm4cNTVBtLVdgWXLengqTwQvB22ANW6pkt2FZ2RRrqYmiJaWXE88Juv1bwCjauBmsOpg4liM1PX9/j7Mm3cOK1YUYu7cu7j5Bt6ua4djxi1Rsv3IRrK+mpqT+Ld/m4F77unxzGNyZP83YGxwVbLsPL7MWN/3MgbrMCyRXMbALVGyC17hH97hxcy8MX3cn3wyCffc02Na9mfGxhtZrAjgdh5fxnMvjWPQHo2BW6JUFrxGf3jbw3ZOAhkZ47FwYQmKiuZ4+kNsdAeminjupXHpBO3jvb2o+7u/w4s/+xlmFKgf/Bm4JUp1wYsfXmMS7cAkd5JR116/fTsCra0jx+6pjouTEtm94OUV0TbxkDvJKJEc7+3FtoYGNC5ciG0NDfisry/xNzmcJzJuK0/eZt3SOslMICS19PWdwG+eXIXH712D5UuvSeu91m/fjtXFxfhqXt6ow65V5vrAbceMY5Y+rGNVB0qyrEwW3Oj1V36FI11t2HXkDSxfmnqQ1bPtwzfdBABYW16OOQ0NWLtypdK1bleXSuINcTL7ZBWylpU94ImEQl04dGgH6ut7sHFjKZ588lps3FiK+voeU49hc4u+vhN45w8vSylt6Nn2zAkTAAAzJ0wYybpV5urALes8QVJHZAC3GpOF9L313M+wuqRkVGkjFXq2vba8fNTX15aXK1/rdnXglnWeIKnHruBtJFl4773JOHjwH3meZBSZe/4Vb7y/G/9l9nCwTSfIRmbbOjdk3XEDtxDipVi/rLrAdKh28jYPipXLjuBtJFnYt68A/f2ZLJ9E8fvdr6GutCTt0kasbFunetadaHFyFoBfAjhhwbVIZ/XoznQWpOxYRDU7oDlh04PeeWKVZCY+8jzJ0fRs+8iSm0d9PZUFxVjZti78hqBih0miwF0HwK9p2m4rLkY2K0d3phN4kx0HK5OZW4idsulhOHhbs+My2XGtAM+TBIaTiGcism1dKkF275Ej2NPejmfa2+O+bpEQKV+zneKWSjRNa9M07f9YdTGyWXWySroLUm5cRHXipgcrSibJTHwM5/W1li9OBbHz4B+llTaaNm+GtmtXwl9NmzfHfI/jvb247bHHHPF3N5Kyi5NG6sFW7WRMN/C6cRE12qYHO1lV7zaSLNTUnEQgMPpJx0lrLVYrbGnCrg/fNFzasEr4E6PTKLkBJ5myhBU7GY0cQRbvhG7VFlETceqmh/CdlmaVTeKNa62pCaKm5iRee61oTCnFzQdlxKPfSJ1W2gh/YrzFAX93IykXuFOpB5u9kzHdwGv3+Yeyxdv0YPdCkBWLlWOThQEMDGTg/ffz8PzzZVH/O3vhoIxI4XNImnyxSxZ2cPo2+ZilEiHELiFEQ8SvbwkhioUQ/yKEeF4IEf25xkROrAfrgTeeeIHXSF1UlQ+2Cpse9AFVZtKThfnzV2HOnPswfnwmDh+eEjVom3GKvSqcOF878u+wk/7u6uLVuB8G8H0A0wF879KvnQA2ANgC4GMAf2Py9Y0hqx4ss2c63cBr1SKqFWRvejBzgciq/m5OjRzLyQd4qLBNPmapRNO0TwBACPElgKMAfqJp2v8SQpRqmva6EOJdAL+z5jIvk1EPlt0zne7Bs8aOsXL+Bzuyth0p1X5cM1oKrah3h+PUyMucPIo31t9hp6zT6BLWuDVNWyyE+AmALy996eKl/z0DYJJZFxb7etKrB5vRMy0j8Fr9wTbjwyN704PZC0Th9W4rpvlxaqTzjyAz8sTohFp33MAthHgKQA6AM5qmbbj05QwhxCQAVRgul1gq3U01RmvkyW6GkBF4zfxgRwYmXMhETdF09J3LQUFOjpSfIbszwKoFooyDO3Fo8DNLd616mVODthlPjGZJlHF3AlgAYIoQYoKmaf0AngOwG8MBfZXJ1zdGumWJdFv34pEZeGVmgLFKQ/7aIN73H8LK6kpU5Oenfc3xNjMkS0ZLoZEt97Orc1Df3Iat20os37XqNU6uawNqbZOPG7g1TXsBwAtCiNsAbAbwkKZp24QQ7wE4rWlatxUXGS7dsoQKPdMya/DxSkM7GwrxQetkoK4Na2rnSsu8ZZDRUmikPt7UdRyBvXnSn8C8KF6yUdp2yO7LS8hpveTxGOrj1jRtpxCiUggxXtO085qmfWT2hcWTTlnC6T3TsmvwhsaMBnJRM+M4bq9M74goWWQsEBmtjx888Tma914d971SfQLzkmjJxuTJPZg07nHc/+07kDOtCl9fEP9J124ynxjNlsyW9/cB3GvWhSQrvE/2hhsexvz5q1BWtjBhMHN6z7TsPnUj7ZPNgXz86cTnSV+rWWS0FBrdcj8kLjj+CczpYs3qOXe6HmdCH+PFLQfwTm8bbnn0UUf1Qqss0TzuxUKIQiHEtwBMATBNCPG1SxtxvmXNJcrl9J5p2XNLjJaGhsQFw9doJhlzlJPZQJGhjUu4eaq09CwuXMjgnPQYoicbQWTgXbyzeCHOndmDHa/uxoEPP3RUL7TKEmXc6wHcDeBHADQAAsBTAKoBOH83SBRO3wwhuwZvdFdnhuaM6QfJLBAZfY9433N94ZXw1wZjvtfs2adx//1dCATyeH5kDNGSjayM11A3a/j4sXuvmoHPju931KRI1SX6tIqwXyM0TVtn2hVZwMmbIWTX4I20T/p9QcwrvDLpazVDugtEydbHF5fMxPv+Q/igdfKY8lR+/iDuuutT/P73pew4iWNssjGcba+ruhkAkJ0xhAdLnTv3Q0VG0yzN1KuwgVM3Q8g+/MFI++QCXwiLSuamfM0ypbtAlOwGioKcHKysrgTq2vBeIBfNgfyRLp677/4U77+fz46TBCKTDT3bnjlhAo739+PV7m4cXjrcnuukXmiVGVmcdF3QdjLZNfh4paHblp3Aw3XHsLK60lGtgKlKtT5ekZ+PNbVz8cNvZOGRNR9j3boj+PEPOzFjxiD27Yvf367anHQzjF7w17Pt4f8G69vbsbq01NFzP1SUKOOeAeBmAEXmXwoB5swtiVYawoUM3FA0HYtKnNW/nY50NlAU5OTg9sprRloiA4E+vDF4hB0nBlSfHkBvbRCtrVegp3t0tr3t6NGRbFvHrDt9iQL3/7z0v+FnTr5q0rXQJWbU4CNLQ4UtTfBVuutDI30DxYVMR/f8221k7oh/BoqD2Thzz348t/kdrKtaAmBstq1z0g5EVSXaOTlScBRCLAFwStO0raZfFTm2Bu9kMjdQ+HwFOPjvuaitPWnJYdMqCR9Qpm9hr8jPx5n2D/BQgmxbx6w7PYZ7wDRNe8fMCyFymrvnX41jFw+mPBfHbRJN9mtpb8eejg78qqMDEzMysHrWLCXmfqjIGc27abBiHCepyciQqXgKcnLgm1AErO5C8758ZeekpyNadh2L/sTzyKZN2FJfj2c7O/FsZ2fc73HC3A8VKR24ZR+IQO4i4xCGry8oxcTmLExdMcFxPf9mSiZgh9M7e95dvBi3NDfj8IsvshRiAmUDtxkHIpB7yDyEYZLIwoIvh3BivuVTjC2XasDWOf2QXbdIZsiUozjx0GByDqNDpoxw8gxpWQpbmkbVsFP5/6zCIbtuoWzglj2MidzDrADi5LMSU6EHa/2Ag1QDtk6FQ3bdQtnArcKBCGSP9du3454ZM1B34AA+6++XEkDclHXLyK4jxdq1yqzbHMoGbqNT77y6OcKr9ACCoSEEgkGsv7QZx46se2DgFDo791g+Drav7wR+8fi9CAZ7Rn3djICtkzFDnYxTNnA7/UAEsoeebb/a3Y3GRYuw7ehRW7LuUKgLhw7tQH19j+XjYF9/5Vf484f78No//Up6OSQaGTPUKTnKBm6nH4hA1gvPtleXlg4vTJaWWp51xzoRBgCEGMLAwBA6OnZi7175GXhf3wm80/hPaFy4EO/+4WX0nj5pSrAOJ2OGOiVH2XZAM4YxkdpGsu1jxy6PEa2owJzGRqytqJCyW8/nK0AgED/wR+t4mj37NO68sxsHDuTht78tM2XPQWFLE7b/6xasLr4KX83Lw4Olxdh15A0sX2puO55Kh+y6hbKBG3D2gQhkvb1HjmDvJ5/gobCt1jMnTMB9xcWY1dCAwaEhAOkHkOHg3YQT1Yuj/v5wx9Plg3Hz8wdx553d2L69xJQ9B/oTwBength58I8jh0ikOw/E6M5TlQ7ZdQulAzfAYUx02Su/+AWuq6vDuqqqUV9fV1WFf/zsM9N38enjFy5ePI81azpw9mwmWlpykZ19EQcO5Ek/kCFys8wjm16K2Y6XyhOGjJ2nZA5la9zkLcd7e3HbY4/FrU9b2dng8xWMCpzhi5G//nU5nnrqWrzwQhkuXBCYM+cUDhyQcyBDrMVG2e144TtPubDoPAzcpITw7C8auzobCluaYi5GBoNZaGwsxLhxWtp7DhK18sm+acnceUryMXCT4xnJ/ox0NnzHhKwbSDx+4ezZzJT2HIRn1/rPi9YdIvumxa3rzsfATY5nJPvbe+QInmlvh3j99Zi/ftnejsCRI9Kvr6/no7jjF1paclFTE4z7HuF7DiKDdUm5hp+/8nTMwCm7HY9b151P+cVJcjc9+0vUKRGrs+F4by+uq6tDo9+PW5qb8eqTT0q9Pp+vAG80XYxbCgkECvDww5346KPJcQ9kqD6di0lRDitItEgosx0v8s9bxxNrnIWBmxwtXvZnpNPBijGjmdq4uGdTBoNZ2LVrOlat+gT7908dveegNgh/7Un4soqwxD9jzPcaGU8rsx3PSK2cHSb2Y6mEkmKku0Pmz0qnU8KqWu31hVfCXxu/FFJYeB7Tps3GihWF+PEPO7Fu3RH8+Ied+OtvZuFv/Nfj6wtKo36flYuE3LquDgZuSkqi7g7ZPyudTgmrarWLS2ZiQYLxC/75Qcw7O4QFXw7hiZsX4InFi/DEzQtwe+U1KMjJifp9Vi8SGqmVPzBzJp7YyvPC7cZSiQfpPcBG6Tvo/scPfiDtVBkjPzNarVWXqOZqZa22ICcHK6srgbo2vBfIRXMgf6QU4vcFscAXwsrqKlTkx+/ljpRumShZRmvlV5+MP9yNzCc0TZP+phUV12sbNrwp/X1JjmQD9yObNmHbm29idnExFmVm4h+uuw7fP3gQ75w5g3/fuNGU4P3Ipk1Aayv+4brrYr/m8GGIr3wlahCL9/3xvi8dfefOYU/XcfzpxOcYEheQoY3DvMIrsahkZsysOhZ9UfXwTTeNyoCP9/djzu7dlp/lGLnIy7MkzSFuvXW/pmnzE72OGTfFpWeuL99wA+4IBPDPt946/BtDQ/i8rw9PbN2K5x59VPrPTadTIt1sPVUFOTm4vfIa3F55Tdrv5bRFwsha+8+3bMHREycSzjEhczBwU1z6B/bNnh48VFqKmRMm4Hh/P17t7sY7ixfjpsZGPFFXJ/3Dm06nRDJ9zU7skLDrxmP0etaWl+OvGhuRKYRj/wzdjoGbYtI/sG/7/fjanj0jo1LXt7ePzLv+dnGxaVl3qlQfM+q0G0+06xnSNLy7aJHpax0UHQM3xaR/YLd1dWF1WLa97ejRkSC+rqoKVSZl3alq2rx5pC7/4PLlymWETrrxRMv+17e346FLN+7vOPjJxc0YuCmqRNl2eKeDGVm30VnQ8a7diu4XMzhpvnVkth154/4pd1Tagn3cFFW8bHttRcWo166rqsJLjY1S+4vT6RfnZDs5om3IiXbjTnZ4l5WbuNyKgZvGCP/A7g0G8cyf/wzx+uu4uqEB98fodJA5eS+dWdCcbCdPrGw78sb90xh/xrECtJWbuNyKgZvGCP/ANt14I7Q77kD3bbdhUmbmmNNldI9XVUkLkOlkzJxsJ0/kxMV4N+5of8bRAjQPaJCDgZvGiDYiNdaHVicrQKaTMRudbcJHdWOaNm+GtmsXtF270P3yy5iUnR3zxh3tzzhagGYZSw4Gbo8JP24rlvAPrP6rdvZsPNvZGXfe9TMS5l2nkzEbnW3CR/XkJTvzO1qAZhlLHnaVeFAy2911VnQ6pDNfxOimldXLlyvdcWKXZFoUY81Q/7K/39LZK27GwE2Okc42b6MZ4ffWr485nzudFkS3S+bG/cimTWMC9HeKi7Hl7bfx0aU2Qh0PaEgNSyXkCOnOgjZydNkz7e04fPRozEd1llDSF+u/40/LywFNQ+SWIS4ep4aBmxwh3XMTo9XlI3/9529+E98vK4v6qM5uBzniPTV9t7QU66OUWljrTh5LJeQIZm/zTlQ/1+uvZh5x5naJ1hl+WlmJOY2NWFtRgRlhgT3ZWjdLWgzcnmKko8QuZi9+xssEI+uvrLumxshT033FxZjV0IDBoaExv2/0ppzo8GQvYOD2mFQ6SlSXMBMsL8eWzs6R+iu7HVJj+KmpsjLlG7Xqc2hkYeAm1zOSCer11w3V1QCYdafCipbRaP3hXry5MnCT6xnOBMMCNLNu54nVH+7FmysDN7lerEww1rmOOi8HBiey+vBkJ2M7IHlWui2IZB2jc2i8ghm3RyR7srsXOOmkGYrPaYcn242BmzzLSSfNUGxOOzzZCVgqIek4NpVkYklrLGbcJB03SAzjDj85WNIai4GbpOIGicusvIG5+SbBktZYLJWQVDzhZJjVQ6s42dBbXBO4BwZOobNzD/bt24b9+3+Lffu2obNzDwYGTtl9aZ7BE04us/IGxsmG3uOKwB0KdeHQoR2or+/Bxo2lePLJa7FxYynq63tw6NAOhEJddl+iJ3jxoN5oC7FW38D4lOM9ygfugYFTaGtrxNatV6GhYTqCwSwMDQkEg1loaJiOrVuvQltbIzNvk3lhg0S0IB2tRGHlDczoTYKdPu6ifODu7m5Bc3Mujh2bGPX3jx2biEAgF93dH1h8Zd5i9KBelUUG6WglCqtvYEZvEqyBu4vygbu3twN79+bFfU0gkIfe3g6Lrsh70j12TAXRgnS0EoWVNzCjNwnWwN1H+cAtxCBCofFxXxMKjYcQgxZdkfd4YYNEZJD+2y1bxpQotr71Fra99ZZlNzCjNwnWwN1H+T5uTctCbu55BINZMV+Tm3semhb79yk9bt8gEW2caFVjI+6PKFHMzsnBgvx8QzewdPsjXxPaAAAE3klEQVS6jW4DX718OUehupDygXvq1Nmore1BQ8P0mK/x+U5i6tTZFl6Vt7h9g0S0OvK3i4txMeL4rQwAm/7yF2z6y1/ivp+MG5jRp5zv/f3fx6yBP/bAA67dtON2ygfuoqJq+P070Np6RdQFyuLis/D5QigqWmrD1ZHqYmW266qqMKexEf/t2mtHDr5tvukmPHL4MMRXvmL6TkmjTzlTxo/HPy8d/Xdfz7rP9PdzNIGilA/c2dlTUFm5FHV1jQgEchEI5CEUGo/c3PPw+U7C5wuhsnIpsrOn2H2ppKC4deSI486A6KUIM7ajG3nKeWTTJqC1Neq13z1jBl56+238cdEiz48mUJHyi5MAkJtbgrlz78KKFYVYs6YL69Z9iDVrurBiRSHmzr0Lubkldl8iKShht0xFBbYdPYrP+vtHvhZtIdaOVrxE146hIXyHC5bKUj7j1mVnT0FZ2UKUlS20+1JIMbEyYiN15PuKizGroQGDEfVuvY5t19CteNd+vL8fr3Z34/ClEgoXLNXjmsBNlKpYU/wMd8tUVsYsXdh1Knm8a8/KyMBDs2bx7EaFCU3TpL9pRcX12oYNb0p/X0odjy6LTj8wuNHvxy3NzTj84ovSss7Iw4iP9/djzu7dUn9Gutc08nUHXBsB4tZb92uaNj/R61xR4yZKlZmbU5w4dMsLowm8gIGbPMvMKX5OHLrlhdEEXsHATZ5lZkbsxMzWC6MJvIKLk+RJsTbWyOiwcOqp5G4fTeAlDNzkSUYy4lQ7LJLJbK3s4nD7aAIvYeAmzzE7I2ZmS2Zj4CbPMTsjZmZLZmPgJs9hRkyqY+Amz2FGTKpjOyARkWIYuImIFMPATUSkGAZuIiLFMHATESmGgZuISDEM3EREimHgJiJSDAM3EZFiGLiJiBTDwE1EpBgGbiIixTBwExEphoGbiEgxDNxERIph4CYiUgwDNxGRYhi4iYgUw8BNRKQYBm4iIsUwcBMRKYaBm4hIMULTNPlvKsTnAD6R/sZERO42S9O0KxO9yJTATURE5mGphIhIMQzcRESKYeAmzxJCXC2E+IPd10GULAZuUp4Q4kdCiI+FEO9F/Lox4nUPCyH+u8H3+3+X3uM+866cKDXj7L4AIkn2AXg34mvtEf++AMBpABBC7AEwGUBP+AuEEOUAvnvptdkAAkKIBk3TgiZcM1FKGLjJLWoAzIj42hcAXgYAIcRyAALAGSHEowBuBFAC4PmI7/kagH/RNG0QwKAQ4o8AFgKoN/HaiZLCwE1KE0J8DcDHAH4c4/dvBPBVAPMA/DWAcwB+AOAnAF6J8i3TMBzwdV8ASNhXS2QlBm5S3dUApsT5/S8APKdp2kDY154FACFELsZm3EEAU8P+PffS14gcgxtwSHmXSh/3RvmtxzVN+0PY6x7EcLatYXhhvgvATzRNOxr2mjkA/jeGSylZAPYCuFXTtOOm/R8gShIDN7mSEOIpAB9omqbXuP8KwP8FsETTtFOXvnYngO9qmnZ7xPf+DMAdGA7wv9E0bYulF0+UAEsl5BWnAUwEUCmEOARgEoA5AD6PfKGmaU8DeNrayyMyjhk3eYYQohbAfwJQBuAMgD0Afqlp2llbL4woSQzcRESK4c5JIiLFMHATESmGgZuISDEM3EREimHgJiJSDAM3EZFi/j91hGDB4jNs2wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "mlp = MLPClassifier(solver='lbfgs', random_state=0).fit(X_train, y_train)\n", "mglearn.plots.plot_2d_separator(mlp, X_train, fill=True, alpha=.3)\n", "\n", "mglearn.discrete_scatter(X_train[:, 0], X_train[:, 1], y_train)\n", "plt.xlabel(\"특성 0\")\n", "plt.ylabel(\"특성 1\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.88, 0.88)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mlp_multi.score(X_test, Y_test), mlp.score(X_test, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Yeast 데이터셋으로 다중 분류 테스트" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_openml\n", "yeast = fetch_openml('yeast', version=4)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((2417, 103),\n", " array([[False, True, True, True, False, False, False, False, False,\n", " True, True, True, True, False],\n", " [ True, True, False, False, False, True, True, False, False,\n", " False, False, True, True, False],\n", " [False, True, True, True, True, True, False, False, False,\n", " False, False, False, False, False],\n", " [False, False, False, False, False, False, True, True, False,\n", " False, False, True, True, False],\n", " [ True, True, False, False, False, False, False, False, False,\n", " False, False, True, True, False],\n", " [False, False, False, False, False, False, False, True, True,\n", " False, False, False, False, False],\n", " [False, False, False, True, True, False, False, False, False,\n", " False, False, True, True, False],\n", " [False, False, False, False, False, False, False, False, False,\n", " False, True, True, True, False],\n", " [False, False, False, False, True, True, False, False, False,\n", " False, False, True, True, False],\n", " [False, False, False, False, True, True, True, True, False,\n", " False, False, True, True, False]]))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = yeast['data']\n", "Y = yeast['target']\n", "Y = Y == 'TRUE'\n", "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)\n", "X.shape, Y_train[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "정확도는 각 행의 모든 클래스가 정확히 맞아야 됩니다." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.12396694214876033" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mlp_multilabel = MLPClassifier(hidden_layer_sizes=(300,100), max_iter=10000, random_state=42).fit(X_train, Y_train)\n", "mlp_multilabel.score(X_test, Y_test)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],\n", " [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n", " [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],\n", " [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0],\n", " [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0],\n", " [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y_pred = mlp_multilabel.predict(X_test)\n", "Y_pred[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "각 행에서 하나의 클래스라도 맞은 경우를 헤아려 봅니다." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.890495867768595" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.sum(np.sum(Y_test.astype(int) & Y_pred, axis=1) > 0)/Y_test.shape[0]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.6" } }, "nbformat": 4, "nbformat_minor": 2 }