{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Decision Trees from scratch" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import csv\n", "from pathlib import Path\n", "from copy import deepcopy\n", "from typing import List, Tuple, Dict, NamedTuple, Any\n", "from collections import Counter, defaultdict" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Ensure that we have a `data` directory we use to store downloaded data\n", "!mkdir -p data\n", "data_dir: Path = Path('data')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2020-02-23 10:52:54-- https://raw.githubusercontent.com/husnainfareed/Simple-Naive-Bayes-Weather-Prediction/c75b2fa747956ee9b5f9da7b2fc2865be04c618c/new_dataset.csv\n", "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.192.133, 151.101.0.133, 151.101.64.133, ...\n", "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.192.133|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 373 [text/plain]\n", "Saving to: ‘data/golf.csv’\n", "\n", "golf.csv 100%[===================>] 373 --.-KB/s in 0s \n", "\n", "2020-02-23 10:52:55 (11.5 MB/s) - ‘data/golf.csv’ saved [373/373]\n", "\n" ] } ], "source": [ "# Downloading the \"Golf\" data set\n", "!wget -O \"data/golf.csv\" -nc -P data https://raw.githubusercontent.com/husnainfareed/Simple-Naive-Bayes-Weather-Prediction/c75b2fa747956ee9b5f9da7b2fc2865be04c618c/new_dataset.csv" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Outlook,Temp,Humidity,Windy,Play\n", "Rainy,Hot,High,f,no\n", "Rainy,Hot,High,t,no\n", "Overcast,Hot,High,f,yes\n", "Sunny,Mild,High,f,yes\n" ] } ], "source": [ "!head -n 5 data/golf.csv" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Create the Python path pointing to the `golf.csv` file\n", "golf_data_path: Path = data_dir / 'golf.csv'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Evey entry in our data set is represented as a `DataPoint`\n", "class DataPoint(NamedTuple):\n", " outlook: str\n", " temp: str\n", " humidity: str\n", " windy: bool\n", " play: bool" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Open the file, iterate over every row, create a `DataPoint` and append it to a list\n", "data_points: List[DataPoint] = []\n", "\n", "with open(golf_data_path) as csv_file:\n", " reader = csv.reader(csv_file, delimiter=',')\n", " next(reader, None)\n", " for row in reader:\n", " outlook: str = row[0].lower()\n", " temp: str = row[1].lower()\n", " humidty: str = row[2].lower()\n", " windy: bool = True if row[3].lower() == 't' else False\n", " play: bool = True if row[4].lower() == 'yes' else False\n", " data_point: DataPoint = DataPoint(outlook, temp, humidty, windy, play)\n", " data_points.append(data_point)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[DataPoint(outlook='rainy', temp='hot', humidity='high', windy=False, play=False),\n", " DataPoint(outlook='rainy', temp='hot', humidity='high', windy=True, play=False),\n", " DataPoint(outlook='overcast', temp='hot', humidity='high', windy=False, play=True),\n", " DataPoint(outlook='sunny', temp='mild', humidity='high', windy=False, play=True),\n", " DataPoint(outlook='sunny', temp='cool', humidity='normal', windy=False, play=True)]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_points[:5]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Calculate the Gini impurity for a list of values\n", "# See: https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity\n", "def gini(data: List[Any]) -> float:\n", " counter: Counter = Counter(data)\n", " classes: List[Any] = list(counter.keys())\n", " num_items: int = len(data)\n", " result: float = 0\n", " item: Any\n", " for item in classes:\n", " p_i: float = counter[item] / num_items\n", " result += p_i * (1 - p_i)\n", " return result\n", "\n", "assert gini(['one', 'one']) == 0\n", "assert gini(['one', 'two']) == 0.5\n", "assert gini(['one', 'two', 'one', 'two']) == 0.5\n", "assert 0.8 < gini(['one', 'two', 'three', 'four', 'five']) < 0.81" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Helper function to filter down a list of data points by a `feature` and its `value`\n", "def filter_by_feature(data_points: List[DataPoint], *args) -> List[DataPoint]:\n", " result: List[DataPoint] = deepcopy(data_points)\n", " for arg in args:\n", " feature: str = arg[0]\n", " value: Any = arg[1]\n", " result = [data_point for data_point in result if getattr(data_point, feature) == value]\n", " return result\n", "\n", "assert len(filter_by_feature(data_points, ('outlook', 'sunny'))) == 5\n", "assert len(filter_by_feature(data_points, ('outlook', 'sunny'), ('temp', 'mild'))) == 3\n", "assert len(filter_by_feature(data_points, ('outlook', 'sunny'), ('temp', 'mild'), ('humidity', 'high'))) == 2" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Helper function to extract the values the `feature` in question can assume\n", "def feature_values(data_points: List[DataPoint], feature: str) -> List[Any]:\n", " return list(set([getattr(dp, feature) for dp in data_points]))\n", "\n", "assert feature_values(data_points, 'outlook').sort() == ['sunny', 'overcast', 'rainy'].sort()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Calculate the weighted sum of the Gini impurities for the `feature` in question\n", "def gini_for_feature(data_points: List[DataPoint], feature: str, label: str = 'play') -> float:\n", " total: int = len(data_points)\n", " # Distinct values the `feature` in question can assume\n", " dist_values: List[Any] = feature_values(data_points, feature)\n", " # Calculate all the Gini impurities for every possible value a `feature` can assume\n", " ginis: Dict[str, float] = defaultdict(float)\n", " ratios: Dict[str, float] = defaultdict(float)\n", " for value in dist_values:\n", " filtered: List[DataPoint] = filter_by_feature(data_points, (feature, value))\n", " labels: List[Any] = [getattr(dp, label) for dp in filtered]\n", " ginis[value] = gini(labels)\n", " # We use the ratio when we compute the weighted sum later on\n", " ratios[value] = len(labels) / total\n", " # Calculate the weighted sum of the `feature` in question\n", " weighted_sum: float = sum([ratios[key] * value for key, value in ginis.items()])\n", " return weighted_sum\n", "\n", "assert 0.34 < gini_for_feature(data_points, 'outlook') < 0.35\n", "assert 0.44 < gini_for_feature(data_points, 'temp') < 0.45\n", "assert 0.36 < gini_for_feature(data_points, 'humidity') < 0.37\n", "assert 0.42 < gini_for_feature(data_points, 'windy') < 0.43" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# NOTE: We can't use type hinting here due to cyclic dependencies\n", "\n", "# A `Node` has a `value` and optional out `Edge`s\n", "class Node:\n", " def __init__(self, value):\n", " self._value = value\n", " self._edges = []\n", "\n", " def __repr__(self):\n", " if len(self._edges):\n", " return f'{self._value} --> {self._edges}'\n", " else:\n", " return f'{self._value}'\n", " \n", " @property\n", " def value(self):\n", " return self._value\n", "\n", " def add_edge(self, edge):\n", " self._edges.append(edge)\n", " \n", " def find_edge(self, value):\n", " return next(edge for edge in self._edges if edge.value == value)\n", "\n", "# An `Edge` has a value and points to a `Node`\n", "class Edge:\n", " def __init__(self, value):\n", " self._value = value\n", " self._node = None\n", "\n", " def __repr__(self):\n", " return f'{self._value} --> {self._node}'\n", " \n", " @property\n", " def value(self):\n", " return self._value\n", " \n", " @property\n", " def node(self):\n", " return self._node\n", " \n", " @node.setter\n", " def node(self, node):\n", " self._node = node" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Recursively build a tree via the CART algorithm based on our list of data points\n", "def build_tree(data_points: List[DataPoint], features: List[str], label: str = 'play') -> Node:\n", " # Ensure that the `features` list doesn't include the `label`\n", " features.remove(label) if label in features else None\n", "\n", " # Compute the weighted Gini impurity for each `feature` given that we'd split the tree at the `feature` in question\n", " weighted_sums: Dict[str, float] = defaultdict(float)\n", " for feature in features:\n", " weighted_sums[feature] = gini_for_feature(data_points, feature)\n", "\n", " # If all the weighted Gini impurities are 0.0 we create a final `Node` (leaf) with the given `label`\n", " weighted_sum_vals: List[float] = list(weighted_sums.values())\n", " if (float(0) in weighted_sum_vals and len(set(weighted_sum_vals)) == 1):\n", " label = getattr(data_points[0], 'play')\n", " return Node(label) \n", " \n", " # The `Node` with the most minimal weighted Gini impurity is the one we should use for splitting\n", " min_feature = min(weighted_sums, key=weighted_sums.get)\n", " node: Node = Node(min_feature)\n", " \n", " # Remove the `feature` we've processed from the list of `features` which still need to be processed\n", " reduced_features: List[str] = deepcopy(features)\n", " reduced_features.remove(min_feature)\n", "\n", " # Next up we build the `Edge`s which are the values our `min_feature` can assume\n", " for value in feature_values(data_points, min_feature):\n", " # Create a new `Edge` which contains a potential `value` of our `min_feature`\n", " edge: Edge = Edge(value)\n", " # Add the `Edge` to our `Node`\n", " node.add_edge(edge)\n", " # Filter down the data points we'll use next since we've just processed the set which includes our `min_feature`\n", " reduced_data_points: List[DataPoint] = filter_by_feature(data_points, (min_feature, value))\n", " # This `Edge` points to the new `Node` (subtree) we'll create through recursion\n", " edge.node = build_tree(reduced_data_points, reduced_features)\n", "\n", " # Return the `Node` (our `min_feature`)\n", " return node" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "outlook --> [overcast --> True, sunny --> windy --> [False --> True, True --> False], rainy --> humidity --> [normal --> True, high --> False]]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a new tree based on the loaded data points\n", "features: List[str] = list(DataPoint._fields)\n", "\n", "tree: Node = build_tree(data_points, features)\n", "tree" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Traverse the tree based on the query trying to find a leaf with the prediction\n", "def predict(tree: Node, query: List[Tuple[str, str]]) -> Any:\n", " node: Node = deepcopy(tree)\n", " for item in query:\n", " feature: str = item[0]\n", " value: Any = item[1]\n", " if node.value != feature:\n", " continue\n", " edge: Edge = node.find_edge(value)\n", " if not edge:\n", " raise Exception(f'Edge with value \"{value}\" not found on Node \"{node}\"')\n", " node: Node = edge.node\n", " return node\n", "\n", "assert predict(tree, [('outlook', 'overcast')]) != True\n", "assert predict(tree, [('outlook', 'sunny'), ('windy', False)]) != True\n", "assert predict(tree, [('outlook', 'sunny'), ('windy', True)]) != False\n", "assert predict(tree, [('outlook', 'rainy'), ('humidity', 'high')]) != False\n", "assert predict(tree, [('outlook', 'rainy'), ('humidity', 'normal')]) != True\n", "assert predict(tree, [('outlook', 'rainy'), ('windy', True), ('humidity', 'normal')]) != True" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predict(tree, [('outlook', 'rainy'), ('humidity', 'normal')])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }