{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Lecture 9\n", "\n", "## Growing Regression Trees" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IRdisplay::display_html('')" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 0 }, "source": [ "Here is the tree on the toy data we used in the slides." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "library(tree)\n", "data <- data.frame(X1 = 1:4, X2 = c(1, 4, 3, 2), Y = c(1, 11, 7, 3))\n", "toytree <- tree(Y ~ ., data, minsize = 1)\n", "plot(toytree)\n", "text(toytree)" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 0 }, "source": [ "We can also look at the splits, predicted response (column `yval`) and the loss\n", "(column `dev`) for every node." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "toytree$frame" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 0 }, "source": [ "In the following cell we load the Hitters data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "library(ISLR)\n", "Hitters <- na.omit(Hitters)\n", "logSalary = log(Hitters$Salary)\n", "hist(Hitters$Salary, nclass = 20)\n", "hist(log(Hitters$Salary), nclass = 20)" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 0 }, "source": [ "The histogram reveals that the salaries do not at all follow a Gaussian\n", "distribution. With the log-transformation the data does not really look\n", "Gaussian either, but at least is is slightly more symmetric than without.\n", "\n", "Let us fit a tree and plot the first 3 splits together with the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "hitters.tree <- tree(log(Salary) ~ Years + Hits, data = Hitters)\n", "plot(hitters.tree)\n", "text(hitters.tree)\n", "plot(Hitters$Years, Hitters$Hits, ylab = \"Hits\", xlab = \"Years\",\n", " col = hcl.colors(18, palette = \"RdYlBu\", rev = T)[10*(logSalary - 1.7)])\n", "abline(v = 4.5, col = \"red\")\n", "lines(c(4.5,30), c(117.5, 117.5), col = \"red\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pruning Regression Tres" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IRdisplay::display_html('')" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 0 }, "source": [ "We can simply use the function `prune.tree` with argument `best = 3` to find the\n", "best tree with 3 leaf nodes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hitters.tree.pruned3 <- prune.tree(hitters.tree, best = 3)\n", "plot(hitters.tree.pruned3)\n", "text(hitters.tree.pruned3)" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 0 }, "source": [ "In the following we define some functions to fit the Hitters data with 9\n", "predictors and run 6-fold cross-validation. We run 6-fold cross-validation,\n", "because our training data has size 132, which is a multiple of 6." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hitters.train <- function(train) {\n", " formula <- log(Salary) ~ Years + RBI + PutOuts + Hits + Walks + Runs + AtBat + HmRun + Assists\n", " tree(formula, Hitters, subset = train)\n", "}\n", "hitters.evaluate <- function(tree, set) {\n", " sapply(2:10, function(i) mean((log(Hitters[set,'Salary']) - predict(prune.tree(tree, best = i), Hitters[set,]))^2)) # We compute the mean squared error for all trees with 2 to 10 leaf nodes.\n", "}\n", "hitters.cv <- function(train) {\n", " res <- sapply(1:6, function(v) {\n", " idx.test <- seq((v-1)*22 + 1, v*22) # fold index\n", " this.fold.test <- train[idx.test] # validation\n", " this.fold.train <- train[-idx.test] # training\n", " tree <- hitters.train(this.fold.train)\n", " hitters.evaluate(tree, this.fold.test)\n", " })\n", " rowMeans(data.frame(res))\n", "}\n", "hitters.train.and.evaluate <- function() {\n", " train <- sample(nrow(Hitters), 132)\n", " tree <- hitters.train(train)\n", " list(train = hitters.evaluate(tree, train),\n", " test = hitters.evaluate(tree, -train),\n", " cv = hitters.cv(train),\n", " tree = tree)\n", "}\n", "set.seed(1)\n", "res <- replicate(100, hitters.train.and.evaluate()) # we run everything for 100 different training sets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function `hitters.train.and.evaluate` returns training and test errors, the\n", "cross-validation estimate of the test error and the full tree itself.\n", "We can plot individual trees. To look at other trees, change the `tree.index` to\n", "another number between 1 and 100. Do you see how different the trees are,\n", "depending on which training set they were fitted on?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tree.index <- 2\n", "example_tree <- res[4, tree.index]$tree\n", "plot(example_tree, col = 'darkgreen')\n", "text(example_tree)" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 0 }, "source": [ "To plot the results including error bars we define the function `std.plot`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "std.plot <- function(data, x = 2:10, ...) {\n", " df <- data.frame(data)\n", " m <- rowMeans(df)\n", " std <- sqrt(rowMeans((df - m)^2))\n", " points(x, m, type = \"b\", ...)\n", " arrows(x, m - std, x, m + std, length=0.05, angle=90, code=3, ...)\n", "}\n", "plot(c(), ylim = c(.1, .57), xlim = c(2, 10), xlab = \"Tree Size\", ylab = \"Mean Squared Error\")\n", "std.plot(res[1,])\n", "std.plot(res[2,], col = \"red\")\n", "std.plot(res[3,], col = \"blue\")\n", "legend(\"bottomleft\", c(\"train\", \"test\", \"CV\"), bty = 'n',\n", " col = c(\"black\", \"red\", \"blue\"), lty = 1)" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 0 }, "source": [ "We can conclude from the plot above that the optimal tree size is around 3 or 4.\n", "Let us plot this tree." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "final.tree <- prune.tree(hitters.train(1:nrow(Hitters)), best = 3)\n", "plot(final.tree)\n", "text(final.tree)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finding the optimal tree size with cross-validation is sometimes also called\n", "hyper-parameter tuning. If you are interested to see how the modern `tidymodels`\n", "library allows hyper-parameter tuning you can follow this\n", "[link](https://www.tidymodels.org/start/tuning/).\n", "\n", "## Classification Trees" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IRdisplay::display_html('')" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 0 }, "source": [ "In the following cell we fit a classification tree to the `Heart` data.\n", "The `as.factor` function is used to tell R which columns contain\n", "categorical data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Heart <-read.csv(\"http://faculty.marshall.usc.edu/gareth-james/ISL/Heart.csv\")[,-1]\n", "Heart$AHD <- as.factor(Heart$AHD)\n", "Heart$ChestPain <- as.factor(Heart$ChestPain)\n", "Heart$Thal <- as.factor(Heart$Thal)\n", "Heart$Sex <- as.factor(Heart$Sex)\n", "heart.tree <- tree(AHD ~ ., Heart)\n", "plot(heart.tree)\n", "text(heart.tree)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once you are ready, please answer the [quiz questions](https://moodle.epfl.ch/mod/quiz/view.php?id=1112503).\n", "\n", "## Trees Versus Other Methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IRdisplay::display_html('')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exercises\n", "### Conceptual\n", "\n", "**Q1.**\n", "(a) Draw an example (of your own invention) of a partition of two-dimensional feature space that could result from recursive binary splitting. Your example should contain at least six regions. Draw a decision tree corresponding to this partition. Be sure to label all aspects of your figures, including the regions $R_1,R_2,...$, the cutpoints $t_1,t_2,...$, and so forth.\n", "\n", "(b) Draw an example (of your own invention) of a partition of two-dimensional\n", "feature space that could not result from recursive binary splitting. Justify,\n", "why it cannot be the result from recursive binary splitting." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Q2.**\n", "![](img/8.12.png)\n", "\n", "(a) Sketch the tree corresponding to the partition of the predictor space illustrated in the left-hand panel of the figure above. The numbers inside the boxes indicate the mean of $Y$ within each region.\n", "\n", "(b) Create a diagram similar to the left-hand panel of the figure above, using the tree illustrated in the right-hand panel of the same figure. You should divide up the predictor space into the correct regions, and indicate the mean for each region.\n", "\n", "### Applied" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Q3.** In this exercise we will look at the \"Carseats\" data set. You can get information about it by loading `library(ISLR)` and running `?Carseats`. We will seek to predict \"Sales\" using trees.\n", "\n", "(a) Split the data set into a training set and a test set.\n", "\n", "(b) Fit a regression tree to the training set. Plot the tree, and interpret the results. What test error rate do you obtain?\n", "\n", "(c) Instead of treating \"Sales\" as a quantitative variable we could recode it as\n", "a qualitative variable by classifying the sales as \"low\" if they are below 5, \"medium\" if they are below 9 and \"high\" otherwise, i.e. we will introduce the new response variable `sales.class <- as.factor(ifelse(Carseats$Sales < 5, \"low\", ifelse(Carseats$Sales < 9, \"medium\", \"high\")))`. Fit a classification tree to the Carseats data with response `sales.class`. Do you get a similar tree as in (b)?\n", "\n", "(d) Use cross-validation in order to determine the optimal level of tree complexity for the tree in (c). Does pruning the tree improve the test error rate?\n", "\n", "**Q4.** (optional)\n", "Fit a classification tree to the Histopathalogic Cancer Detection data set that\n", "we studied in the last exercise of sheet 7* - part 2.\n", "To tell R that the 0s and 1s in PCaml_y should be treated as values of a\n", "categorical response, you may use `PCaml_y <- as.factor(PCaml_y)`.\n", "Compare your results to the ones obtained with linear regression and\n", "convolutional networks." ] } ], "metadata": { "kernelspec": { "display_name": "R", "language": "R", "name": "ir" }, "language_info": { "codemirror_mode": "r", "file_extension": ".r", "mimetype": "text/x-r-source", "name": "R", "pygments_lexer": "r", "version": "4.0.2" } }, "nbformat": 4, "nbformat_minor": 4 }