{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "metadata": false, "name": "#%%\n" } }, "source": [ "This code is heavily derived from [Decision Tree from a Scratch](https://medium.com/@rakendd/decision-tree-from-scratch-9e23bcfb4928) (ID3 algorithm)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "pycharm": { "is_executing": false, "metadata": false, "name": "#%%\n" } }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "eps = np.finfo(float).eps\n", "dataset = {'Taste': ['Salty', 'Spicy', 'Spicy', 'Spicy', 'Spicy', 'Sweet', 'Salty', 'Sweet', 'Spicy', 'Salty'],\n", " 'Temperature': ['Hot', 'Hot', 'Hot', 'Cold', 'Hot', 'Cold', 'Cold', 'Hot', 'Cold', 'Hot'],\n", " 'Texture': ['Soft', 'Soft', 'Hard', 'Hard', 'Hard', 'Soft', 'Soft', 'Soft', 'Soft', 'Hard'],\n", " 'Eat': ['No', 'No', 'Yes', 'No', 'Yes', 'Yes', 'No', 'Yes', 'Yes', 'Yes']}\n", "\n", "df = pd.DataFrame(dataset, columns=['Taste', 'Temperature', 'Texture', 'Eat'])" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "pycharm": { "is_executing": false, "metadata": false, "name": "#%%\n" } }, "outputs": [], "source": [ "def find_entropy(df):\n", " entropy = 0\n", " for value in ('Yes', 'No'):\n", " fraction = df['Eat'].value_counts()[value] / len(df['Eat'])\n", " entropy += -fraction * np.log2(fraction)\n", " return entropy" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['Eat'].value_counts()['Yes']" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 Salty\n", "6 Salty\n", "9 Salty\n", "Name: Taste, dtype: object" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['Taste'][df['Taste'] == 'Salty']" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(df[df['Eat'] == 'Yes'])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "pycharm": { "is_executing": false, "metadata": false, "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "0.9709505944546686" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "find_entropy(df)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "pycharm": { "is_executing": false, "metadata": false, "name": "#%%\n" } }, "outputs": [], "source": [ "def find_entropy_attribute(df, attribute):\n", " entropy_sum = 0\n", " for variable in df[attribute].unique():\n", " entropy = 0\n", " for target_variable in ('Yes', 'No'):\n", " num = len(df[attribute][df[attribute] == variable][df['Eat'] == target_variable])\n", " den = len(df[attribute][df[attribute] == variable])\n", " fraction = num / (den + eps)\n", " entropy += -fraction * np.log2(fraction + eps)\n", " entropy_weights = den / len(df)\n", " entropy_sum += -entropy_weights * entropy\n", " return abs(entropy_sum)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "pycharm": { "is_executing": false, "metadata": false, "name": "#%%\n" } }, "outputs": [], "source": [ "def build_tree(df):\n", " IG = []\n", " for key in df.keys()[:-1]:\n", " IG.append(find_entropy(df) - find_entropy_attribute(df, key))\n", " node = df.keys()[:-1][np.argmax(IG)] # Get attribute with maximum information gain\n", "\n", " tree = {}\n", " tree[node] = {}\n", "\n", " for value in np.unique(df[node]):\n", " sub_table = df[df[node] == value].reset_index(drop=True)\n", " cls, counts = np.unique(sub_table['Eat'], return_counts=True)\n", "\n", " if len(counts) == 1: # Checking purity of subset\n", " tree[node][value] = cls[0]\n", " else:\n", " tree[node][value] = build_tree(sub_table) # Calling the function recursively\n", "\n", " return tree" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "pycharm": { "is_executing": false, "metadata": false, "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'Taste': {'Salty': {'Texture': {'Hard': 'Yes', 'Soft': 'No'}},\n", " 'Spicy': {'Temperature': {'Cold': {'Texture': {'Hard': 'No',\n", " 'Soft': 'Yes'}},\n", " 'Hot': {'Texture': {'Hard': 'Yes',\n", " 'Soft': 'No'}}}},\n", " 'Sweet': 'Yes'}}\n" ] } ], "source": [ "tree = build_tree(df)\n", "\n", "import pprint\n", "pprint.pprint(tree)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "pycharm": { "is_executing": false, "metadata": false, "name": "#%%\n" } }, "outputs": [], "source": [ "def predict(inst, tree):\n", " for nodes in tree.keys():\n", " value = inst[nodes]\n", " tree = tree[nodes][value]\n", " prediction = 0\n", "\n", " if type(tree) is dict:\n", " prediction = predict(inst, tree)\n", " else:\n", " prediction = tree\n", " break\n", "\n", " return prediction" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "pycharm": { "is_executing": false, "metadata": false, "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "Taste Salty\n", "Temperature Cold\n", "Texture Soft\n", "Name: 6, dtype: object" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inst = df.iloc[6][:-1]\n", "inst" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "pycharm": {} }, "outputs": [ { "data": { "text/plain": [ "'No'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.iloc[6][-1]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "pycharm": { "is_executing": false, "metadata": false, "name": "#%%\n" } }, "outputs": [ { "data": { "text/plain": [ "'No'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "prediction = predict(inst, tree)\n", "prediction" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" }, "stem_cell": { "cell_type": "raw", "metadata": { "pycharm": { "metadata": false } }, "source": "" }, "pycharm": { "stem_cell": { "cell_type": "raw", "source": [], "metadata": { "collapsed": false } } } }, "nbformat": 4, "nbformat_minor": 1 }