{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "EE-311\n", "======\n", "\n", "Lab 7: Decision Trees, Bagging, Boosting\n", "----------------------------------------\n", "\n", "created by Arnaud Pannatier and Francois Marelli on 01.04.2020\n", "\n", "# Homework\n", "\n", "The file `homework.py` contains the homework of the week. It contains empty functions that must be completed according to the instructions.\n", "\n", "When the homework is completed, it must be submitted on Moodle for grading.\n", "\n", "**Do not change the function definitions in the file!**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to train your Decision Tree\n", "\n", "We want to understand how a decision tree can be trained with the CART algorithm.\n", "\n", "You will need to implement `get_separating_values`, `compute_probability_distribution` and `compute_gini` using the formulas in the course to get all the information you need for choosing the best split.\n", "\n", "We will illustrate this on a simple circle dataset, generated in the cell below." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2020-04-03T14:09:55.663477Z", "start_time": "2020-04-03T14:09:55.449580Z" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from sklearn.datasets import make_circles\n", "\n", "import importlib\n", "import homework\n", "importlib.reload(homework)\n", "\n", "X, y = make_circles(n_samples=20, noise=0.1, factor=0.3, random_state=0)\n", "\n", "plt.figure(figsize=(7, 7))\n", "plt.scatter(X[:, 0], X[:, 1], c=y)\n", "plt.axis('square')\n", "plt.xlabel(\"X1\")\n", "plt.ylabel(\"X2\")\n", "plt.grid()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following cell combines all your functions to complete the first step of the CART algorithm on the first dimension of our dataset.\n", "\n", "If everything is correct, you will see a vertical line representing the split that was selected." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2020-04-03T14:09:55.676330Z", "start_time": "2020-04-03T14:09:55.667859Z" } }, "outputs": [], "source": [ "X1 = X[:,0]\n", "\n", "# Create a list of all possible separating variables\n", "s_x1 = homework.get_separating_values(X1)\n", "\n", "# Best split initialization\n", "min_score = np.inf\n", "split = None\n", "\n", "# Iterate over the separating variables\n", "for split_point in s_x1:\n", " # Compute empirical distributions\n", " pl0, pr0 = homework.compute_probability_distribution(X1, y, split_point, 0)\n", " pl1, pr1 = homework.compute_probability_distribution(X1, y, split_point, 1)\n", "\n", " # Compute Gini impurity\n", " gini_l = homework.compute_gini(pl0, pl1)\n", " gini_r = homework.compute_gini(pr0, pr1)\n", "\n", " # Count points in each split\n", " n_l = (X1 < split_point).sum()\n", " n_r = (X1 > split_point).sum()\n", "\n", " # Compute the score of the split\n", " score = homework.compute_score(n_l, gini_l, n_r, gini_r)\n", " \n", " # Update best split if score is better\n", " if score < min_score:\n", " min_score = score\n", " split = split_point\n", "\n", "plt.figure(figsize=(7, 7))\n", "plt.scatter(X[:, 0], X[:, 1], c=y)\n", "plt.axis('square')\n", "plt.vlines(split, -1.5, 1.5)\n", "plt.xlabel(\"X1\")\n", "plt.ylabel(\"X2\")\n", "plt.grid()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test Cell" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2020-04-03T14:09:56.151806Z", "start_time": "2020-04-03T14:09:55.954749Z" } }, "outputs": [], "source": [ "import unittest\n", "import testing\n", "\n", "import importlib\n", "importlib.reload(testing)\n", "\n", "unittest.main(module=testing, argv=['first-arg-is-ignored'], exit=False)" ] } ], "metadata": { "interpreter": { "hash": "775b8dc5e77b6cd469b087e562ff893173a7e39eafdcd4ff38a1582b848ee085" }, "kernelspec": { "display_name": "Python 3.9.2 64-bit ('EE-311': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }