{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "EE-311\n", "======\n", "\n", "Lab 7: Decision Trees, Bagging, Boosting\n", "----------------------------------------\n", "\n", "created by Arnaud Pannatier and Francois Marelli on 01.04.2020\n", "\n", "# Homework\n", "\n", "The file `homework.py` contains the homework of the week. It contains empty functions that must be completed according to the instructions.\n", "\n", "When the homework is completed, it must be submitted on Moodle for grading.\n", "\n", "**Do not change the function definitions in the file!**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to train your Decision Tree\n", "\n", "We want to understand how a decision tree can be trained with the CART algorithm.\n", "\n", "You will need to implement `get_separating_values`, `compute_probability_distribution` and `compute_gini` using the formulas in the course to get all the information you need for choosing the best split.\n", "\n", "We will illustrate this on a simple circle dataset, generated in the cell below." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2020-04-03T14:09:55.663477Z", "start_time": "2020-04-03T14:09:55.449580Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbgAAAGpCAYAAAD/QCONAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAhi0lEQVR4nO3deZRcZZ3/8fe3q3pLZyEhENawi8MiIBFRdGhAkMUR3BhGVFyjjjDjHGdcxp+D+zA/9efIyMjEEUEdjcsMEjWCgPQAorI44ICAxKgQAgIRk/SSXp/fH11gJ90d0qSrbtXT79c5fVJ1762qz3nOpT/cp+69HSklJEnKTVPRASRJqgYLTpKUJQtOkpQlC06SlCULTpKUpXLRAaph4cKFae+9957Sa3p6eujo6KhOoBnEcdx+juH0cBynR72P42233fZYSmmnidZlWXB77703t95665Re09XVRWdnZ3UCzSCO4/ZzDKeH4zg96n0cI+K3k61zilKSlCULTpKUJQtOkpQlC06SlCULTpKUJQtOkpQlC06SlCULTpKUJQtOkpQlC06SlCULTpKUJQtOkpQlC07TamhgiI+e9WnOmH8OZ+2xlK989FsMDQ4VHUvSDGTBadpsWLeR++9+kBu+9WN61veybu3jLP/Hy/n42Z8pOpqkGciC07T57rKrGRkZYWQkPbmsv2+An373Nh5c9VCBySTNRBacps0vbrqXNKbcnlBuKbP65/cXkEjSTGbBadrsddCeEOOXDw+NsOu+O9c+kKQZzYLTtDn9HS8mYvOGK7eU2ffQxex/+D4FpZI0U1lwmjY7L96JPQ7cjX0OXUypXKLcUub5pz+Hj3///UVHkzQDlYsOoLy0zWpl2R2fondjH+WWMi2tzUVHUqZSSmzq7aelrZlSqVR0HNUhj+BUFbPmtFtuqppbrrqdcw44j5fNP4fT572Of/2bLzI4MFh0LNUZj+AkNZR7b1nFh17xCfp7BwAYHhpg5bJr6H68h3dfem7B6VRPPIKT1FC++vH/YqBvYLNl/X0D/Pc3bmLDuo0FpVI9suAkNZT773mQNP5yS8otZR554LHaB1LdsuAkNZRnLNmPptL4X11DA0Psuu+iAhKpXllwkhrK2e9/BS1tLZsta53Vyhl/dSodc2cVlEr1yIKT1FAWP3N3Pn3Dhzn8+ENo62hl58ULefMFZ/Pmfzy76GiqM55FKanh7H/4PnzimvOLjqE65xGcJClLFpwkKUsWnCQpSxacJClLFpwkKUsWnCQpSxacJClLFpwkKUuFFlxEXBIRj0TEnZOsj4i4MCJWRcTPI+LZtc4oSWpMRR/BXQqcvJX1pwAHVH6WAp+rQSZJUgYKLbiU0vXA77eyyenAl9KonwA7RMSutUknSWpk9X4vyt2BB8Y8X1NZ9tCWG0bEUkaP8li0aBFdXV1T+qDu7u4pv0bjOY7bzzGcHo7j9Gjkcaz3gosJlk3wpw4hpbQMWAawZMmS1NnZOaUP6urqYqqv0XiO4/ZzDKeH4zg9Gnkci/4O7qmsAfYc83wPYG1BWSRJDaTeC24F8LrK2ZRHA+tTSuOmJyVJ2lKhU5QR8TWgE1gYEWuA84FmgJTSxcBK4FRgFdALvKGYpJKkRlNowaWU/uIp1ifgHTWKI0nKSL1PUUqS9LRYcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLFlwkqQsWXCSpCxZcJKkLJWLDiBJ2n593X18/ws/5McrbmX+onmccd4pHPS8A4uOVSgLTpIaXF93H3/5nPfy6P2P0d83QATctOIW3vapc3jJW08qOl5hnKKUpAb33X+7mkcq5QaQEvT3DnDxu75EX8+mgtMVx4KTpAb3o2/fzECl3MYqlZv45S2/KiBRfbDgJKnBzVs4d8LlI8MjzFkwu8Zp6offwRVooH+QH371Rn58xS3ssGgef/b2k9j/8H2KjiWpwZxx3incdvXP6e/tf3JZU1Ow054L2efQxQUmK5YFV5D+vn7e+YIPsOaXa9nU009TU3DtV67nvIvezItff1zR8SQ1kCOOP5RzPnQml35gOeWWMmkksWDX+Xzse+8jIoqOVxgLriBXfvE6Hrj3Qfp7R+fNR0YS/X0DfPa8L3Dsmc+nbVZrwQklNZJXveulnPKmE7jn5lXM3XE2Bzx73xldbuB3cIW5/hs/frLcxmoqNXHPT+8rIJGkRjd7hw6WnHQYzzhyvxlfbmDBFWb2/I4Jl4+MJNrntNc4jSTlx4IryEv/8sW0dWw+DRkBO+w0l2ccuW9BqSQpHxZcQY488TD+/N2n09zazKy57bTPaWfH3Rbw8ZV/79SCJE0DTzIp0Gs+8CpOe+tJ3HnjPcxdMJtDXvhMSqVS0bEkKQsWXMHm7zyPF778uUXHkKTsOEUpScqSBSdJypIFJ0nKUqEFFxEnR8S9EbEqIt47wfrOiFgfEbdXfv6hiJySpMZT2EkmEVECLgJOBNYAt0TEipTSL7bY9IaU0ktqHlCS1NCKPII7CliVUlqdUhoAlgOnF5hHkpSRIi8T2B14YMzzNcBE58s/LyLuANYCf5tSumuiN4uIpcBSgEWLFtHV1TWlMN3d3VN+jcZzHLefYzg9HMfp0cjjWGTBTXS7jrTF858Be6WUuiPiVODbwAETvVlKaRmwDGDJkiWps7NzSmG6urqY6ms0nuO4/RzD6eE4To9GHscipyjXAHuOeb4Ho0dpT0opbUgpdVcerwSaI2Jh7SJKkhpVkQV3C3BAROwTES3AWcCKsRtExC5RuTFjRBzFaN51NU8qSWo4hU1RppSGIuJc4CqgBFySUrorIt5WWX8x8Erg7RExBPQBZ6WUtpzGlCRpnELvRVmZdly5xbKLxzz+LPDZWueSJDU+72QiScqSBSdJypIFJ0nKkgUnScqSBSdJypIFJ0nKkgUnScqSBSdJypIFJ0nKkgUnScqSBSdJypIFJ0nKkgUnScqSBSdJypIFJ0nKkgUnScqSBSdJypIFJ0nKkgUnScqSBSdJypIFJ0nKkgUnScqSBSdJypIFJ0nKkgUnScqSBSdJypIFJ0nKkgUnScqSBSdJylK56ABSrlLqI/V8AfpWAE3Q/iqi47VEtBQdTZoRLDipClIaJq07G4buA/pHF3Z/hjRwPcy/lIgoNJ80EzhFKVVD/3/D8GqeLDcANsHg7TD4s4JCSTOLBSdVQRr8H0i9E60YLTlJVWfBSVUQpV2A9glWtELToprnkWYiC06qhraXQJS2WBhAC7SdWEQiacax4KQqiKZ5xIIvQ2kvoA1ohdL+xI7/QURr0fGkGcGzKKUqieaDYeEPYPhBiBJR2rXoSNKMYsFJVRQRUN6j6BjSjOQUpSQpSxacJClLFpwkKUsWnCQpSxacJClLFpwkKUsWnCQpSxacJClLFpwkKUsWnCQpSxacJClL3otSylwaup/UtwJSL9F2AjQ/e/QemVLmLDgpYyO9/wkbPggMA8Ok3v+A9pNh7gWWnLLnFKWUqTTyh0q59QNDQAL6YNOVMHBTgcmk2rDgpFz13wgxwSRN6iNt+l7t80g1ZsFJuYrmyVYAk62T8mHBSblqeQEwMsGKVqL9jBqHkWrPgpMyFU0dxA4XAu3ALKB19KfjTUTLEcWGk2rAsyiljEXrsbDz9bDpGkh90PqnRHlx0bGkmrDgpMxF0zyY9YqiY0g1t9UpyoiYGxH7TbD8WdWLJEnS9pu04CLiTOAe4D8j4q6IeM6Y1ZdWO5jyloZWk3q/Sdp0HSkNFh1HUoa2NkX598CRKaWHIuIo4MsR8fcppf9i9DxjacpSGiFteB/0fR8IiCaIdljwZaI8brJAkp62rRVcOaX0EEBK6eaIOA74bkTswegtEaSp23TF6J002DT6PAGpl/T4X8LCK719lKRps7Xv4DaM/f6tUnadwOnAwVXOpUyl3q+Nns23+VIYfhiGf11IJkl52lrBvYctpiJTShuBk4GPVTOUMpb6J14eMfk6SXoatlZwlwGviPjjzewiYhHwReDPqh1MmWo7DWibaAWUn1HrNJIytrWCOxLYB/ifiDg+Iv4auBn4MfDcWoRTfqLjtVDel9E7awC0AO3EDp8iolRgMkm5mfQkk5TS48DbKsV2DbAWODqltKZW4Wqp+w89vP3Id/PomnU886j9eePHXs2+z9qr6FjZiWiHHb8Jm35AGvgRNO1CzHolUdqt6GhZS8OPweD/QmknKB/syTyaESYtuIjYAfgnRo/WTgZOBb4fEX+dUvphbeLVxncuvorf9T3Cqv8ZPcnh5pU/446uu7jwpo+xz6GW3HSLaIb204j204qOkr2UEmnjJ6H3MogWSMNQ2h0WXEKUdik6nlRVW5ui/BlwH7AkpfSDlNI7gdcCH42Ir9UiXC0MDQ7xhfd9lTTyxysfUoL+3gG++IHlBSaTpkH/VdD3FWAAUjfQB8O/Jj1+btHJpKrb2nVwf7rldGRK6Xbg+RHxlqqmqqF1ax9naHB43PKUEvf89L4CEknTJ/VcOsFlGcMwdC9p+EGitHsRsaSamPQIbmvftaWUPl+dOLU3d+GczY7exlq01041TiNNs5H1Ey+PMoxsrG0WqcZm/N+Da+9o46RzjiWaNv/SvXVWC6/5wCsLSiVNk7YTGT1TdUtlKO9f6zRSTc34ggN4x4VvZO6Oc2hpa6alvYU5C2Zz7oVv4rmnHVl0NGm7RMebRs+cfPLaw6bRx3M/wphLXKUsFbqHR8TJwGeAEvDvKaULtlgflfWnAr3A61NKP5vuHOXmMjsvXsh/rfsiGx/vYf6ieZRKXpOlxhdN82DH75B6vw4DN0JpN2LWa4jmPyk6mlR1hRVcjF7VexFwIrAGuCUiVqSUfjFms1OAAyo/zwU+RxUvMm9tb6W1vbVaby8VIppmE7PfBLyp6ChSTRU5RXkUsCqltDqlNAAsZ/RGzmOdDnwpjfoJsENE7FrroKo/afgR0vDDRceQVMeKnKLcHXhgzPM1jD86m2ib3YGHtnyziFgKLAVYtGgRXV1dUwrT3d095ddovOqP4wAMPQA8cWPmZijtCTHR/S0bk/vi9HAcp0cjj2ORBTfRvYK2PF9/W7YZXZjSMmAZwJIlS1JnZ+eUwnR1dTHV12i8ao5jSv2kR46F9Dib7QYxh9jpOqJpblU+t9bcF6eH4zg9Gnkci5yiXAPsOeb5Hoze73Kq22im2HQto0duW/w/ThqCTd8tIpGkOlbkEdwtwAERsQ/wIHAW8OottlkBnBsRyxmdvlz/xF8Z1ww08jCkgQlW9JGGHpzwcF/SeI89uI5rvnI9G9Z1s+TFh3PE8YdkeQPuwgoupTQUEecCVzF6mcAlKaW7IuJtlfUXAysZvURgFaOXCbyhqLyqA83PgmiGNLj58phFtBxRTCapwfx05c/4yJmfYmR4hMH+Ib5z8Q847NiD+NDl76ZUzuvyqEKvg0sprWS0xMYuu3jM4wS8o9a5VKeaj4Tys2DwdmBTZWErlPaG1s7CYkmNYqB/kI+/+p/p7/3jTMim7k3c0XUXXV+/iRPOfmGB6aafdzJRw4gIYsG/w+x3QGmv0bMnO95MLPiqd+WQtsHdP/7lhMs39fRz9Zf/u8Zpqs/fCmooES3E7LfC7LcWHUVqOKXy5Mc05ea8pifBIzhJmjH+5Ohn0Nwy/rimraOVk994fAGJqsuCk6QZolQu8aFvv4f2OW20z26jubWZ1vYWjjvrBRxzxlFFx5t2TlFKqqo0cBup5xIYfghajyE63kA0LSg61ox18PMPZPmaZfzo2zez8ffdPPtFz2Lvg/d86hc2IAtOUtWM9F4OG87nyQv0h35J6vtP2PEKouQfFC7KrDntnPjaY4uOUXVOUUqqipQGYONHGb2k44m7zwzAyHpSz+cLTKaZwoKTVB1Dq4GRCVYMQn9+p6Sr/lhwkqqjaYfxd515ct2ONY2imcmCk1QVUdoFWp7NuK/6o53o8I+vqvosOElVEzt8BpoPA1ohZo/+2/EOou2EoqNpBvAsSklVE03ziR2/Rhq6H0YehfKBRNPsomNphrDgJFVdlBcDi4uOoRnGgpMykoZ+S+q+EAZuhtLORMdbibaTqv+5aQAGbhr9e30tzyOa5lT9M6WnYsFJmUhD95PWvQxSLzACI78jrf870vBamjpeX73PHbiF9PiYm1+nQdLcD9E06+VV+0xpW3iSiZSJ1HPRH8vtyYV90P3PpNRfnc8c6SU9vhRS9x9/6IcNHyQNra7KZ0rbyoKTcjFwKxNfWB0w/EB1PrP/uklWDJH6vl2dz5S2kQUn5aK0+8TL02D1LqxOPZAmKtUhGNlYnc+UtpEFJ2UiOt4KtG+xtBVajyea5lfnQ1uPYcKjxphFtL2oOp8pbSMLTspEtB4Dc8+HmMdo0bVA24nEDv9Uvc8s7Q4db4ZoB6KydBa0HAMtz6/a50rbwrMopYw0zXo5qf2lMLwWmubX5HT9pjl/TWp9PqnvW5D6ibbToPUEIuKpXyxVkQUnZSaiDOXaXlQdLc8hWp5T08+UnopTlJKkLFlwkqQsWXDKSkqDjGz8F0Z+dzQjDx/KyO/fSBpaVXQsSQWw4JSVtP490PN5SL8H+mHgR6R1Z5KGHy46mqQas+DUENLArYz8/q2MPPZnjGz4yISFlYYfgk1XA5vGLoXUT+q5tFZRJdUJz6JU3Rvp+w6sfz9PFtfQr0h9K2DhFURptz9uOPQriBYYd9/FQRj8ea3iSqoTHsGprqU0BBs+zOZHZUOQukndF22+cWnx6J9rGacM5QOrmFJSPbLgVN+G1wCDE62A/hs3WxLlxdB6NNC6+abRTHS8oVoJJdUpC071rWkepKGJ15UWjlsUO/wLtL+M0ZJrgvJBxPwvVf6itKSZxO/gVNeiaT6p9YXQfwMwdvqxnehYOn77aCPmfZg094PAEBEtNUoqqd54BKe6F/M+AS1HAa0Qs4E2mP02ou3Fk78mmiw3aYbzCE51L5pmEwsuGb0MYPgRKO9HNM0uOpakOmfBqWFEaVco7Vp0DEkNwilKSVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWykV8aEQsAL4O7A38BjgzpfT4BNv9BtgIDANDKaUltUspSWpkRR3BvRe4NqV0AHBt5flkjkspHW65SZKmoqiCOx24rPL4MuCMgnJIkjIVKaXaf2jEH1JKO4x5/nhKaf4E2/0aeBxIwL+llJZt5T2XAksBFi1adOTy5cunlKm7u5vZs2dP6TUaz3Hcfo7h9HAcp0e9j+Nxxx1322QzfFX7Di4irgF2mWDV+6fwNseklNZGxM7A1RFxT0rp+ok2rJTfMoAlS5akzs7OKeXt6upiqq/ReI7j9nMMp4fjOD0aeRyrVnAppRdNti4ifhcRu6aUHoqIXYFHJnmPtZV/H4mIy4GjgAkLTpKksYr6Dm4FcE7l8TnAFVtuEBEdETHnicfAScCdNUsoSWpoRRXcBcCJEXEfcGLlORGxW0SsrGyzCLgxIu4Abga+l1K6spC0kqSGU8h1cCmldcAJEyxfC5xaebwaOKzG0SRJmfBOJpKkLFlwkqQsWXCSVLCe9T3cddO9PPybCU8o19NUyHdwkiRIKfGVj3yL5RdcTnNrM4P9gxx8zDM5/1vvomNeR9HxGp5HcJJUkK6v38Q3PnEFA5sG6Vnfy8CmQe684W4ueN1ni46WBQtOkgryjU9ewaae/s2WDQ4McdsPbmfD7zcWlCofFpwkFWT9oxsmXF4ql+h+vKfGafJjwUlSQY488VmUyuN/Dbe0t7Bo750KSJQXC06SCvLa889k1rxZlFtGz/eLgNZZLfzVRW+hVCoVnK7xeRalJBVk5z0X8vmf/z+++akV3HHdXeyyz8686m9fykFHP6PoaFmw4CSpQDvuOp+3ffKcp95QU+YUpSQpSxacJClLFpwkKUsWnCQpSxacJClLFpwkKUteJtCA1v7qYa75yvX0rO/l6JccyeHHHUJEFB1LkuqKBddgrv3qDXz6LRczPDTM0OAwKz9/Dc85+Qj+z9f/hqYmD8gl6Qn+RmwgPRt6+fRbLqa/b4ChwWEANvX0c8uVt/OT79xWcDpJqi8WXAO547q7KDWPvz/dpp5N/HD5jQUkkqT6ZcE1kInKDUZv0Nrc4myzJI1lwTWQI44/hGD8ySQt7a2c/IbjC0gkSfXLgmsgLW0tfPDyv6Oto5X22W20trfQ0tbMy847hcM6Dy46niTVFee1Gszhxx3C19d+npuuuIW+jX0cedJh7LbfLkXHkqS6Y8E1oFlz2nnRa/606BiSVNecopQkZcmCkyRlyYKTJGXJgpMkZcmCkyRlyYKTJGXJgpMkZcmCkyRlyYKTJGXJgpMkZcmCkyRlyYKTJGXJgpMkZcmCkyRlyYKTJGXJgpMkZcmCkyRlyYKTJGXJgpMkZcmCk8b49Z3387cnfJCTW/6cM+afw7J3f4mB/sGiY0l6GspFB5DqxaNr1vHOF/wfejf0AdCzvpcrLrqKB+97mA9d/u6C00maKo/gpIrLL1zJwKbNj9YG+ga49arbeWj17wpKJenpsuCkivtuW83QwNC45c2tzdx/95oCEknaHhacVLHfEXtTbimNWz44MMQeB+5WQCJJ28OCkype/len0tzavNmylrZmjjj+EHbff9eCUkl6uiw4qWLnxTvx6es/wsHHHEg0BW0drZz65hP4h2++q+hokp4Gz6KUxtjvsL355xs+SkqJiCg6jqTt4BGcNAHLTWp8FpwkKUsWnCQpSxacJClLFpwkKUsWnCQpSxacJClLFpwkKUsWnCQpSxacJClLFpwkKUsWnCQpSxacJClLFpwkKUsWnCQpSxacpIbS/Yce/u/rP8tpHWdzSutZfOD0f+KR+x8tOpbqkAUnqWGMjIzwrs7zuW75jxjoG2BocJibv3cb5z73ffR19xUdT3XGgpPUMO7ououHVv+OoYGhJ5eNjCT6ujfxw6/eWGAy1SMLTlLDuP/uBxkeGh63fFNPP6v/97cFJFI9s+AkNYzFf7I7pXJp3PK2jlb2PXSvAhKpnhVScBHxqoi4KyJGImLJVrY7OSLujYhVEfHeWmaUVH8O6zyYXfbZmXJL+cllTaUm2jraOP7VLygwmepRUUdwdwIvB66fbIOIKAEXAacABwF/EREH1SaepHrU1NTEp7o+ROefP5/m1mZK5Saec8oRfPan/0j77Pai46nOlJ96k+mXUrobICK2ttlRwKqU0urKtsuB04FfVD2gpLo1Z/5s3nPZebznsvNIKT3V7xHNYIUU3DbaHXhgzPM1wHMn2zgilgJLARYtWkRXV9eUPqy7u3vKr9F4juP2cwynh+M4PRp5HKtWcBFxDbDLBKven1K6YlveYoJlabKNU0rLgGUAS5YsSZ2dndsS80ldXV1M9TUaz3Hcfo7h9HAcp0cjj2PVCi6l9KLtfIs1wJ5jnu8BrN3O95QkzRD1fJnALcABEbFPRLQAZwErCs4kSWoQRV0m8LKIWAM8D/heRFxVWb5bRKwESCkNAecCVwF3A99IKd1VRF5JUuMp6izKy4HLJ1i+Fjh1zPOVwMoaRpMkZaKepyglSXraLDhJUpYsOElSliw4SVKWLDhJUpYsOElSliw4SVKWIqVJb+/YsCLiUWCqf953IfBYFeLMNI7j9nMMp4fjOD3qfRz3SintNNGKLAvu6YiIW1NKk/7xVW0bx3H7OYbTw3GcHo08jk5RSpKyZMFJkrJkwf3RsqIDZMJx3H6O4fRwHKdHw46j38FJkrLkEZwkKUsWnCQpSzO24CLiVRFxV0SMRMSkp8BGxMkRcW9ErIqI99YyY72LiAURcXVE3Ff5d/4k2/0mIv43Im6PiFtrnbNePdW+FaMurKz/eUQ8u4ic9WwbxrAzItZX9r3bI+IfishZ7yLikoh4JCLunGR9Q+6LM7bggDuBlwPXT7ZBRJSAi4BTgIOAv4iIg2oTryG8F7g2pXQAcG3l+WSOSykd3qjX00y3bdy3TgEOqPwsBT5X05B1bgr/fd5Q2fcOTyl9uKYhG8elwMlbWd+Q++KMLbiU0t0ppXufYrOjgFUppdUppQFgOXB69dM1jNOByyqPLwPOKC5Kw9mWfet04Etp1E+AHSJi11oHrWP+9zlNUkrXA7/fyiYNuS/O2ILbRrsDD4x5vqayTKMWpZQeAqj8u/Mk2yXgBxFxW0QsrVm6+rYt+5b739Zt6/g8LyLuiIjvR8TBtYmWnYbcF8tFB6imiLgG2GWCVe9PKV2xLW8xwbIZdV3F1sZwCm9zTEppbUTsDFwdEfdU/o9xJtuWfWvG739PYVvG52eM3quwOyJOBb7N6DSbpqYh98WsCy6l9KLtfIs1wJ5jnu8BrN3O92woWxvDiPhdROyaUnqoMl3xyCTvsbby7yMRcTmjU0szveC2Zd+a8fvfU3jK8UkpbRjzeGVE/GtELEwp1fPNg+tRQ+6LTlFu3S3AARGxT0S0AGcBKwrOVE9WAOdUHp8DjDsqjoiOiJjzxGPgJEZP8JnptmXfWgG8rnIG29HA+iemhAVswxhGxC4REZXHRzH6O29dzZM2vobcF7M+gtuaiHgZ8C/ATsD3IuL2lNKLI2I34N9TSqemlIYi4lzgKqAEXJJSuqvA2PXmAuAbEfEm4H7gVQBjxxBYBFxe+R1TBr6aUrqyoLx1Y7J9KyLeVll/MbASOBVYBfQCbygqbz3axjF8JfD2iBgC+oCzkrdvGicivgZ0AgsjYg1wPtAMjb0veqsuSVKWnKKUJGXJgpMkZcmCkyRlyYKTJGXJgpMkZcmCkxpMROwZEb+OiAWV5/Mrz/eKiCsj4g8R8d2ic0pFs+CkBpNSeoDRu7lfUFl0AbAspfRb4BPAa4vKJtUTC05qTJ8Gjo6IdwIvAD4FkFK6FthYYC6pbszYO5lIjSylNBgRfwdcCZxU+XMxksbwCE5qXKcADwGHFB1EqkcWnNSAIuJw4ETgaOBvGuGPT0q1ZsFJDaZyd/zPAe9MKd3P6Iklnyw2lVR/LDip8bwFuD+ldHXl+b8Cz4yIYyPiBuCbwAkRsSYiXlxYSqlg/jUBSVKWPIKTJGXJgpMkZcmCkyRlyYKTJGXJgpMkZcmCkyRlyYKTJGXp/wOOHx/Ns4yXdwAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from sklearn.datasets import make_circles\n", "\n", "import importlib\n", "import homework\n", "importlib.reload(homework)\n", "\n", "X, y = make_circles(n_samples=20, noise=0.1, factor=0.3, random_state=0)\n", "\n", "plt.figure(figsize=(7, 7))\n", "plt.scatter(X[:, 0], X[:, 1], c=y)\n", "plt.axis('square')\n", "plt.xlabel(\"X1\")\n", "plt.ylabel(\"X2\")\n", "plt.grid()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following cell combines all your functions to complete the first step of the CART algorithm on the first dimension of our dataset.\n", "\n", "If everything is correct, you will see a vertical line representing the split that was selected." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2020-04-03T14:09:55.676330Z", "start_time": "2020-04-03T14:09:55.667859Z" } }, "outputs": [], "source": [ "X1 = X[:,0]\n", "\n", "# Create a list of all possible separating variables\n", "s_x1 = homework.get_separating_values(X1)\n", "\n", "# Best split initialization\n", "min_score = np.inf\n", "split = None\n", "\n", "# Iterate over the separating variables\n", "for split_point in s_x1:\n", " # Compute empirical distributions\n", " pl0, pr0 = homework.compute_probability_distribution(X1, y, split_point, 0)\n", " pl1, pr1 = homework.compute_probability_distribution(X1, y, split_point, 1)\n", "\n", " # Compute Gini impurity\n", " gini_l = homework.compute_gini(pl0, pl1)\n", " gini_r = homework.compute_gini(pr0, pr1)\n", "\n", " # Count points in each split\n", " n_l = (X1 < split_point).sum()\n", " n_r = (X1 > split_point).sum()\n", "\n", " # Compute the score of the split\n", " score = homework.compute_score(n_l, gini_l, n_r, gini_r)\n", " \n", " # Update best split if score is better\n", " if score < min_score:\n", " min_score = score\n", " split = split_point\n", "\n", "plt.figure(figsize=(7, 7))\n", "plt.scatter(X[:, 0], X[:, 1], c=y)\n", "plt.axis('square')\n", "plt.vlines(split, -1.5, 1.5)\n", "plt.xlabel(\"X1\")\n", "plt.ylabel(\"X2\")\n", "plt.grid()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test Cell" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2020-04-03T14:09:56.151806Z", "start_time": "2020-04-03T14:09:55.954749Z" } }, "outputs": [], "source": [ "import unittest\n", "import testing\n", "\n", "import importlib\n", "importlib.reload(testing)\n", "\n", "unittest.main(module=testing, argv=['first-arg-is-ignored'], exit=False)" ] } ], "metadata": { "interpreter": { "hash": "775b8dc5e77b6cd469b087e562ff893173a7e39eafdcd4ff38a1582b848ee085" }, "kernelspec": { "display_name": "Python 3.9.2 64-bit ('EE-311': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }