diff --git a/src/lecture9.ipynb b/src/lecture9.ipynb
new file mode 100644
index 0000000..623ef96
--- /dev/null
+++ b/src/lecture9.ipynb
@@ -0,0 +1,410 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Lecture 9\n",
+ "\n",
+ "## Growing Regression Trees"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "IRdisplay::display_html('')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "Here is the tree on the toy data we used in the slides."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "library(tree)\n",
+ "data <- data.frame(X1 = 1:4, X2 = c(1, 4, 3, 2), Y = c(1, 11, 7, 3))\n",
+ "toytree <- tree(Y ~ ., data, minsize = 1)\n",
+ "plot(toytree)\n",
+ "text(toytree)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "We can also look at the splits, predicted response (column `yval`) and the loss\n",
+ "(column `dev`) for every node."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "toytree$frame"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "In the following cell we load the Hitters data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "outputs": [],
+ "source": [
+ "library(ISLR)\n",
+ "Hitters <- na.omit(Hitters)\n",
+ "logSalary = log(Hitters$Salary)\n",
+ "hist(Hitters$Salary, nclass = 20)\n",
+ "hist(log(Hitters$Salary), nclass = 20)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "The histogram reveals that the salaries do not at all follow a Gaussian\n",
+ "distribution. With the log-transformation the data does not really look\n",
+ "Gaussian either, but at least is is slightly more symmetric than without.\n",
+ "\n",
+ "Let us fit a tree and plot the first 3 splits together with the data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "lines_to_next_cell": 2
+ },
+ "outputs": [],
+ "source": [
+ "hitters.tree <- tree(log(Salary) ~ Years + Hits, data = Hitters)\n",
+ "plot(hitters.tree)\n",
+ "text(hitters.tree)\n",
+ "plot(Hitters$Years, Hitters$Hits, ylab = \"Hits\", xlab = \"Years\",\n",
+ " col = hcl.colors(18, palette = \"RdYlBu\", rev = T)[10*(logSalary - 1.7)])\n",
+ "abline(v = 4.5, col = \"red\")\n",
+ "lines(c(4.5,30), c(117.5, 117.5), col = \"red\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Pruning Regression Tres"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "IRdisplay::display_html('')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "We can simply use the function `prune.tree` with argument `best = 3` to find the\n",
+ "best tree with 3 leaf nodes."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hitters.tree.pruned3 <- prune.tree(hitters.tree, best = 3)\n",
+ "plot(hitters.tree.pruned3)\n",
+ "text(hitters.tree.pruned3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "In the following we define some functions to fit the Hitters data with 9\n",
+ "predictors and run 6-fold cross-validation. We run 6-fold cross-validation,\n",
+ "because our training data has size 132, which is a multiple of 6."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hitters.train <- function(train) {\n",
+ " formula <- log(Salary) ~ Years + RBI + PutOuts + Hits + Walks + Runs + AtBat + HmRun + Assists\n",
+ " tree(formula, Hitters, subset = train)\n",
+ "}\n",
+ "hitters.evaluate <- function(tree, set) {\n",
+ " sapply(2:10, function(i) mean((log(Hitters[set,'Salary']) - predict(prune.tree(tree, best = i), Hitters[set,]))^2)) # We compute the mean squared error for all trees with 2 to 10 leaf nodes.\n",
+ "}\n",
+ "hitters.cv <- function(train) {\n",
+ " res <- sapply(1:6, function(v) {\n",
+ " idx.test <- seq((v-1)*22 + 1, v*22) # fold index\n",
+ " this.fold.test <- train[idx.test] # validation\n",
+ " this.fold.train <- train[-idx.test] # training\n",
+ " tree <- hitters.train(this.fold.train)\n",
+ " hitters.evaluate(tree, this.fold.test)\n",
+ " })\n",
+ " rowMeans(data.frame(res))\n",
+ "}\n",
+ "hitters.train.and.evaluate <- function() {\n",
+ " train <- sample(nrow(Hitters), 132)\n",
+ " tree <- hitters.train(train)\n",
+ " list(train = hitters.evaluate(tree, train),\n",
+ " test = hitters.evaluate(tree, -train),\n",
+ " cv = hitters.cv(train),\n",
+ " tree = tree)\n",
+ "}\n",
+ "set.seed(1)\n",
+ "res <- replicate(100, hitters.train.and.evaluate()) # we run everything for 100 different training sets"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The function `hitters.train.and.evaluate` returns training and test errors, the\n",
+ "cross-validation estimate of the test error and the full tree itself.\n",
+ "We can plot individual trees. To look at other trees, change the `tree.index` to\n",
+ "another number between 1 and 100. Do you see how different the trees are,\n",
+ "depending on which training set they were fitted on?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tree.index <- 2\n",
+ "example_tree <- res[4, tree.index]$tree\n",
+ "plot(example_tree, col = 'darkgreen')\n",
+ "text(example_tree)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "To plot the results including error bars we define the function `std.plot`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "std.plot <- function(data, x = 2:10, ...) {\n",
+ " df <- data.frame(data)\n",
+ " m <- rowMeans(df)\n",
+ " std <- sqrt(rowMeans((df - m)^2))\n",
+ " points(x, m, type = \"b\", ...)\n",
+ " arrows(x, m - std, x, m + std, length=0.05, angle=90, code=3, ...)\n",
+ "}\n",
+ "plot(c(), ylim = c(.1, .57), xlim = c(2, 10), xlab = \"Tree Size\", ylab = \"Mean Squared Error\")\n",
+ "std.plot(res[1,])\n",
+ "std.plot(res[2,], col = \"red\")\n",
+ "std.plot(res[3,], col = \"blue\")\n",
+ "legend(\"bottomleft\", c(\"train\", \"test\", \"CV\"), bty = 'n',\n",
+ " col = c(\"black\", \"red\", \"blue\"), lty = 1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "We can conclude from the plot above that the optimal tree size is around 3 or 4.\n",
+ "Let us plot this tree."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "final.tree <- prune.tree(hitters.train(1:nrow(Hitters)), best = 3)\n",
+ "plot(final.tree)\n",
+ "text(final.tree)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Finding the optimal tree size with cross-validation is sometimes also called\n",
+ "hyper-parameter tuning. If you are interested to see how the modern `tidymodels`\n",
+ "library allows hyper-parameter tuning you can follow this\n",
+ "[link](https://www.tidymodels.org/start/tuning/).\n",
+ "\n",
+ "## Classification Trees"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "IRdisplay::display_html('')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "lines_to_next_cell": 0
+ },
+ "source": [
+ "In the following cell we fit a classification tree to the `Heart` data.\n",
+ "The `as.factor` function is used to tell R which columns contain\n",
+ "categorical data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Heart <-read.csv(\"http://faculty.marshall.usc.edu/gareth-james/ISL/Heart.csv\")[,-1]\n",
+ "Heart$AHD <- as.factor(Heart$AHD)\n",
+ "Heart$ChestPain <- as.factor(Heart$ChestPain)\n",
+ "Heart$Thal <- as.factor(Heart$Thal)\n",
+ "Heart$Sex <- as.factor(Heart$Sex)\n",
+ "heart.tree <- tree(AHD ~ ., Heart)\n",
+ "plot(heart.tree)\n",
+ "text(heart.tree)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Once you are ready, please answer the [quiz questions](https://moodle.epfl.ch/mod/quiz/view.php?id=1112503).\n",
+ "\n",
+ "## Trees Versus Other Methods"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "IRdisplay::display_html('')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Exercises\n",
+ "### Conceptual\n",
+ "\n",
+ "**Q1.**\n",
+ "(a) Draw an example (of your own invention) of a partition of two-dimensional feature space that could result from recursive binary splitting. Your example should contain at least six regions. Draw a decision tree corresponding to this partition. Be sure to label all aspects of your figures, including the regions $R_1,R_2,...$, the cutpoints $t_1,t_2,...$, and so forth.\n",
+ "\n",
+ "(b) Draw an example (of your own invention) of a partition of two-dimensional\n",
+ "feature space that could not result from recursive binary splitting. Justify,\n",
+ "why it cannot be the result from recursive binary splitting."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Q2.**\n",
+ "![](img/8.12.png)\n",
+ "\n",
+ "(a) Sketch the tree corresponding to the partition of the predictor space illustrated in the left-hand panel of the figure above. The numbers inside the boxes indicate the mean of $Y$ within each region.\n",
+ "\n",
+ "(b) Create a diagram similar to the left-hand panel of the figure above, using the tree illustrated in the right-hand panel of the same figure. You should divide up the predictor space into the correct regions, and indicate the mean for each region.\n",
+ "\n",
+ "### Applied"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Q3.** In this exercise we will look at the \"Carseats\" data set. You can get information about it by loading `library(ISLR)` and running `?Carseats`. We will seek to predict \"Sales\" using trees.\n",
+ "\n",
+ "(a) Split the data set into a training set and a test set.\n",
+ "\n",
+ "(b) Fit a regression tree to the training set. Plot the tree, and interpret the results. What test error rate do you obtain?\n",
+ "\n",
+ "(c) Instead of treating \"Sales\" as a quantitative variable we could recode it as\n",
+ "a qualitative variable by classifying the sales as \"low\" if they are below 5, \"medium\" if they are below 9 and \"high\" otherwise, i.e. we will introduce the new response variable `sales.class <- as.factor(ifelse(Carseats$Sales < 5, \"low\", ifelse(Carseats$Sales < 9, \"medium\", \"high\")))`. Fit a classification tree to the Carseats data with response `sales.class`. Do you get a similar tree as in (b)?\n",
+ "\n",
+ "(d) Use cross-validation in order to determine the optimal level of tree complexity for the tree in (c). Does pruning the tree improve the test error rate?\n",
+ "\n",
+ "**Q4.** (optional)\n",
+ "Fit a classification tree to the Histopathalogic Cancer Detection data set that\n",
+ "we studied in the last exercise of sheet 7* - part 2.\n",
+ "To tell R that the 0s and 1s in PCaml_y should be treated as values of a\n",
+ "categorical response, you may use `PCaml_y <- as.factor(PCaml_y)`.\n",
+ "Compare your results to the ones obtained with linear regression and\n",
+ "convolutional networks."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "R",
+ "language": "R",
+ "name": "ir"
+ },
+ "language_info": {
+ "codemirror_mode": "r",
+ "file_extension": ".r",
+ "mimetype": "text/x-r-source",
+ "name": "R",
+ "pygments_lexer": "r",
+ "version": "4.0.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}