{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
},
"orig_nbformat": 4,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.6.13 64-bit ('pyes': conda)"
},
"interpreter": {
"hash": "5febc64283966fe46fc465ad8ab242eb0484fa28506faf7951a93cb8efa4edfd"
}
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# 基于单一种群的特征选择方法\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"import pyes\n",
"import pandas as pd\n",
"import math\n",
"import matplotlib.pyplot as plt\n",
"from tensorflow.python.framework.errors_impl import InvalidArgumentError\n",
"from tensorflow.contrib.distributions import MultivariateNormalFullCovariance\n",
"from loguru import logger\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.svm import SVC\n",
"\n",
"tf.logging.set_verbosity(tf.logging.ERROR)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"data = np.genfromtxt('../exampleData/glass.data', delimiter=',')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" 0 1 2 3 4 5 6 7 8 9 10\n",
"0 1.0 1.52101 13.64 4.49 1.10 71.78 0.06 8.75 0.00 0.0 1.0\n",
"1 2.0 1.51761 13.89 3.60 1.36 72.73 0.48 7.83 0.00 0.0 1.0\n",
"2 3.0 1.51618 13.53 3.55 1.54 72.99 0.39 7.78 0.00 0.0 1.0\n",
"3 4.0 1.51766 13.21 3.69 1.29 72.61 0.57 8.22 0.00 0.0 1.0\n",
"4 5.0 1.51742 13.27 3.62 1.24 73.08 0.55 8.07 0.00 0.0 1.0\n",
".. ... ... ... ... ... ... ... ... ... ... ...\n",
"209 210.0 1.51623 14.14 0.00 2.88 72.61 0.08 9.18 1.06 0.0 7.0\n",
"210 211.0 1.51685 14.92 0.00 1.99 73.06 0.00 8.40 1.59 0.0 7.0\n",
"211 212.0 1.52065 14.36 0.00 2.02 73.42 0.00 8.44 1.64 0.0 7.0\n",
"212 213.0 1.51651 14.38 0.00 1.94 73.61 0.00 8.48 1.57 0.0 7.0\n",
"213 214.0 1.51711 14.23 0.00 2.08 73.36 0.00 8.62 1.67 0.0 7.0\n",
"\n",
"[214 rows x 11 columns]"
],
"text/html": "
\n\n
\n \n \n | \n 0 | \n 1 | \n 2 | \n 3 | \n 4 | \n 5 | \n 6 | \n 7 | \n 8 | \n 9 | \n 10 | \n
\n \n \n \n | 0 | \n 1.0 | \n 1.52101 | \n 13.64 | \n 4.49 | \n 1.10 | \n 71.78 | \n 0.06 | \n 8.75 | \n 0.00 | \n 0.0 | \n 1.0 | \n
\n \n | 1 | \n 2.0 | \n 1.51761 | \n 13.89 | \n 3.60 | \n 1.36 | \n 72.73 | \n 0.48 | \n 7.83 | \n 0.00 | \n 0.0 | \n 1.0 | \n
\n \n | 2 | \n 3.0 | \n 1.51618 | \n 13.53 | \n 3.55 | \n 1.54 | \n 72.99 | \n 0.39 | \n 7.78 | \n 0.00 | \n 0.0 | \n 1.0 | \n
\n \n | 3 | \n 4.0 | \n 1.51766 | \n 13.21 | \n 3.69 | \n 1.29 | \n 72.61 | \n 0.57 | \n 8.22 | \n 0.00 | \n 0.0 | \n 1.0 | \n
\n \n | 4 | \n 5.0 | \n 1.51742 | \n 13.27 | \n 3.62 | \n 1.24 | \n 73.08 | \n 0.55 | \n 8.07 | \n 0.00 | \n 0.0 | \n 1.0 | \n
\n \n | ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n ... | \n
\n \n | 209 | \n 210.0 | \n 1.51623 | \n 14.14 | \n 0.00 | \n 2.88 | \n 72.61 | \n 0.08 | \n 9.18 | \n 1.06 | \n 0.0 | \n 7.0 | \n
\n \n | 210 | \n 211.0 | \n 1.51685 | \n 14.92 | \n 0.00 | \n 1.99 | \n 73.06 | \n 0.00 | \n 8.40 | \n 1.59 | \n 0.0 | \n 7.0 | \n
\n \n | 211 | \n 212.0 | \n 1.52065 | \n 14.36 | \n 0.00 | \n 2.02 | \n 73.42 | \n 0.00 | \n 8.44 | \n 1.64 | \n 0.0 | \n 7.0 | \n
\n \n | 212 | \n 213.0 | \n 1.51651 | \n 14.38 | \n 0.00 | \n 1.94 | \n 73.61 | \n 0.00 | \n 8.48 | \n 1.57 | \n 0.0 | \n 7.0 | \n
\n \n | 213 | \n 214.0 | \n 1.51711 | \n 14.23 | \n 0.00 | \n 2.08 | \n 73.36 | \n 0.00 | \n 8.62 | \n 1.67 | \n 0.0 | \n 7.0 | \n
\n \n
\n
214 rows × 11 columns
\n
"
},
"metadata": {},
"execution_count": 3
}
],
"source": [
"# 数据概览\n",
"pd.DataFrame(data)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"number of rows: 214, number of column: 11\n"
]
}
],
"source": [
"n_row, n_col = data.shape\n",
"print('number of rows: {n_row}, number of column: {n_col}'.format(n_row=n_row, n_col=n_col))"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"def encode(solution):\n",
" \"\"\"用于将个体编码成二进制形式,用于选择特征\"\"\"\n",
" v_func = lambda x: 1.0 / (1.0 + np.exp(-x*2))\n",
" return [True if v_func(x) > 0.5 else False for x in solution]"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[]"
]
},
"metadata": {},
"execution_count": 51
},
{
"output_type": "display_data",
"data": {
"text/plain": "",
"image/svg+xml": "\r\n\r\n\r\n\r\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAe1klEQVR4nO3deXiU9b338fc3yYQkQBKWsCfsKIgiGnFrC1q1aF1qWxUebbW1cmm1p6e1i572sR67nJ72XG0fT+1CW2utC+JOFWu1aql7wiKyExGyCCQsIUDIMjPf548MNEYwA0xyz0w+r+vKNXPfczPzKU4+/XFvP3N3REQk9WUEHUBERBJDhS4ikiZU6CIiaUKFLiKSJlToIiJpIiuoDx44cKCPGjUqqI8XEUlJixcv3ubuRQd7LbBCHzVqFOXl5UF9vIhISjKzTYd6TbtcRETShApdRCRNqNBFRNKECl1EJE2o0EVE0kSnhW5md5tZrZmtOMTrZmZ3mlmFmS03s5MSH1NERDoTzwj9HmDmh7x+PjA+9jMH+PXRxxIRkcPV6Xno7r7IzEZ9yCaXAPd62314XzezQjMb6u6bExVSRCRe0ajTEonSEonSGt7/6LREIrSE214LR6JEok7EnWgUIu5EolEiUYhEnaj7+x7/9ZzYn/nXOnfaHgF3cNrW+YHX2q2jbf05Ewczpbgw4f/bE3Fh0XCgqt1ydWzdBwrdzObQNoqnpKQkAR8tIqkuHInS0BSmvrGFXftaqd/Xyq7G1thymL0tYRpbwjQ2R2hsidDYGqGxOUxjS4R9rRH2NodpDkdpCUdpjUQJR5N7jgczGFKQk7SFHjd3nwvMBSgtLU3uv3UROWr1jS1s2t5I1c5GtjY0U9vQRO3uZrY2NLG1oYm63c00NIU/9D1yQ5nkZWeSm51J7+wscrPblgvzssmLPc8JZZKdlUEo08jOzCSUZWRnZpCdlXHgMdRuOSvTyDQjI8PIzDAyrO2xbR3tnrc9Zmb863lGBu9bl2GG0VbUhrU9xp5nGFj718269O87EYVeAxS3Wx4RWyciPUBrJMqGur2s2dLAuq272bi9kcrtjWzavvcDZZ2dlcHg/F4M6pvDMUP68tHxRfTLy6YgN4vCvGwK8kIU5IYozA1RmJdNfk4WWZk6GS9eiSj0BcBNZjYPOBXYpf3nIukpEnXWbGlgyaadLKvaxerNDVTU7qElEgUgK8MY0S+XkgG9ObG4kJED8ijpn0fJgDyG5OdQkBvq8lFqT9ZpoZvZg8AMYKCZVQPfA0IA7v4bYCFwAVABNAJf6KqwItK9olFn1eYGFq2v47V3trO0sp49zW2j7oF9ejFpWD4fnTCQiUPymTg0nzFFvQlpRB2YeM5ymd3J6w7cmLBEIhKoptYIL66p5bnVW1m0bhvb9jQDcOyQvnxq6jBKR/bn5JH9GNEvV6PtJBPY7XNFJHm0hKO8XFHHX97azN9WbmFvS4TCvBAfG1/E9AlFfHTCQAb1zQk6pnRChS7Sg1XtaOSBNyt5uLyKbXtaKMgNcdGUYVw0ZRinju6vA5IpRoUu0gOVb9zBr196hxfW1mLAxycOZtYpxXx0fBHZWSrxVKVCF+kh3J1/rt/GL1+o4M2NO+iXF+IrZ41j9qklDC3IDTqeJIAKXaQHWFGzix8/s4aXK7YxtCCH2y6cxKxpxeRlqwLSif5riqSxbXua+dHC1Ty2pIZ+eSFuu3ASV55WQq+szKCjSRdQoYukoWjUeXhxFT9auIbGljA3zBjLDTPGkp8TCjqadCEVukiaqanfx83zl/H6hh1MG9WfH316MuMG9Q06lnQDFbpIGln49mZueXQ5kajz408fz+WlxWRk6OKfnkKFLpIGmsMRbl+wigffrGRKcSF3zjqRkQN6Bx1LupkKXSTF1e5u4vo/L2ZJZT3XTx/LzedN0P1UeigVukgKW1Gzi+vuLWdnYwu/uvIkLjh+aNCRJEAqdJEU9c/1dcy5dzH98kI8cv0ZTB5eEHQkCZgKXSQF/XXFFv7twaWMKerNvddO042zBFChi6ScRxZX861H3mJKcSH3XDONgjydWy5tVOgiKeTxpdV84+G3OHPcAOZ+rpTevfQrLP+ib4NIinh25Ra+8fByzhg7gD9cfQo5IV2+L++nc5tEUsArFdv4ygNLOX54AXM/X6oyl4NSoYskuf2nJo4p6s09XziFPtrNIoegQhdJYrW7m7ju3nIKc0Pc+8VpFOZlBx1Jkpj+r14kSTW1Rphz72LqG1t55IbTGZSvUxPlw6nQRZKQu3PLo8tZVlXPb646ieOG6aIh6Zx2uYgkoT+9upEnlr3HN86bwMzJupxf4qNCF0kyb1fv4kcL1/DxYwdx41njgo4jKUSFLpJEdje1ctODSxjQJ5v/uWwKZrqXucRP+9BFkoS7c+tjb1O9cx/z5pxGv946o0UOj0boIkni8aU1PLV8M18/dwKnjOofdBxJQSp0kSSwtaGJ2xespHRkP66fPjboOJKiVOgiAXN3/uOxt2kOR/nJZ08gU3OAyhFSoYsE7PGlNfx9TS3f/MQxjCnqE3QcSWEqdJEA1bbb1fKFM0cHHUdSXFyFbmYzzWytmVWY2S0Heb3EzF40s6VmttzMLkh8VJH088OFq2nSrhZJkE4L3cwygbuA84FJwGwzm9Rhs+8C8919KjAL+FWig4qkm9fe2c6Ty97j+uljtatFEiKeEfo0oMLdN7h7CzAPuKTDNg7kx54XAO8lLqJI+mmNRLntyRWM6JfLl2forBZJjHgKfThQ1W65OrauvduBq8ysGlgIfOVgb2Rmc8ys3MzK6+rqjiCuSHr44yvvsr52D7dfdJwmq5CESdRB0dnAPe4+ArgA+LOZfeC93X2uu5e6e2lRUVGCPloktWzZ1cQvnl/PORMHcc6kwUHHkTQST6HXAMXtlkfE1rV3LTAfwN1fA3KAgYkIKJJu/udvawlHnNsuPC7oKJJm4in0MmC8mY02s2zaDnou6LBNJfBxADObSFuha5+KSAer3mvg0SXVXHPmKEoG5AUdR9JMp4Xu7mHgJuBZYDVtZ7OsNLM7zOzi2GY3A9eZ2VvAg8A17u5dFVokVf3XM6vJzwlx4wzdFlcSL667Lbr7QtoOdrZfd1u756uAMxMbTSS9LFpXxz/Xb+O7n5xIQV4o6DiShnSlqEg3iESdHy1cTXH/XD53+sig40iaUqGLdIMnltawZstuvvWJY+mVpdMUpWuo0EW6WDgS5c4X1nPcsHw+ebzmB5Wuo0IX6WKPLa1h0/ZG/v2cCWTofi3ShVToIl2oNRLlf19Yz/HDCzhn4qCg40iaU6GLdKFHF1dTtWMfXzt3vCZ8li6nQhfpIi3hKP/7QgVTigs56xiNzqXrqdBFusijS6qpqd/H187R6Fy6hwpdpAuEI1F+8493mDKigOkTdCM66R4qdJEu8MyKLWza3sgNM8ZqdC7dRoUukmDuzq9feocxRb05b9KQoONID6JCF0mwReu3sWpzA9dPH6vzzqVbqdBFEuxXL1YwJD+HT53YcWIvka6lQhdJoCWVO3nj3R186aOjyc7Sr5d0L33jRBLod4s2UJAbYva0kqCjSA+kQhdJkKodjTy7cguzp5XQu1dcUw2IJJQKXSRB7n1tI2bG53W/cwmICl0kAfY2h5lXVsX5k4cwrDA36DjSQ6nQRRLg0SXV7G4K88WPjA46ivRgKnSRoxSNOn98ZSMnFhdyUkm/oONID6ZCFzlKL62r5d1tezU6l8Cp0EWO0t0vb2RIfg7nT9Zl/hIsFbrIUVi7ZTcvV2zj82eMJJSpXycJlr6BIkfhnlc3khPKYPYpupBIgqdCFzlCDU2tPLG0hkumDKdf7+yg44io0EWO1ONLatjXGuGq03QhkSQHFbrIEXB37nt9E1NGFHD8iIKg44gAKnSRI/LmuztYX7uHKzU6lySiQhc5Ave9UUl+ThYXnTAs6CgiB6jQRQ5T3e5m/rpiM589uZjc7Myg44gcEFehm9lMM1trZhVmdsshtrnczFaZ2UozeyCxMUWSx/zyKlojzpWn6VRFSS6d3rTZzDKBu4BzgWqgzMwWuPuqdtuMB24FznT3nWY2qKsCiwQpEnUeeKOSM8YOYGxRn6DjiLxPPCP0aUCFu29w9xZgHnBJh22uA+5y950A7l6b2JgiyeEf62qpqd+nUxUlKcVT6MOBqnbL1bF17U0AJpjZK2b2upnNPNgbmdkcMys3s/K6urojSywSoPtfr6Soby/OnTQ46CgiH5Cog6JZwHhgBjAb+J2ZFXbcyN3nunupu5cWFRUl6KNFuseWXU28uLaWy0tH6L4tkpTi+VbWAMXtlkfE1rVXDSxw91Z3fxdYR1vBi6SNR5dUE3W4vLS4841FAhBPoZcB481stJllA7OABR22eYK20TlmNpC2XTAbEhdTJFjRqPNQWRWnjxnAyAG9g44jclCdFrq7h4GbgGeB1cB8d19pZneY2cWxzZ4FtpvZKuBF4Jvuvr2rQot0t9ff3U7ljkauOEWjc0lenZ62CODuC4GFHdbd1u65A1+P/YiknYfKqsjPyWKmJrGQJKYjOyKd2NXYyjMrtvCpqcPJCenKUEleKnSRTjyxrIaWcFS7WyTpqdBFPoS7M6+sisnD8zlumG6TK8lNhS7yIVbUNLB6cwNXaIo5SQEqdJEPMa+skpxQBhdP0W1yJfmp0EUOYV9LhAXL3uOCyUMpyA0FHUekUyp0kUNY+PZmdjeHdTBUUoYKXeQQHiqvYvTA3kwb3T/oKCJxUaGLHMSGuj28+e4OLi8txsyCjiMSFxW6yEE8VF5FZobxmZM73ilaJHmp0EU6aI1EeXRxDWcfO4hBfXOCjiMSNxW6SAcvrKll255mrtBtciXFqNBFOphfVsWgvr2YcYwmYZHUokIXaWf/rESfPXkEWZqVSFKMvrEi7TyyuEqzEknKUqGLxESjzvzyak4fM4BRAzUrkaQeFbpIzOsbNCuRpDYVukjMQ+WalUhSmwpdBKhvbNGsRJLyVOgiwBNL22YlmqX7nksKU6FLj7d/VqITRhQwaVh+0HFEjpgKXXq8t6p3sWbLbh0MlZSnQpce76GySnJDmZqVSFKeCl16tL3NYRYse49PnjCUvjmalUhSmwpderSnl29mb0uEWdrdImlAhS492ryySsYN6sPJI/sFHUXkqKnQpcdat3U3SyrrmXWKZiWS9KBClx7robIqQpnGpVM1K5GkBxW69EjN4QiPLanmvElDGNCnV9BxRBJChS490nOrtrKzsVXnnktaUaFLjzTvzSqGF+bykXEDg44ikjBxFbqZzTSztWZWYWa3fMh2nzEzN7PSxEUUSayqHY28XLGNy0uLycjQwVBJH50WupllAncB5wOTgNlmNukg2/UFvgq8keiQIok0v7yKDIPLSkcEHUUkoeIZoU8DKtx9g7u3APOASw6y3feB/waaEphPJKHCkSgPl1fzsQlFDCvMDTqOSELFU+jDgap2y9WxdQeY2UlAsbs//WFvZGZzzKzczMrr6uoOO6zI0Xp+dS1bGpr4P9N0m1xJP0d9UNTMMoCfATd3tq27z3X3UncvLSoqOtqPFjls97+xiaEFOZx97KCgo4gkXDyFXgO0P7drRGzdfn2BycBLZrYROA1YoAOjkmze3baXf67fxuxpJWRl6gQvST/xfKvLgPFmNtrMsoFZwIL9L7r7Lncf6O6j3H0U8DpwsbuXd0likSP0wBubyMow3YhL0lanhe7uYeAm4FlgNTDf3Vea2R1mdnFXBxRJhKbWCA8vrua84wYzKD8n6DgiXSIrno3cfSGwsMO62w6x7YyjjyWSWE8v30x9YytXnToy6CgiXUY7EqVHuO+NTYwp6s3pYwcEHUWky6jQJe2tqNnF0sp6rjx1pG6TK2lNhS5p7/43NpETyuCzJ+nKUElvKnRJa7v2tfLksve46IRhFORpzlBJbyp0SWsPlVXS2BLh6jNGBR1FpMup0CVthSNR/vTqJk4d3Z/JwwuCjiPS5VTokraeW7WVmvp9fPEjo4OOItItVOiStu5+5V2K++dyzsTBQUcR6RYqdElLy6vrKdu4k2vOGE2mJrGQHkKFLmnpj69spE+vLC7XJBbSg6jQJe1sbWjiqeXvcVnpCPrm6FRF6TlU6JJ27nt9E+Goc41OVZQeRoUuaaWpNcL9b1RyzsTBjBzQO+g4It1KhS5pZX55FTv2tvAlnaooPZAKXdJGayTKb/+xgZNH9mPa6P5BxxHpdip0SRtPL99MTf0+bpg+VndVlB5JhS5pIRp1fv3SOxwzuK8mgJYeS4UuaeHFtbWs3bqb62eMIUMXEkkPpUKXlOfu/OqldxhemMuFJwwLOo5IYFTokvJe27CdxZt2MudjYwhl6istPZe+/ZLS3J1fPLeewfm9uOKU4qDjiARKhS4p7ZWK7by5cQc3njWOnFBm0HFEAqVCl5Tl7vz8+XUMLcjR6FwEFbqksEXrt7F4005uPGscvbI0OhdRoUtKcnd+/tw6hhfmcnmpRucioEKXFPXcqq0sq6rnxrPGkZ2lr7EIqNAlBYUjUX781zWMKerNZZrAQuQAFbqknHllVWyo28stM4/Veeci7ei3QVLKnuYwv3h+HdNG9efcSZr8WaS9rKADiByOuYs2sG1PC7+/eqLuqCjSQVwjdDObaWZrzazCzG45yOtfN7NVZrbczP5uZiMTH1V6uq0NTfxu0QYuPGEoJxYXBh1HJOl0WuhmlgncBZwPTAJmm9mkDpstBUrd/QTgEeAniQ4q8oOnVxNx59szjw06ikhSimeEPg2ocPcN7t4CzAMuab+Bu7/o7o2xxdcBnXogCfVqxTb+8tZ7fHnGWIr75wUdRyQpxVPow4GqdsvVsXWHci3wzMFeMLM5ZlZuZuV1dXXxp5QerSUc5f8+uYKS/nlcP31s0HFEklZCz3Ixs6uAUuCnB3vd3ee6e6m7lxYVFSXyoyWN3f3Ku7xTt5fbL56kG3CJfIh4znKpAdpfWz0itu59zOwc4DvAdHdvTkw86elq6vdx59/Xc+6kwZx9rE5TFPkw8YzQy4DxZjbazLKBWcCC9huY2VTgt8DF7l6b+JjSE7k7tz72NgC3XdjxOLyIdNRpobt7GLgJeBZYDcx395VmdoeZXRzb7KdAH+BhM1tmZgsO8XYicXuorIpF6+q49fxjdSBUJA5xXVjk7guBhR3W3dbu+TkJziU9XE39Pn7w9GpOG9OfK0/VZQ0i8dCl/5J09u9qibrzk89MISNDV4SKxEOFLknnvtc3sWhdHbecfywlA7SrRSReKnRJKivf28X3n17N9AlFXKVdLSKHRYUuSWNPc5ivPLCUfnkhfna5drWIHC7dbVGSgrvz3cffZuP2vTxw3WkM6NMr6EgiKUcjdEkKD7xZyRPL3uOrH5/AaWMGBB1HJCWp0CVwr76zje89uZIZxxRx09njgo4jkrJU6BKojdv2csN9Sxg1sDd3zp5KpvabixwxFboEpqGplWv/VIYZ/OHqUvJzQkFHEklpKnQJxL6WCF+6p5xN2xv5zVUnM3JA76AjiaQ8neUi3a4lHOX6+xZTtmkHd86aqoOgIgmiEbp0q0jU+dpDy/jHujr+69LjuWjKsKAjiaQNFbp0m9ZIlK/PX8bTb2/mOxdMZNa0kqAjiaQV7XKRbtHUGuGmB5bw/OpavvmJY7juY2OCjiSSdlTo0uX2Noe57t5yXn1nO9+/5Dg+d/qooCOJpCUVunSp6p2NXHfvYtZt3c3Pr5jCpVNHBB1JJG2p0KXLlG3cwfV/XkxLJMrd15zC9AmaGFykK6nQJeHcnfveqOSOv6ykuF8ev7u6lLFFfYKOJZL2VOiSUDv3tvDtR5fzt1VbmT6hiDtnTaUgT1eAinQHFbokzMvrt/GNh99i+95mvvvJiXzxzNG6p7lIN1Khy1HbvqeZHzy9mseX1jCmqDe/v/pMJg8vCDqWSI+jQpcjFo5EmV9ezU+eXcPe5jD/dvY4vnzWOHJCmUFHE+mRVOhy2Nyd51Zt5SfPrqWidg/TRvXnh5dOZvzgvkFHE+nRVOgSt2jUeWFNLXe9VMHSynrGFPXmt587mfMmDcZM+8pFgqZCl041tUZ4avlm5i56h3Vb9zC8MJcfXjqZK0qLycrU7YBEkoUKXQ6ponYP896s5JEl1dQ3tjJhcB9+fsUULjxhGCEVuUjSUaHL+1TtaOQvy9/jqbc2s2pzA1kZxicmD+HKU0s4fcwA7VoRSWIq9B4uHInyVnU9/1hbx0vr6lhevQuAqSWF3HbhJC6cMpRBfXMCTiki8VCh9zAt4SirNjdQvnEHizft5JWKbTQ0hckwmFrSj2/PPJYLTxhKcf+8oKOKyGFSoaexfS0R1m7dzZrNDaze3MCqzQ0sr95FczgKwIh+ucycPITpEwbxkXEDdYm+SIpToae4ptYI1Tsb2bS97adyRyObtu9l4/ZGNm7fi3vbdr2zMzlmSF+uPHUkpaP6cfLIfgzO164UkXQSV6Gb2Uzg/wGZwO/d/ccdXu8F3AucDGwHrnD3jYmN2jNEo87upjC79rUe+Knf10JtQzO1u5up3d1E3e7m2HITOxtb3/fn+/TKoqR/HscO6cvFU4YxcWg+E4f2pbhfnu6rIpLmOi10M8sE7gLOBaqBMjNb4O6r2m12LbDT3ceZ2Szgv4EruiJwV3J3og5Rd6LuuLfNgxmOOK2RKK1RpzUcJRyN0hJ2wtEorbHXDmwTaVsXjkZpbo3S2BJmX2uUfS1h9rVGaGyJsK81wr7YY2NLhL3NYRqaWtnV2Mru5vCBUXVHoUyjqE8vivJzKBmQR+motlF2Sf88SgbkMbJ/Hv17Z+tMFJEeKp4R+jSgwt03AJjZPOASoH2hXwLcHnv+CPBLMzP3Q1XTkZtfVsXcf244ULj7yzca7VjI+5f/te592zvQcbmL5YYyycvOJDc788DznFAmg/NzmDC4LwW5IfJzQ+TnZFGQGzrwU5iXTVHfXhTmhjTKFpFDiqfQhwNV7ZargVMPtY27h81sFzAA2NZ+IzObA8wBKCk5shnf+/XO5pjBfTGDDDMyYo/W7nlGBu9fNnvf9tZh+QN/vt02oYwMQplGVmbbYygzg6zMDLIzjayMDEJZGYQy3v962zZGr6wM8rKzyA1lkhPK0MhZRLpUtx4Udfe5wFyA0tLSIxoTnztpMOdOGpzQXCIi6SCe67drgOJ2yyNi6w66jZllAQW0HRwVEZFuEk+hlwHjzWy0mWUDs4AFHbZZAFwde/5Z4IWu2H8uIiKH1ukul9g+8ZuAZ2k7bfFud19pZncA5e6+APgD8GczqwB20Fb6IiLSjeLah+7uC4GFHdbd1u55E3BZYqOJiMjh0D1QRUTShApdRCRNqNBFRNKECl1EJE1YUGcXmlkdsOkI//hAOlyFmiSU6/Ao1+FL1mzKdXiOJtdIdy862AuBFfrRMLNydy8NOkdHynV4lOvwJWs25To8XZVLu1xERNKECl1EJE2kaqHPDTrAISjX4VGuw5es2ZTr8HRJrpTchy4iIh+UqiN0ERHpQIUuIpImUr7QzexmM3MzGxh0FgAz+76ZLTezZWb2NzMbFnQmADP7qZmtiWV73MwKg84EYGaXmdlKM4uaWeCnl5nZTDNba2YVZnZL0HkAzOxuM6s1sxVBZ2nPzIrN7EUzWxX7b/jVoDMBmFmOmb1pZm/Fcv1n0JnaM7NMM1tqZk8l+r1TutDNrBg4D6gMOks7P3X3E9z9ROAp4LZOtu8uzwGT3f0EYB1wa8B59lsBfBpYFHSQdhOinw9MAmab2aRgUwFwDzAz6BAHEQZudvdJwGnAjUny99UMnO3uU4ATgZlmdlqwkd7nq8DqrnjjlC504OfAt4CkObLr7g3tFnuTJNnc/W/uHo4tvk7bzFOBc/fV7r426BwxByZEd/cWYP+E6IFy90W0zTOQVNx9s7sviT3fTVtJDQ82FXibPbHFUOwnKX4PzWwE8Eng913x/ilb6GZ2CVDj7m8FnaUjM/uhmVUBV5I8I/T2vgg8E3SIJHSwCdEDL6hUYGajgKnAGwFHAQ7s1lgG1ALPuXtS5AJ+QdsgNNoVb96tk0QfLjN7HhhykJe+A/wHbbtbut2H5XL3J939O8B3zOxW4Cbge8mQK7bNd2j7p/L93ZEp3lySusysD/Ao8O8d/oUaGHePACfGjhU9bmaT3T3QYxBmdiFQ6+6LzWxGV3xGUhe6u59zsPVmdjwwGnjLzKBt98ESM5vm7luCynUQ99M201O3FHpnuczsGuBC4OPdOefrYfx9BS2eCdGlHTML0Vbm97v7Y0Hn6cjd683sRdqOQQR9UPlM4GIzuwDIAfLN7D53vypRH5CSu1zc/W13H+Tuo9x9FG3/ND6pO8q8M2Y2vt3iJcCaoLK0Z2Yzafun3sXu3hh0niQVz4ToEmNto6k/AKvd/WdB59nPzIr2n8VlZrnAuSTB76G73+ruI2KdNQt4IZFlDila6Enux2a2wsyW07ZLKClO5QJ+CfQFnoudUvmboAMBmNmlZlYNnA48bWbPBpUldtB4/4Toq4H57r4yqDz7mdmDwGvAMWZWbWbXBp0p5kzgc8DZse/UstjoM2hDgRdjv4NltO1DT/gpgslIl/6LiKQJjdBFRNKECl1EJE2o0EVE0oQKXUQkTajQRUTShApdRCRNqNBFRNLE/weD7VlwWKyRpQAAAABJRU5ErkJggg==\n"
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"# 非关键代码,用于验证 v_func 函数图像是否与文章中\"图 4.1\" 一致\n",
"v_func = lambda x: 1.0 / (1.0 + np.exp(-x*2))\n",
"x = np.linspace(-4, 4, 300)\n",
"plt.plot(x, [v_func(v) for v in x])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def DR(encoded_x):\n",
" \"\"\"计算特征维度缩减率\"\"\"\n",
" return 1 - 1.0 * sum([1 for v in encoded_x if v]) / len(encoded_x)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" 0 1 2 3 4 5 6 7 8 9 10\n",
"0 True True True False False True True True False True 0.3\n",
"1 True False False False True True True False True True 0.4\n",
"2 True False True True True True False False True False 0.4\n",
"3 False False False False False False True True False False 0.8\n",
"4 False False True False False True False False False True 0.7\n",
"5 True True False False True False True True True False 0.4\n",
"6 False False True True False True True False True True 0.4\n",
"7 False True True False True True True False False True 0.4\n",
"8 False True True False True False True False False False 0.6\n",
"9 False True True True True False False False True True 0.4"
],
"text/html": "\n\n
\n \n \n | \n 0 | \n 1 | \n 2 | \n 3 | \n 4 | \n 5 | \n 6 | \n 7 | \n 8 | \n 9 | \n 10 | \n
\n \n \n \n | 0 | \n True | \n True | \n True | \n False | \n False | \n True | \n True | \n True | \n False | \n True | \n 0.3 | \n
\n \n | 1 | \n True | \n False | \n False | \n False | \n True | \n True | \n True | \n False | \n True | \n True | \n 0.4 | \n
\n \n | 2 | \n True | \n False | \n True | \n True | \n True | \n True | \n False | \n False | \n True | \n False | \n 0.4 | \n
\n \n | 3 | \n False | \n False | \n False | \n False | \n False | \n False | \n True | \n True | \n False | \n False | \n 0.8 | \n
\n \n | 4 | \n False | \n False | \n True | \n False | \n False | \n True | \n False | \n False | \n False | \n True | \n 0.7 | \n
\n \n | 5 | \n True | \n True | \n False | \n False | \n True | \n False | \n True | \n True | \n True | \n False | \n 0.4 | \n
\n \n | 6 | \n False | \n False | \n True | \n True | \n False | \n True | \n True | \n False | \n True | \n True | \n 0.4 | \n
\n \n | 7 | \n False | \n True | \n True | \n False | \n True | \n True | \n True | \n False | \n False | \n True | \n 0.4 | \n
\n \n | 8 | \n False | \n True | \n True | \n False | \n True | \n False | \n True | \n False | \n False | \n False | \n 0.6 | \n
\n \n | 9 | \n False | \n True | \n True | \n True | \n True | \n False | \n False | \n False | \n True | \n True | \n 0.4 | \n
\n \n
\n
"
},
"metadata": {},
"execution_count": 8
}
],
"source": [
"# 非关键代码,用于测试 DR 函数\n",
"test_encoded_xs = np.ones((10, 10), dtype=bool)\n",
"mask = np.random.normal(0, 1, (10, 10)) > 0\n",
"test_encoded_xs[mask] = False\n",
"df = pd.DataFrame(test_encoded_xs)\n",
"df[10] = [DR(x) for x in test_encoded_xs]\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def CA(encode_x, data_X, data_y):\n",
" \"\"\"计算分类准确率\"\"\"\n",
" clf = make_pipeline(StandardScaler(), SVC(gamma='auto'))\n",
" clf.fit(data_X[:,encode_x], data_y)\n",
" return clf.score(data_X[:,encode_x], data_y)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"准确率为: 0.9345794392523364, 选中的特征下标: 0-1-3-6-7-8-9\n"
]
}
],
"source": [
"# 非关键代码,用于测试 CA 函数\n",
"encode_x = np.ones(n_col-1, dtype=bool)\n",
"encode_x[np.random.normal(0, 1, n_col-1) > 0] = False\n",
"\n",
"ca = CA(encode_x, data[:,:10], data[:,10])\n",
"print('准确率为: {ca}, 选中的特征下标: {idxs_feature}'.format(ca=ca, idxs_feature='-'.join([str(i) for i, v in enumerate(encode_x) if v])))"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"# 目标函数包含 CA 和 DR 两部分内容,如果需要修改分类器,请修改 CA 函数\n",
"# 说明:文章中\"图 4.2\" 曲线与给定的公式 4.2 是对不上的,所以按最曲线重新定义了个 rho 函数\n",
"def rho(l):\n",
" return 0.90 + 0.099 / (l/2000 + 1)\n",
"\n",
"def objective(encode_x, data_X, data_y):\n",
" l = len(encode_x)\n",
" return rho(l) * CA(encode_x, data_X, data_y) + (1 - rho(l)) * DR(encode_x)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Text(0, 0.5, 'Value of Rho')"
]
},
"metadata": {},
"execution_count": 12
},
{
"output_type": "display_data",
"data": {
"text/plain": "",
"image/svg+xml": "\r\n\r\n\r\n\r\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEGCAYAAAB2EqL0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAmTklEQVR4nO3deXxddZ3/8dc7N7lZmjTpki60pa1QwLJDrCIKhZ8iMg4M6CiMo+j4k3HBZfwxIzyc3zDiTxlxXxgFHUYYV0RHGUURsQgzgjRslUI3CoW2lKZ70yXr5/fHOUlv0+bmts3NTZP38/G4j5zzPeee87m9bd79nu9ZFBGYmZkNpKzUBZiZ2eHBgWFmZgVxYJiZWUEcGGZmVhAHhpmZFaS81AUMlokTJ8asWbNKXYaZ2WHlkUce2RARjYWsO2ICY9asWTQ3N5e6DDOzw4qkVYWu60NSZmZWEAeGmZkVxIFhZmYFKVpgSLpF0npJT/azXJK+KmmFpEWSTstZdrmk5enr8mLVaGZmhStmD+M7wPl5lr8RmJO+rgC+ASBpPHAt8EpgHnCtpHFFrNPMzApQtMCIiPuBTXlWuQi4LRIPAQ2SpgJvAO6JiE0RsRm4h/zBY2ZmQ6CUYxjTgBdy5lenbf2170PSFZKaJTW3tLQUrVAzMzvMB70j4uaIaIqIpsbGgq472UdrWydfvGcZj7+wZXCLMzMbYUoZGGuAGTnz09O2/tqLoqOzm6/eu5zHn99crF2YmY0IpQyMO4F3pmdLvQrYGhEvAncD50kalw52n5e2FUVVRQaA3Z3dxdqFmdmIULRbg0j6ATAfmChpNcmZTxUAEfFN4C7gAmAFsBN4d7psk6RPAQvTTV0XEfkGzw9JZXmSmbs7uoq1CzOzEaFogRERlw2wPIAP9rPsFuCWYtTVV1mZyGbK2N3hHoaZWT6H9aD3YKmsKHMPw8xsAA4MknGMNo9hmJnl5cAAqirKaHMPw8wsLwcGUFWeYXenA8PMLB8HBj1jGD4kZWaWjwODtIfhQ1JmZnk5MEgGvR0YZmb5OTBIBr19SMrMLD8HBlBZkaHNg95mZnk5MEhuD+IehplZfg4Mei7ccw/DzCwfBwY9Z0m5h2Fmlo8Dg55Bb/cwzMzycWCQHJLq7A46u9zLMDPrjwODpIcBfoiSmVk+Dgygsjx96p4PS5mZ9cuBwZ4ehm9xbmbWPwcGOc/1dg/DzKxfDgx8SMrMrBAODHIGvX0thplZvxwY7Dkk5afumZn1z4FBci8pwE/dMzPLw4EB1GTLAdjV7kNSZmb9cWAANdnkkNSO9s4SV2JmNnw5MNgTGLvafUjKzKw/DgxgTGVySMo9DDOz/jkwSAa9JfcwzMzycWAAkhiTLWdHmwPDzKw/DoxUdTbDrg4fkjIz648DIzUmm3EPw8wsDwdGqjpbzk6PYZiZ9cuBkRqTzbDTZ0mZmfWrqIEh6XxJSyWtkHT1fpbPlHSvpEWS7pM0PWfZDZIWS3pa0lclqZi1Vmcz7mGYmeVRtMCQlAFuBN4IzAUukzS3z2qfB26LiJOA64Dr0/e+GjgTOAk4AXgFcHaxagUYky13D8PMLI9i9jDmASsiYmVEtAM/BC7qs85c4Hfp9IKc5QFUAVmgEqgAXipirdS4h2FmllcxA2Ma8ELO/Oq0LdcTwCXp9MVAnaQJEfEgSYC8mL7ujoin++5A0hWSmiU1t7S0HFKxNZUODDOzfEo96H0VcLakx0gOOa0BuiQdDbwcmE4SMudKem3fN0fEzRHRFBFNjY2Nh1RIjQ9JmZnlVV7Eba8BZuTMT0/bekXEWtIehqRa4M0RsUXSe4GHIqI1XfYr4AzggWIVW5PNsLujm67uIFNW1PF1M7PDUjF7GAuBOZJmS8oClwJ35q4gaaKknhquAW5Jp58n6XmUS6og6X3sc0hqMI3peSaGn7pnZrZfRQuMiOgErgTuJvllf3tELJZ0naQL09XmA0slLQMmA59O2+8AngH+RDLO8URE/FexaoXktFqAnW0+LGVmtj/FPCRFRNwF3NWn7Z9ypu8gCYe+7+sC/raYtfU1pjINDA98m5ntV6kHvYeN6go/E8PMLB8HRso9DDOz/BwYqZ6n7rV6DMPMbL8cGKmxVUlgbN/twDAz2x8HRqquqgKA7bs7SlyJmdnw5MBI1fYcknIPw8xsvxwYqZpshkyZfEjKzKwfDoyUJGory31IysysHw6MHLWV5Wz3WVJmZvvlwMhRV1XuQ1JmZv1wYOQYW1XhQ1JmZv1wYORwD8PMrH8OjBy1VeW+0tvMrB8OjBzuYZiZ9c+BkaMuHcOIiFKXYmY27DgwctRWltPRFbR1dpe6FDOzYceBkcM3IDQz658DI4dvQGhm1j8HRo6eGxC6h2Fmti8HRo6x1UkPY5t7GGZm+3Bg5BhXkwTGlp0ODDOzvhwYOep7A6O9xJWYmQ0/DowcDdVZwD0MM7P9cWDkyJaXMSabYbMDw8xsH+WFrCQpCxyTzi6NiBH7G7WhJsuWXT4kZWbW14CBIWk+cCvwHCBghqTLI+L+olZWIg01FT4kZWa2H4X0ML4AnBcRSwEkHQP8ADi9mIWVyriaLJs96G1mto9CxjAqesICICKWARXFK6m0Gmoq2OoehpnZPgrpYTRL+jbw3XT+7UBz8UoqrYaaCvcwzMz2o5DAeD/wQeDD6fwDwL8WraISG1eTZeuuDrq7g7IylbocM7NhY8DAiIg24Ivpa8Srr66gO5L7SfVcyGdmZgWMYUg6U9I9kpZJWtnzGoriSmFcTXLxng9LmZntrZBB738j6V28BnhFzmtAks6XtFTSCklX72f5TEn3Slok6T5J03OWHSnpN5KelvSUpFkFfaJDNG5MenuQXR74NjPLVcgYxtaI+NWBblhSBrgReD2wGlgo6c6IeCpntc8Dt0XErZLOBa4H3pEuuw34dETcI6kWGJLH4NVXu4dhZrY//QaGpNPSyQWSPgf8FGjrWR4Rjw6w7XnAiohYmW7vh8BFQG5gzAU+1rMf4GfpunOB8oi4J91Xa4Gf55BNGJMExqZWB4aZWa58PYwv9JlvypkO4NwBtj0NeCFnfjXwyj7rPAFcAnwFuBiokzSB5DYkWyT9FJgN/Ba4OiK6ct8s6QrgCoAjjzxygHIKM7GuEoCW1rYB1jQzG136DYyIOGcI9n8V8HVJ7wLuB9YAXWldrwVOBZ4HfgS8i2Q8JbfGm4GbAZqammIwChqTzVBVUcaG7Q4MM7NceQe9JWUkTcyZz0p6r6SnC9j2GmBGzvz0tK1XRKyNiEsi4lTgE2nbFpLeyOMRsTIiOkkOVZ3GEJDExNpKNriHYWa2l34DQ9KlwCZgkaTfSzoPWAlcQHK190AWAnMkzU7vdnspcGeffUyU1FPDNcAtOe9tkNSYzp/L3mMfRZUEhscwzMxy5RvD+Efg9IhYkQ6APwi8JSL+q5ANR0SnpCuBu4EMcEtELJZ0HdAcEXcC84HrJQXJIakPpu/tknQVcK8kAY8A3zq4j3jgGusqeWHTzqHanZnZYSFfYLRHxApIzoiStLzQsOgREXcBd/Vp+6ec6TuAO/p57z3ASQeyv8EysbaSx57fXIpdm5kNW/kCY5Kkj+XMN+TOR8SIvVVIY22WTTva6eoOMr6flJkZkD8wvgXU5ZkfsSbWVdIdsHFHG5PqqkpdjpnZsJDvtNpPDmUhw8nE2uRajA3b2x0YZmapQu4lNer0BoZPrTUz6+XA2I/GOgeGmVlf+a7D+Ej688yhK2d46AmM9b7a28ysV74exrvTn18bikKGk9rKcuoqy1m3dXepSzEzGzbynSX1tKTlwBGSFuW0C4iIKMk1EkNlSn0VL27dVeoyzMyGjXxnSV0maQrJldoXDl1Jw8PUhmr3MMzMcuQd9I6IdRFxMvAiyTUYdcDaiFg1FMWV0tSxVax1YJiZ9RrwiXuSziZ5+t1zJIejZki6PCLuL3JtJTWlvooNrW20d3aTLffJZGZmhTyi9YvAeRGxFEDSMcAPgNOLWVipHdFQRQS8tG03M8bXlLocM7OSK+S/zhU9YQEQEcuAiuKVNDxMqa8GYN02H5YyM4PCehjNkr4NfDedfzvQXLyShocj6pNbgrzocQwzM6CwwHg/yXMqPpzOPwD8a9EqGiam9ATGFp9aa2YGBQRGRLSRjGOM2NuZ709dVQV1leXuYZiZpXz6Tx5HNFSzerN7GGZm4MDI68gJNTy/aUepyzAzGxYKDgxJo+7c0pnja3h+004iotSlmJmV3ICBIenVkp4ClqTzJ0sa8YPekPQwdnd00+K71pqZFdTD+BLwBmAjQEQ8AZxVzKKGiyPTC/ZWbdpZ4krMzEqvoENSEfFCn6auItQy7MycMAaAVRsdGGZmhVyH8YKkVwMhqQL4CPB0ccsaHqY1VFMmeN49DDOzgnoY7yO5cG8asAY4JZ0f8bLlZUytr+b5jT5TysyskAv3NpDcDmRUmjmhxmMYZmYUdnvzfwf2Oa80Iv6mKBUNM7MmjuGXi14kIpBU6nLMzEqmkDGMX+RMVwEXA2uLU87wM2dSLVt3dbChtZ3GuspSl2NmVjKFHJL6Se68pB8A/120ioaZOZPqAFixvtWBYWaj2sHcGmQOMGmwCxmujp5UC8CK9dtLXImZWWkVMoaxnWQMQ+nPdcDHi1zXsDF5bCV1leUsX99a6lLMzEqqkENSdUNRyHAliaMn17L8JQeGmY1u/R6SknRavlchG5d0vqSlklZIuno/y2dKulfSIkn3SZreZ/lYSaslff3AP9rgmTOp1j0MMxv18vUwvpBnWQDn5tuwpAxwI/B6YDWwUNKdEfFUzmqfB26LiFslnQtcD7wjZ/mngPvz7WcoHD2pltubV7NlZzsNNdlSl2NmVhL9BkZEnHOI254HrIiIlQCSfghcBOQGxlzgY+n0AuBnPQsknQ5MBn4NNB1iLYfkuCljAXjqxW28+qiJpSzFzKxkCjpLStIJkt4q6Z09rwLeNg3IvWnh6rQt1xPAJen0xUCdpAmSykh6OFcNUNcVkpolNbe0tBTyUQ7K8UckgbF4zbai7cPMbLgr5HkY1wJfS1/nADcAFw7S/q8Czpb0GHA2yb2quoAPAHdFxOp8b46ImyOiKSKaGhsbB6mkfU2orWRqfRWL124t2j7MzIa7Qq70fgtwMvBYRLxb0mTguwW8bw0wI2d+etrWKyLWkvYwJNUCb46ILZLOAF4r6QNALZCV1BoR+wycD5Xjj6jnybXuYZjZ6FXIIaldEdENdEoaC6xn7yDoz0JgjqTZkrLApcCduStImpgefgK4BrgFICLeHhFHRsQskl7IbaUMC4ATpo3lmZZWdrZ3lrIMM7OSKSQwmiU1AN8CHgEeBR4c6E0R0QlcCdxN8vyM2yNisaTrJPUc0poPLJW0jGSA+9MH/AmGyPFH1BMBT7/oK77NbHTq95CUpBuB70fEB9Kmb0r6NTA2IhYVsvGIuAu4q0/bP+VM3wHcMcA2vgN8p5D9FdMJ05KB7yfXbOX0meNKXI2Z2dDL18NYBnxe0nOSbpB0akQ8V2hYjDRTxlYxeWwljz6/udSlmJmVRL+BERFfiYgzSM5e2gjcImmJpGslHTNkFQ4TkmiaOZ7m5xwYZjY6DTiGERGrIuKzEXEqcBnwF4ySZ3r3dfrMcazZsot1W3eXuhQzsyFXyHUY5ZL+XNL3gF8BS9lzsd2o0jQrGbtoXrWpxJWYmQ29fDcffL2kW0iu0H4v8EvgqIi4NCJ+PlQFDicvnzqW6oqMD0uZ2aiU78K9a4DvA/8nIvwbEqjIlHHKjAb3MMxsVMo36H1uRHzbYbG3V71sAovXbmPzjvZSl2JmNqQO5hGto9pr5kwkAv7wzMZSl2JmNqQcGAfo5On11FWV88Dy4t0d18xsOHJgHKDyTBlnHjWRB5ZvICJKXY6Z2ZBxYByE18yZyJotu3h2w45Sl2JmNmQcGAfh7GOSZ2/8bsn6EldiZjZ0HBgHYcb4Go6bUsdvFr9U6lLMzIaMA+MgveH4KSxctYmW7W2lLsXMbEg4MA7S+SdMIQJ++7R7GWY2OjgwDtJxU+qYOaGGXz+5rtSlmJkNCQfGQZLE+SdM4X9WbGBjqw9LmdnI58A4BJecOp3O7uDnj68tdSlmZkXnwDgEx06p48Rp9fzk0dWlLsXMrOgcGIfozadNY/HabTz94rZSl2JmVlQOjEN04SnTqMiIHy18odSlmJkVlQPjEI0fk+XPTpzKHY+sZvvujlKXY2ZWNA6MQfCuM2fT2tbJTx7xWIaZjVwOjEFwyowGTj2ygVsfXEV3t+9ga2YjkwNjkLz7zNk8u2EH9/jKbzMboRwYg+SCE6Ywa0INX/ntcj8nw8xGJAfGICnPlPGhc+fw1IvbuOcp9zLMbORxYAyii045glkTavjyb5d7LMPMRhwHxiAqz5TxkdclvYyfPram1OWYmQ0qB8Ygu+jkaZw8o4Ebfr2EHW2dpS7HzGzQODAGWVmZuPbP57J+exs3LlhR6nLMzAaNA6MITjtyHJecOo1vPbCSJet8jykzGxmKGhiSzpe0VNIKSVfvZ/lMSfdKWiTpPknT0/ZTJD0oaXG67G3FrLMY/vFNcxlbVcHf/3gRnV3dpS7HzOyQFS0wJGWAG4E3AnOByyTN7bPa54HbIuIk4Drg+rR9J/DOiDgeOB/4sqSGYtVaDOPHZPnUX5zAn9Zs5ab7V5a6HDOzQ1bMHsY8YEVErIyIduCHwEV91pkL/C6dXtCzPCKWRcTydHotsB5oLGKtRXHBiVP5s5Om8qV7lvHIqs2lLsfM7JAUMzCmAbn3/F6dtuV6Argknb4YqJM0IXcFSfOALPBM3x1IukJSs6TmlpaWQSt8MH3m4hOZ2lDFld9/lE072ktdjpnZQSv1oPdVwNmSHgPOBtYAXT0LJU0F/gN4d0TsMxAQETdHRFNENDU2Ds8OSH11Bf/6V6ezsbWdj/7ocY9nmNlhq5iBsQaYkTM/PW3rFRFrI+KSiDgV+ETatgVA0ljgl8AnIuKhItZZdCdOr+eTFx3P/ctauPbOxb7XlJkdlsqLuO2FwBxJs0mC4lLgr3JXkDQR2JT2Hq4Bbknbs8B/kgyI31HEGofMZfOO5LmNO7jp9yuZOaGGK846qtQlmZkdkKL1MCKiE7gSuBt4Grg9IhZLuk7Shelq84GlkpYBk4FPp+1vBc4C3iXp8fR1SrFqHSoff8Nx/NlJU/nMXUv44cPPl7ocM7MDopFyeKSpqSmam5tLXcaAdnd08b7vPsJ9S1u44c0n8dZXzBj4TWZmRSLpkYhoKmTdUg96jzpVFRm++denc9YxjXz8p4v4j4dWlbokM7OCODBKoKoiw83vOJ1zj53E//3Zk3z210t8O3QzG/YcGCVSVZHhpneczmXzjuQb9z3DR3/0OLvauwZ+o5lZiRTzLCkbQHmmjM9cfALTx1Xz+d8sZdlL2/nGX5/O7IljSl2amdk+3MMoMUl88Jyj+c675/HStt38+df+mzufWFvqsszM9uHAGCbOPqaRX3z4tcyZXMuHf/AYH/zeo2xsbSt1WWZmvRwYw8i0hmp+/Ldn8A/nH8s9T73EeV+6n58/vsZXhpvZsODAGGbKM2V8YP7R/NeHXsO0cdV85IeP87abHuKptX4Qk5mVlgNjmDp2Sh3/+YEz+ZdLTmRFSytv+toDfPyORazevLPUpZnZKOUrvQ8DW3d28JV7l/Pdh1YRBJe+4kg+eM7RTKmvKnVpZnaYO5ArvR0Yh5G1W3bx9QUruH3hC5RJXHTKEbzntbM5bsrYUpdmZocpB8YI98KmnXzrgZX8uHk1uzq6eM3RE3nXq2cx/9hGyjM+ymhmhXNgjBJbdrbz/Yef59Y/PMdL29qYVFfJm0+fzlubZvjiPzMriANjlOno6uZ3S9Zz+8IXWLB0Pd0Bp88cx5tOmsobT5jqsQ4z65cDYxR7adtufvLoau58fC1L1m0HoGnmOC44cSqvnzuZGeNrSlyhmQ0nDgwD4JmWVu5a9CK//NOLveHxssYxzD9mEvOPbWTe7PFUVWRKXKWZlZIDw/bx7IYdLFiynvuWtfDQyo20d3ZTXZGhadY4Xjl7PPNmT+Ck6fUOELNRxoFhee1q7+KhlRv5fRoePb2PbHkZp85oYN7s8Zw8vYGTZtQzqc7jH2Yj2YEEhm9vPgpVZzOcc9wkzjluEpCcbbXwuc38ceVGHn5uEzcuWEHP85yOqK/ipDQ8Tp7ewPFHjKWhJlvC6s2sVBwYRkNNltfPnczr504GYGd7J0+u2cai1Vt4YvVWFq3ewq8Xr+tdf/LYSo6ZXMdxU+o4dspYjptSx9GTan04y2yEc2DYPmqy5cybPZ55s8f3tm3Z2c6i1VtZsm4bS9ZtZ+m67dz64CraO7sBKBMcOb6GWRPHMLvP64j6asrKVKqPY2aDxIFhBWmoyXLWMY2cdUxjb1tnVzerNu1k6brtLFm3nWdaWnm2ZQcPP7uJnTmPm82WlzFrQg0zJ4xhWkM108clr2kNNUwbV824mgokB4rZcOfAsINWninjqMZajmqs5YITp/a2RwTrt7exsmUHz23cwbMbkteqjTv4w4oN7Ojz7PKabIZpDdVMG1fNtIZqpoytYvLYKiaNrWRyOu1QMSs9B4YNOkm9v+jPOGrCXssigq27Oli9eRerN+9izZZdrNm8i9Wbd7Jmyy4ef2ELW3Z27LPNbKaMxrrKJETqqpg8tpJJY6sYPybL+DFZJtZmGT+mkvFjsoytKne4mBWBA8OGlCQaarI01GQ5YVr9ftfZ3dFFy/Y2Xtq2m/Xpz5e2tbF+225e2r6bFS2t/M8zG9i+u3O/76/IKA2SSiakgTJ+TJYJY7I0jMlSX11BfXUFDenP+uoKxlZXkPE4i1leDgwbdqoqMswYXzPgbUx2d3SxcUc7m1rb2bijjU072tnY2p609czvaOeFzTvZ2NpOa9v+A6ZHXWU5Y3NCpKFm70CpqyqntrKcMZXl1FWWU1u193R1RcY9GxvRHBh22KqqSMc+GqoLWr+ts4utOzvYumvf15a0fVtO24r1rb3TbenZYPmUiX3CpLayfK+gqclmqMkm4VKTzVCdzu+ZzlBdsae9uiLjno8NGw4MGzUqyzNMGpth0tgDv3p9d0cXrW2d7GjrZPvuzt7p1j7zPdOtuzvZ0Z7Mv7h1N61p+872zt6LIguvu2xP0KShUpUGTk02Q2V5hqqKMirLM1SmP6sO8mdleZlPgbZ+OTDMClBVkfySnlhbeUjbiQjaOrvZ1d7Fzo4udrV3srO9K2e+K51P2ne2d7Ert71jT/vG1nZWd3TR1tnF7o5u2jq62N3Z3XttzMHKZsr2Cp5seRnZTPKzIpNMV/S2KZnPWV5Zvvd88n7tNb+nfe/tJtvLUFEuKjJlVJSVkcmI8rJk3r2t0nJgmA0hSb3hM65I++juDtq7umnr6GZ3Z9f+f3Z00dbZvSds+vvZ0UV7VxJCHV3dtHd109EZ7NzVQUdn997Lcubbu7opxm3qJKgoK6M8J0SS6f23VeQsq8iUUV6mvdav2GfZ3u/Zp61MZNJtlEn7mU9CredV3s90Mp9/3eE4HubAMBthyspEVVkSSvVUlKyOzq5uOrpiT5B0de8VMrnzewInXb+zm87u5P1d6c/Oruht6+zqprM76OjqprMr6OhOfnb1tOUs6+zuZldH8rOzK2lP1tu7rbM7erc1HO7JKrEnkMrKKFNy7VOmTGS0J6gyZeL4I+r52mWnFr0mB4aZFUV5pozyTHKzy8NNT/B05YRIT/h0d0Nnd7qsOwmpnunuiN7g6uzu3mu+K6J3e73T3UFXVzddAV3dSWh152y37/yefe697pHjCzvx41AVNTAknQ98BcgA346If+mzfCZwC9AIbAL+OiJWp8suB/4xXfX/RcStxazVzKxHcmjo8Au6Yisr1oYlZYAbgTcCc4HLJM3ts9rngdsi4iTgOuD69L3jgWuBVwLzgGslFeuQr5mZFaBogUHyi35FRKyMiHbgh8BFfdaZC/wunV6Qs/wNwD0RsSkiNgP3AOcXsVYzMxtAMQNjGvBCzvzqtC3XE8Al6fTFQJ2kCQW+F0lXSGqW1NzS0jJohZuZ2b6KGRiFuAo4W9JjwNnAGqAr/1v2iIibI6IpIpoaGxsHfoOZmR20Yg56rwFm5MxPT9t6RcRa0h6GpFrgzRGxRdIaYH6f995XxFrNzGwAxexhLATmSJotKQtcCtyZu4KkiZJ6ariG5IwpgLuB8ySNSwe7z0vbzMysRIoWGBHRCVxJ8ov+aeD2iFgs6TpJF6arzQeWSloGTAY+nb53E/ApktBZCFyXtpmZWYkohsMljYOgqakpmpubS12GmdlhRdIjEdFU0LojJTAktQCrDmETE4ENg1TO4WK0febR9nnBn3m0OJTPPDMiCjpraMQExqGS1Fxoyo4Uo+0zj7bPC/7Mo8VQfeZSn1ZrZmaHCQeGmZkVxIGxx82lLqAERttnHm2fF/yZR4sh+cwewzAzs4K4h2FmZgVxYJiZWUFGfWBIOl/SUkkrJF1d6noOlKQZkhZIekrSYkkfSdvHS7pH0vL057i0XZK+mn7eRZJOy9nW5en6y9MHWPW0ny7pT+l7vqph8LBhSRlJj0n6RTo/W9If0xp/lN6OBkmV6fyKdPmsnG1ck7YvlfSGnPZh93dCUoOkOyQtkfS0pDNGwXf8d+nf6Scl/UBS1Uj7niXdImm9pCdz2or+vfa3jwFFxKh9kTwJ8BngZUCW5Hbrc0td1wF+hqnAael0HbCM5DkjNwBXp+1XA59Npy8AfgUIeBXwx7R9PLAy/TkunR6XLns4XVfpe984DD73x4DvA79I528HLk2nvwm8P53+APDNdPpS4Efp9Nz0+64EZqd/DzLD9e8EcCvwv9PpLNAwkr9jkscZPAtU53y/7xpp3zNwFnAa8GROW9G/1/72MWC9pf6HUOK/lGcAd+fMXwNcU+q6DvEz/Rx4PbAUmJq2TQWWptM3AZflrL80XX4ZcFNO+01p21RgSU77XuuV6DNOB+4FzgV+kf5j2ACU9/1eSe5ldkY6XZ6up77fdc96w/HvBFCf/vJUn/aR/B33PBNnfPq9/YLkwWoj7nsGZrF3YBT9e+1vHwO9RvshqYIe1HS4SLvhpwJ/BCZHxIvponUkN3eE/j9zvvbV+2kvpS8D/wB0p/MTgC2R3PAS9q6x93Oly7em6x/on0MpzQZagH9PD8N9W9IYRvB3HBFrSB7h/DzwIsn39ggj+3vuMRTfa3/7yGu0B8aIoeR5Ij8BPhoR23KXRfLfiBFx/rSkNwHrI+KRUtcyhMpJDlt8IyJOBXaQHEboNZK+Y4D0mPpFJGF5BDCGUfiY5qH4Xg9kH6M9MAZ8yNPhQFIFSVh8LyJ+mja/JGlqunwqsD5t7+8z52ufvp/2UjkTuFDScyTPiT8X+ArQIKnngWC5NfZ+rnR5PbCRA/9zKKXVwOqI+GM6fwdJgIzU7xjgdcCzEdESER3AT0m++5H8PfcYiu+1v33kNdoDY8CHPA136VkP/wY8HRFfzFl0J9BztsTlJGMbPe3vTM+4eBWwNe2a7vehVemybZJele7rnTnbGnIRcU1ETI+IWSTf1+8i4u3AAuAt6Wp9P2/Pn8Nb0vUjbb80PbtmNjCHZIBw2P2diIh1wAuSjk2b/hfwFCP0O049D7xKUk1aU89nHrHfc46h+F7720d+pRrUGi4vkjMPlpGcMfGJUtdzEPW/hqQ7uQh4PH1dQHL89l5gOfBbYHy6voAb08/7J6ApZ1t/A6xIX+/OaW8Cnkzf83X6DL6W8LPPZ89ZUi8j+UWwAvgxUJm2V6XzK9LlL8t5/yfSz7SUnLOChuPfCeAUoDn9nn9GcjbMiP6OgU8CS9K6/oPkTKcR9T0DPyAZo+kg6Um+Zyi+1/72MdDLtwYxM7OCjPZDUmZmViAHhpmZFcSBYWZmBXFgmJlZQRwYZmZWEAeGHXYkhaQv5MxfJemfB2nb35H0loHXPOT9/KWSu84u2M+yzym5S+vnDmK7p0i6YHCqNNubA8MOR23AJZImlrqQXDlXIBfiPcB7I+Kc/Sy7AjgpIv7+IMo4heT6goKlF4L5d4ENyH9J7HDUSfIM47/ru6BvD0FSa/pzvqTfS/q5pJWS/kXS2yU9nD4v4KiczbxOUrOkZem9q3qev/E5SQuVPIvgb3O2+4CkO0muRO5bz2Xp9p+U9Nm07Z9ILrj8t769iHQ7tcAjkt4mqVHST9L9LpR0ZrrePEkPKrkZ4R8kHZtesXwd8DZJj6fv/2dJV+Vs/0lJs9LXUkm3kVzYNUPS3+d8vk8e+NdiI92B/I/IbDi5EVgk6YYDeM/JwMuBTSTPDPh2RMxT8tCpDwEfTdebBcwDjgIWSDqa5LYKWyPiFZIqgf+R9Jt0/dOAEyLi2dydSToC+CxwOrAZ+I2kv4iI6ySdC1wVEc2574mICyW1RsQp6Ta+D3wpIv5b0pEkt4F4OckV0K+NiE5JrwM+ExFvTsOoKSKuTN//z3n+POYAl0fEQ5LOS+fnkVxRfKeksyLi/oL+ZG1UcGDYYSkitqX/O/4wsKvAty2M9JbOkp4Ben7h/wnIPTR0e0R0A8slrQSOI7k/z0k5vZd6kl+w7cDDfcMi9QrgvohoSff5PZIH5vyswHohuQnfXO15AN5YJXcmrgdulTSH5NYwFQewzR6rIuKhdPq89PVYOl9L8vkcGNbLgWGHsy8DjwL/ntPWSXqoNT0un81Z1pYz3Z0z383e/xb63i8nSP7X/aGIuDt3gaT5JLcbL5Yy4FURsbvPfr8OLIiIi5U8B+W+ft7f++eRqsqZzq1bwPURcdMhV2wjlscw7LAVEZtIHtn5npzm50gOAQFcyMH9z/svJZWl4xovI7lp3d3A+5XcSh5Jxyh5iFE+DwNnS5ooKUPyxLPfH2AtvyE5XEa631PSyXr23Kr6XTnrbyd5VG+P50gOmaHkGdCz+9nP3cDfpL0XJE2TNOkAa7URzoFhh7svALlnS32L5Jf0EySP4TyY//0/T/LL/lfA+9L/3X+bZFD7UUlPkjwGM28PPT38dTXJLbmfAB6JiAO9bfiHgaZ0IPop4H1p+w3A9ZIe61PHApJDWI9LehvJc1LGS1oMXElyd9b91fobkmekPyjpTyTP3Kjb37o2evlutWZmVhD3MMzMrCAODDMzK4gDw8zMCuLAMDOzgjgwzMysIA4MMzMriAPDzMwK8v8Bk5HNZh6iT0IAAAAASUVORK5CYII=\n"
},
"metadata": {
"needs_background": "light"
}
}
],
"source": [
"# rho 函数的图示\n",
"x = np.linspace(0, 100000, 10000)\n",
"plt.plot(x, rho(x))\n",
"plt.xlabel('Number of feature')\n",
"plt.ylabel('Value of Rho')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" 0 1 2 3 4 5 6 7 8 9\n",
"0 False True True True False False True False True True\n",
"1 False True False True True True False False False False\n",
"2 False True True False False False True True True False\n",
"3 False True True False False True True True True True\n",
"4 True True False True True False False True True True"
],
"text/html": "\n\n
\n \n \n | \n 0 | \n 1 | \n 2 | \n 3 | \n 4 | \n 5 | \n 6 | \n 7 | \n 8 | \n 9 | \n
\n \n \n \n | 0 | \n False | \n True | \n True | \n True | \n False | \n False | \n True | \n False | \n True | \n True | \n
\n \n | 1 | \n False | \n True | \n False | \n True | \n True | \n True | \n False | \n False | \n False | \n False | \n
\n \n | 2 | \n False | \n True | \n True | \n False | \n False | \n False | \n True | \n True | \n True | \n False | \n
\n \n | 3 | \n False | \n True | \n True | \n False | \n False | \n True | \n True | \n True | \n True | \n True | \n
\n \n | 4 | \n True | \n True | \n False | \n True | \n True | \n False | \n False | \n True | \n True | \n True | \n
\n \n
\n
"
},
"metadata": {},
"execution_count": 13
}
],
"source": [
"# 非关键代码,目标函数的测试 1\n",
"p = np.random.normal(0, 1, (20, 10)) # 测试种群\n",
"encoded_p = [encode(x) for x in p]\n",
"pd.DataFrame(encoded_p).head(5) # 编码后种群(前 5 个)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" 0 1 2 3 4 5 6 7 8 9 \\\n",
"0 False True True True False False True False True True \n",
"1 False True False True True True False False False False \n",
"2 False True True False False False True True True False \n",
"3 False True True False False True True True True True \n",
"4 True True False True True False False True True True \n",
"\n",
" fitness \n",
"0 0.742479 \n",
"1 0.710116 \n",
"2 0.695969 \n",
"3 0.746995 \n",
"4 0.924300 "
],
"text/html": "\n\n
\n \n \n | \n 0 | \n 1 | \n 2 | \n 3 | \n 4 | \n 5 | \n 6 | \n 7 | \n 8 | \n 9 | \n fitness | \n
\n \n \n \n | 0 | \n False | \n True | \n True | \n True | \n False | \n False | \n True | \n False | \n True | \n True | \n 0.742479 | \n
\n \n | 1 | \n False | \n True | \n False | \n True | \n True | \n True | \n False | \n False | \n False | \n False | \n 0.710116 | \n
\n \n | 2 | \n False | \n True | \n True | \n False | \n False | \n False | \n True | \n True | \n True | \n False | \n 0.695969 | \n
\n \n | 3 | \n False | \n True | \n True | \n False | \n False | \n True | \n True | \n True | \n True | \n True | \n 0.746995 | \n
\n \n | 4 | \n True | \n True | \n False | \n True | \n True | \n False | \n False | \n True | \n True | \n True | \n 0.924300 | \n
\n \n
\n
"
},
"metadata": {},
"execution_count": 14
}
],
"source": [
"# 非关键代码,目标函数的测试 2\n",
"df = pd.DataFrame(encoded_p)\n",
"df['fitness'] = [objective(encoded_x, data[:,:10], data[:,10]) for encoded_x in encoded_p]\n",
"df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def real_objectives(lam, population, data_X, data_y):\n",
" \"\"\"\n",
" 适应值塑造,对应文章中的公式 4.3, 返回 lam 个塑造后的较优适应值,及各个体下标\n",
" \"\"\"\n",
" encoded_population = [encode(x) for x in population]\n",
" fitness_values = [objective(x, data_X, data_y) for x in encoded_population]\n",
" # 取前 lam 个适应值较优的个体下标,再按升序排列\n",
" selected = np.argsort(fitness_values)[::-1][:lam][::-1]\n",
"\n",
" temp = 1.0 * (fitness_values[selected[lam-1]] - fitness_values[selected[0]]) / (lam - 1) \n",
" \n",
" real_fitness_values = [0] * lam\n",
"\n",
" for i, idx in enumerate(selected):\n",
" if i == 0:\n",
" real_fitness_values[i] = - (fitness_values[selected[lam-1]] - fitness_values[selected[i]]) / 2\n",
" else:\n",
" real_fitness_values[i] = real_fitness_values[0] + i * temp\n",
" \n",
" return real_fitness_values, selected"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"([-0.21914911424187478,\n",
" -0.1793038207433521,\n",
" -0.1394585272448294,\n",
" -0.09961323374630672,\n",
" -0.059767940247784035,\n",
" -0.019922646749261363,\n",
" 0.019922646749261336,\n",
" 0.059767940247784035,\n",
" 0.0996132337463067,\n",
" 0.13945852724482938,\n",
" 0.17930382074335205,\n",
" 0.21914911424187478],\n",
" array([10, 3, 6, 2, 0, 9, 8, 4, 11, 5, 7, 1], dtype=int64))"
]
},
"metadata": {},
"execution_count": 16
}
],
"source": [
"# 非关键代码,用于测试 real_objective\n",
"real_objectives(lam=12, population=np.random.normal(0, 1, (12, 10)), data_X=data[:,:10], data_y=data[:,10])"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def event_on_generation(gen, population, current_best, history_best):\n",
" best = history_best\n",
" if gen == 0:\n",
" best = current_best\n",
" print('iter: {gen}, best solution(bool): {best}, best idxs: {idxs}'.format(gen=gen, best=encode(best), idxs='-'.join([str(i) for i, v in enumerate(encode(best)) if v])))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"def NES(\n",
" n_dimension: int, \n",
" objective, \n",
" n_iter, \n",
" learn_rate, \n",
" lam,\n",
" mean,\n",
" cov,\n",
" random_state=None, \n",
" event_on_genration=event_on_generation):\n",
"\n",
" logger.debug({\n",
" 'dimension': n_dimension,\n",
" 'n_iter': n_iter,\n",
" 'learn_rate': learn_rate,\n",
" })\n",
"\n",
" tf.set_random_seed(random_state)\n",
" best, best_fitness = None, 1e+10\n",
" \n",
" def get_fitness(population): \n",
" return [-objective(solution) for solution in population]\n",
"\n",
" mean = tf.Variable(mean, dtype=tf.float32)\n",
" cov = tf.Variable(cov, dtype=tf.float32)\n",
" mvn = MultivariateNormalFullCovariance(loc=mean, covariance_matrix=cov)\n",
" make_population = mvn.sample(lam)\n",
"\n",
" fitness_input = tf.placeholder(tf.float32, [lam, ])\n",
" prob_output = tf.placeholder(tf.float32, [lam, n_dimension])\n",
" loss = -tf.reduce_mean(mvn.log_prob(prob_output) * fitness_input)\n",
" train_op = tf.train.GradientDescentOptimizer(learn_rate).minimize(loss)\n",
"\n",
" sess = tf.Session()\n",
" sess.run(tf.global_variables_initializer())\n",
"\n",
" for g in range(n_iter):\n",
" population = sess.run(make_population)\n",
"\n",
" scores = [objective(c) for c in population]\n",
" \n",
" event_on_genration(g, population, population[np.argsort(scores)[0]], best)\n",
"\n",
" fitness_values = get_fitness(population)\n",
" sess.run(train_op, {fitness_input: fitness_values, prob_output: population}) \n",
"\n",
" for i, score in enumerate([objective(x) for x in population]):\n",
" if score < best_fitness:\n",
" best, best_fitness = population[i], score\n",
"\n",
" return [best, best_fitness]"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"n_dimension = n_col - 1\n",
"n_iter = 10\n",
"learn_rate = 0.002\n",
"mu = 4\n",
"lam = 12\n",
"mean = data[:,:n_dimension].mean(axis=0)\n",
"cov = 10000 * tf.eye(n_dimension)\n",
"random_state = 1 # 固定输出\n",
"\n",
"def es_objective(x):\n",
" return -objective(encode(x), data[:,:n_dimension], data[:,n_dimension])"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"2021-07-11 16:14:51.616 | DEBUG | __main__:NES:15 - {'dimension': 10, 'n_iter': 10, 'learn_rate': 0.002}\n",
"iter: 0, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n",
"iter: 1, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n",
"iter: 2, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n",
"iter: 3, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n",
"iter: 4, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n",
"iter: 5, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n",
"iter: 6, best solution(bool): [True, True, True, True, False, True, True, True, True, False], best idxs: 0-1-2-3-5-6-7-8\n",
"iter: 7, best solution(bool): [True, True, False, False, False, True, True, True, False, False], best idxs: 0-1-5-6-7\n",
"iter: 8, best solution(bool): [True, True, False, False, False, True, True, True, False, False], best idxs: 0-1-5-6-7\n",
"iter: 9, best solution(bool): [True, True, False, False, False, True, True, True, False, False], best idxs: 0-1-5-6-7\n"
]
}
],
"source": [
"best, best_fitness = NES(n_dimension, es_objective, n_iter, learn_rate, lam, mean, cov, random_state, event_on_generation)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"best solution(bool): [True, True, False, False, False, True, True, True, False, False], best idxs: 0-1-5-6-7\n维度缩减率 DR: 0.5\n分类准确率 CA: 0.9532710280373832\n"
]
}
],
"source": [
"print('best solution(bool): {best}, best idxs: {idxs}'.format(best=encode(best), idxs='-'.join([str(i) for i, v in enumerate(encode(best)) if v])))\n",
"print('维度缩减率 DR: {dr}'.format(dr=DR(encode(best))))\n",
"print('分类准确率 CA: {ca}'.format(ca=CA(encode(best), data[:,:n_dimension], data[:,n_dimension])))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 结论\n",
"# 原先 10 个特征,经过基于 NES 的特征选择方法,特征数缩减至 5 个,同时准确率保持在 95% 以上"
]
}
]
}