{ "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" }, "orig_nbformat": 4, "kernelspec": { "name": "python3", "display_name": "Python 3.6.13 64-bit ('pyes': conda)" }, "interpreter": { "hash": "5febc64283966fe46fc465ad8ab242eb0484fa28506faf7951a93cb8efa4edfd" } }, "nbformat": 4, "nbformat_minor": 2, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# 基于单一种群的特征选择方法\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "import tensorflow as tf\n", "import numpy as np\n", "import pyes\n", "import pandas as pd\n", "import math\n", "import matplotlib.pyplot as plt\n", "from tensorflow.python.framework.errors_impl import InvalidArgumentError\n", "from tensorflow.contrib.distributions import MultivariateNormalFullCovariance\n", "from loguru import logger\n", "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.svm import SVC\n", "\n", "tf.logging.set_verbosity(tf.logging.ERROR)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data = np.genfromtxt('../exampleData/glass.data', delimiter=',')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " 0 1 2 3 4 5 6 7 8 9 10\n", "0 1.0 1.52101 13.64 4.49 1.10 71.78 0.06 8.75 0.00 0.0 1.0\n", "1 2.0 1.51761 13.89 3.60 1.36 72.73 0.48 7.83 0.00 0.0 1.0\n", "2 3.0 1.51618 13.53 3.55 1.54 72.99 0.39 7.78 0.00 0.0 1.0\n", "3 4.0 1.51766 13.21 3.69 1.29 72.61 0.57 8.22 0.00 0.0 1.0\n", "4 5.0 1.51742 13.27 3.62 1.24 73.08 0.55 8.07 0.00 0.0 1.0\n", ".. ... ... ... ... ... ... ... ... ... ... ...\n", "209 210.0 1.51623 14.14 0.00 2.88 72.61 0.08 9.18 1.06 0.0 7.0\n", "210 211.0 1.51685 14.92 0.00 1.99 73.06 0.00 8.40 1.59 0.0 7.0\n", "211 212.0 1.52065 14.36 0.00 2.02 73.42 0.00 8.44 1.64 0.0 7.0\n", "212 213.0 1.51651 14.38 0.00 1.94 73.61 0.00 8.48 1.57 0.0 7.0\n", "213 214.0 1.51711 14.23 0.00 2.08 73.36 0.00 8.62 1.67 0.0 7.0\n", "\n", "[214 rows x 11 columns]" ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
012345678910
01.01.5210113.644.491.1071.780.068.750.000.01.0
12.01.5176113.893.601.3672.730.487.830.000.01.0
23.01.5161813.533.551.5472.990.397.780.000.01.0
34.01.5176613.213.691.2972.610.578.220.000.01.0
45.01.5174213.273.621.2473.080.558.070.000.01.0
....................................
209210.01.5162314.140.002.8872.610.089.181.060.07.0
210211.01.5168514.920.001.9973.060.008.401.590.07.0
211212.01.5206514.360.002.0273.420.008.441.640.07.0
212213.01.5165114.380.001.9473.610.008.481.570.07.0
213214.01.5171114.230.002.0873.360.008.621.670.07.0
\n

214 rows × 11 columns

\n
" }, "metadata": {}, "execution_count": 3 } ], "source": [ "# 数据概览\n", "pd.DataFrame(data)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "number of rows: 214, number of column: 11\n" ] } ], "source": [ "n_row, n_col = data.shape\n", "print('number of rows: {n_row}, number of column: {n_col}'.format(n_row=n_row, n_col=n_col))" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "def encode(solution):\n", " \"\"\"用于将个体编码成二进制形式,用于选择特征\"\"\"\n", " v_func = lambda x: 1.0 / (1.0 + np.exp(-x*2))\n", " return [True if v_func(x) > 0.5 else False for x in solution]" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[]" ] }, "metadata": {}, "execution_count": 51 }, { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n 2021-07-11T16:12:15.830556\r\n image/svg+xml\r\n \r\n \r\n Matplotlib v3.3.4, https://matplotlib.org/\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", "image/png": "\n" }, "metadata": { "needs_background": "light" } } ], "source": [ "# 非关键代码,用于验证 v_func 函数图像是否与文章中\"图 4.1\" 一致\n", "v_func = lambda x: 1.0 / (1.0 + np.exp(-x*2))\n", "x = np.linspace(-4, 4, 300)\n", "plt.plot(x, [v_func(v) for v in x])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def DR(encoded_x):\n", " \"\"\"计算特征维度缩减率\"\"\"\n", " return 1 - 1.0 * sum([1 for v in encoded_x if v]) / len(encoded_x)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " 0 1 2 3 4 5 6 7 8 9 10\n", "0 True True True False False True True True False True 0.3\n", "1 True False False False True True True False True True 0.4\n", "2 True False True True True True False False True False 0.4\n", "3 False False False False False False True True False False 0.8\n", "4 False False True False False True False False False True 0.7\n", "5 True True False False True False True True True False 0.4\n", "6 False False True True False True True False True True 0.4\n", "7 False True True False True True True False False True 0.4\n", "8 False True True False True False True False False False 0.6\n", "9 False True True True True False False False True True 0.4" ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
012345678910
0TrueTrueTrueFalseFalseTrueTrueTrueFalseTrue0.3
1TrueFalseFalseFalseTrueTrueTrueFalseTrueTrue0.4
2TrueFalseTrueTrueTrueTrueFalseFalseTrueFalse0.4
3FalseFalseFalseFalseFalseFalseTrueTrueFalseFalse0.8
4FalseFalseTrueFalseFalseTrueFalseFalseFalseTrue0.7
5TrueTrueFalseFalseTrueFalseTrueTrueTrueFalse0.4
6FalseFalseTrueTrueFalseTrueTrueFalseTrueTrue0.4
7FalseTrueTrueFalseTrueTrueTrueFalseFalseTrue0.4
8FalseTrueTrueFalseTrueFalseTrueFalseFalseFalse0.6
9FalseTrueTrueTrueTrueFalseFalseFalseTrueTrue0.4
\n
" }, "metadata": {}, "execution_count": 8 } ], "source": [ "# 非关键代码,用于测试 DR 函数\n", "test_encoded_xs = np.ones((10, 10), dtype=bool)\n", "mask = np.random.normal(0, 1, (10, 10)) > 0\n", "test_encoded_xs[mask] = False\n", "df = pd.DataFrame(test_encoded_xs)\n", "df[10] = [DR(x) for x in test_encoded_xs]\n", "df" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def CA(encode_x, data_X, data_y):\n", " \"\"\"计算分类准确率\"\"\"\n", " clf = make_pipeline(StandardScaler(), SVC(gamma='auto'))\n", " clf.fit(data_X[:,encode_x], data_y)\n", " return clf.score(data_X[:,encode_x], data_y)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "准确率为: 0.9345794392523364, 选中的特征下标: 0-1-3-6-7-8-9\n" ] } ], "source": [ "# 非关键代码,用于测试 CA 函数\n", "encode_x = np.ones(n_col-1, dtype=bool)\n", "encode_x[np.random.normal(0, 1, n_col-1) > 0] = False\n", "\n", "ca = CA(encode_x, data[:,:10], data[:,10])\n", "print('准确率为: {ca}, 选中的特征下标: {idxs_feature}'.format(ca=ca, idxs_feature='-'.join([str(i) for i, v in enumerate(encode_x) if v])))" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "# 目标函数包含 CA 和 DR 两部分内容,如果需要修改分类器,请修改 CA 函数\n", "# 说明:文章中\"图 4.2\" 曲线与给定的公式 4.2 是对不上的,所以按最曲线重新定义了个 rho 函数\n", "def rho(l):\n", " return 0.90 + 0.099 / (l/2000 + 1)\n", "\n", "def objective(encode_x, data_X, data_y):\n", " l = len(encode_x)\n", " return rho(l) * CA(encode_x, data_X, data_y) + (1 - rho(l)) * DR(encode_x)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Text(0, 0.5, 'Value of Rho')" ] }, "metadata": {}, "execution_count": 12 }, { "output_type": "display_data", "data": { "text/plain": "
", "image/svg+xml": "\r\n\r\n\r\n\r\n \r\n \r\n \r\n \r\n 2021-07-11T15:49:47.658519\r\n image/svg+xml\r\n \r\n \r\n Matplotlib v3.3.4, https://matplotlib.org/\r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n \r\n\r\n", "image/png": "\n" }, "metadata": { "needs_background": "light" } } ], "source": [ "# rho 函数的图示\n", "x = np.linspace(0, 100000, 10000)\n", "plt.plot(x, rho(x))\n", "plt.xlabel('Number of feature')\n", "plt.ylabel('Value of Rho')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " 0 1 2 3 4 5 6 7 8 9\n", "0 False True True True False False True False True True\n", "1 False True False True True True False False False False\n", "2 False True True False False False True True True False\n", "3 False True True False False True True True True True\n", "4 True True False True True False False True True True" ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
0123456789
0FalseTrueTrueTrueFalseFalseTrueFalseTrueTrue
1FalseTrueFalseTrueTrueTrueFalseFalseFalseFalse
2FalseTrueTrueFalseFalseFalseTrueTrueTrueFalse
3FalseTrueTrueFalseFalseTrueTrueTrueTrueTrue
4TrueTrueFalseTrueTrueFalseFalseTrueTrueTrue
\n
" }, "metadata": {}, "execution_count": 13 } ], "source": [ "# 非关键代码,目标函数的测试 1\n", "p = np.random.normal(0, 1, (20, 10)) # 测试种群\n", "encoded_p = [encode(x) for x in p]\n", "pd.DataFrame(encoded_p).head(5) # 编码后种群(前 5 个)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " 0 1 2 3 4 5 6 7 8 9 \\\n", "0 False True True True False False True False True True \n", "1 False True False True True True False False False False \n", "2 False True True False False False True True True False \n", "3 False True True False False True True True True True \n", "4 True True False True True False False True True True \n", "\n", " fitness \n", "0 0.742479 \n", "1 0.710116 \n", "2 0.695969 \n", "3 0.746995 \n", "4 0.924300 " ], "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
0123456789fitness
0FalseTrueTrueTrueFalseFalseTrueFalseTrueTrue0.742479
1FalseTrueFalseTrueTrueTrueFalseFalseFalseFalse0.710116
2FalseTrueTrueFalseFalseFalseTrueTrueTrueFalse0.695969
3FalseTrueTrueFalseFalseTrueTrueTrueTrueTrue0.746995
4TrueTrueFalseTrueTrueFalseFalseTrueTrueTrue0.924300
\n
" }, "metadata": {}, "execution_count": 14 } ], "source": [ "# 非关键代码,目标函数的测试 2\n", "df = pd.DataFrame(encoded_p)\n", "df['fitness'] = [objective(encoded_x, data[:,:10], data[:,10]) for encoded_x in encoded_p]\n", "df.head(5)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def real_objectives(lam, population, data_X, data_y):\n", " \"\"\"\n", " 适应值塑造,对应文章中的公式 4.3, 返回 lam 个塑造后的较优适应值,及各个体下标\n", " \"\"\"\n", " encoded_population = [encode(x) for x in population]\n", " fitness_values = [objective(x, data_X, data_y) for x in encoded_population]\n", " # 取前 lam 个适应值较优的个体下标,再按升序排列\n", " selected = np.argsort(fitness_values)[::-1][:lam][::-1]\n", "\n", " temp = 1.0 * (fitness_values[selected[lam-1]] - fitness_values[selected[0]]) / (lam - 1) \n", " \n", " real_fitness_values = [0] * lam\n", "\n", " for i, idx in enumerate(selected):\n", " if i == 0:\n", " real_fitness_values[i] = - (fitness_values[selected[lam-1]] - fitness_values[selected[i]]) / 2\n", " else:\n", " real_fitness_values[i] = real_fitness_values[0] + i * temp\n", " \n", " return real_fitness_values, selected" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "([-0.21914911424187478,\n", " -0.1793038207433521,\n", " -0.1394585272448294,\n", " -0.09961323374630672,\n", " -0.059767940247784035,\n", " -0.019922646749261363,\n", " 0.019922646749261336,\n", " 0.059767940247784035,\n", " 0.0996132337463067,\n", " 0.13945852724482938,\n", " 0.17930382074335205,\n", " 0.21914911424187478],\n", " array([10, 3, 6, 2, 0, 9, 8, 4, 11, 5, 7, 1], dtype=int64))" ] }, "metadata": {}, "execution_count": 16 } ], "source": [ "# 非关键代码,用于测试 real_objective\n", "real_objectives(lam=12, population=np.random.normal(0, 1, (12, 10)), data_X=data[:,:10], data_y=data[:,10])" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "def event_on_generation(gen, population, current_best, history_best):\n", " best = history_best\n", " if gen == 0:\n", " best = current_best\n", " print('iter: {gen}, best solution(bool): {best}, best idxs: {idxs}'.format(gen=gen, best=encode(best), idxs='-'.join([str(i) for i, v in enumerate(encode(best)) if v])))" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "def NES(\n", " n_dimension: int, \n", " objective, \n", " n_iter, \n", " learn_rate, \n", " lam,\n", " mean,\n", " cov,\n", " random_state=None, \n", " event_on_genration=event_on_generation):\n", "\n", " logger.debug({\n", " 'dimension': n_dimension,\n", " 'n_iter': n_iter,\n", " 'learn_rate': learn_rate,\n", " })\n", "\n", " tf.set_random_seed(random_state)\n", " best, best_fitness = None, 1e+10\n", " \n", " def get_fitness(population): \n", " return [-objective(solution) for solution in population]\n", "\n", " mean = tf.Variable(mean, dtype=tf.float32)\n", " cov = tf.Variable(cov, dtype=tf.float32)\n", " mvn = MultivariateNormalFullCovariance(loc=mean, covariance_matrix=cov)\n", " make_population = mvn.sample(lam)\n", "\n", " fitness_input = tf.placeholder(tf.float32, [lam, ])\n", " prob_output = tf.placeholder(tf.float32, [lam, n_dimension])\n", " loss = -tf.reduce_mean(mvn.log_prob(prob_output) * fitness_input)\n", " train_op = tf.train.GradientDescentOptimizer(learn_rate).minimize(loss)\n", "\n", " sess = tf.Session()\n", " sess.run(tf.global_variables_initializer())\n", "\n", " for g in range(n_iter):\n", " population = sess.run(make_population)\n", "\n", " scores = [objective(c) for c in population]\n", " \n", " event_on_genration(g, population, population[np.argsort(scores)[0]], best)\n", "\n", " fitness_values = get_fitness(population)\n", " sess.run(train_op, {fitness_input: fitness_values, prob_output: population}) \n", "\n", " for i, score in enumerate([objective(x) for x in population]):\n", " if score < best_fitness:\n", " best, best_fitness = population[i], score\n", "\n", " return [best, best_fitness]" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "n_dimension = n_col - 1\n", "n_iter = 10\n", "learn_rate = 0.002\n", "mu = 4\n", "lam = 12\n", "mean = data[:,:n_dimension].mean(axis=0)\n", "cov = 10000 * tf.eye(n_dimension)\n", "random_state = 1 # 固定输出\n", "\n", "def es_objective(x):\n", " return -objective(encode(x), data[:,:n_dimension], data[:,n_dimension])" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "2021-07-11 16:14:51.616 | DEBUG | __main__:NES:15 - {'dimension': 10, 'n_iter': 10, 'learn_rate': 0.002}\n", "iter: 0, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n", "iter: 1, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n", "iter: 2, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n", "iter: 3, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n", "iter: 4, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n", "iter: 5, best solution(bool): [True, True, False, True, True, True, False, False, False, True], best idxs: 0-1-3-4-5-9\n", "iter: 6, best solution(bool): [True, True, True, True, False, True, True, True, True, False], best idxs: 0-1-2-3-5-6-7-8\n", "iter: 7, best solution(bool): [True, True, False, False, False, True, True, True, False, False], best idxs: 0-1-5-6-7\n", "iter: 8, best solution(bool): [True, True, False, False, False, True, True, True, False, False], best idxs: 0-1-5-6-7\n", "iter: 9, best solution(bool): [True, True, False, False, False, True, True, True, False, False], best idxs: 0-1-5-6-7\n" ] } ], "source": [ "best, best_fitness = NES(n_dimension, es_objective, n_iter, learn_rate, lam, mean, cov, random_state, event_on_generation)" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "best solution(bool): [True, True, False, False, False, True, True, True, False, False], best idxs: 0-1-5-6-7\n维度缩减率 DR: 0.5\n分类准确率 CA: 0.9532710280373832\n" ] } ], "source": [ "print('best solution(bool): {best}, best idxs: {idxs}'.format(best=encode(best), idxs='-'.join([str(i) for i, v in enumerate(encode(best)) if v])))\n", "print('维度缩减率 DR: {dr}'.format(dr=DR(encode(best))))\n", "print('分类准确率 CA: {ca}'.format(ca=CA(encode(best), data[:,:n_dimension], data[:,n_dimension])))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 结论\n", "# 原先 10 个特征,经过基于 NES 的特征选择方法,特征数缩减至 5 个,同时准确率保持在 95% 以上" ] } ] }