{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lecture 9\n",
"\n",
"## Growing Regression Trees"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IRdisplay::display_html('')"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"Here is the tree on the toy data we used in the slides."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"library(tree)\n",
"data <- data.frame(X1 = 1:4, X2 = c(1, 4, 3, 2), Y = c(1, 11, 7, 3))\n",
"toytree <- tree(Y ~ ., data, minsize = 1)\n",
"plot(toytree)\n",
"text(toytree)"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"We can also look at the splits, predicted response (column `yval`) and the loss\n",
"(column `dev`) for every node."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"toytree$frame"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"In the following cell we load the Hitters data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 0
},
"outputs": [],
"source": [
"library(ISLR)\n",
"Hitters <- na.omit(Hitters)\n",
"logSalary = log(Hitters$Salary)\n",
"hist(Hitters$Salary, nclass = 20)\n",
"hist(log(Hitters$Salary), nclass = 20)"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"The histogram reveals that the salaries do not at all follow a Gaussian\n",
"distribution. With the log-transformation the data does not really look\n",
"Gaussian either, but at least is is slightly more symmetric than without.\n",
"\n",
"Let us fit a tree and plot the first 3 splits together with the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"hitters.tree <- tree(log(Salary) ~ Years + Hits, data = Hitters)\n",
"plot(hitters.tree)\n",
"text(hitters.tree)\n",
"plot(Hitters$Years, Hitters$Hits, ylab = \"Hits\", xlab = \"Years\",\n",
" col = hcl.colors(18, palette = \"RdYlBu\", rev = T)[10*(logSalary - 1.7)])\n",
"abline(v = 4.5, col = \"red\")\n",
"lines(c(4.5,30), c(117.5, 117.5), col = \"red\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pruning Regression Tres"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IRdisplay::display_html('')"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"We can simply use the function `prune.tree` with argument `best = 3` to find the\n",
"best tree with 3 leaf nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hitters.tree.pruned3 <- prune.tree(hitters.tree, best = 3)\n",
"plot(hitters.tree.pruned3)\n",
"text(hitters.tree.pruned3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"In the following we define some functions to fit the Hitters data with 9\n",
"predictors and run 6-fold cross-validation. We run 6-fold cross-validation,\n",
"because our training data has size 132, which is a multiple of 6."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hitters.train <- function(train) {\n",
" formula <- log(Salary) ~ Years + RBI + PutOuts + Hits + Walks + Runs + AtBat + HmRun + Assists\n",
" tree(formula, Hitters, subset = train)\n",
"}\n",
"hitters.evaluate <- function(tree, set) {\n",
" sapply(2:10, function(i) mean((log(Hitters[set,'Salary']) - predict(prune.tree(tree, best = i), Hitters[set,]))^2)) # We compute the mean squared error for all trees with 2 to 10 leaf nodes.\n",
"}\n",
"hitters.cv <- function(train) {\n",
" res <- sapply(1:6, function(v) {\n",
" idx.test <- seq((v-1)*22 + 1, v*22) # fold index\n",
" this.fold.test <- train[idx.test] # validation\n",
" this.fold.train <- train[-idx.test] # training\n",
" tree <- hitters.train(this.fold.train)\n",
" hitters.evaluate(tree, this.fold.test)\n",
" })\n",
" rowMeans(data.frame(res))\n",
"}\n",
"hitters.train.and.evaluate <- function() {\n",
" train <- sample(nrow(Hitters), 132)\n",
" tree <- hitters.train(train)\n",
" list(train = hitters.evaluate(tree, train),\n",
" test = hitters.evaluate(tree, -train),\n",
" cv = hitters.cv(train),\n",
" tree = tree)\n",
"}\n",
"set.seed(1)\n",
"res <- replicate(100, hitters.train.and.evaluate()) # we run everything for 100 different training sets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The function `hitters.train.and.evaluate` returns training and test errors, the\n",
"cross-validation estimate of the test error and the full tree itself.\n",
"We can plot individual trees. To look at other trees, change the `tree.index` to\n",
"another number between 1 and 100. Do you see how different the trees are,\n",
"depending on which training set they were fitted on?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tree.index <- 2\n",
"example_tree <- res[4, tree.index]$tree\n",
"plot(example_tree, col = 'darkgreen')\n",
"text(example_tree)"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"To plot the results including error bars we define the function `std.plot`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"std.plot <- function(data, x = 2:10, ...) {\n",
" df <- data.frame(data)\n",
" m <- rowMeans(df)\n",
" std <- sqrt(rowMeans((df - m)^2))\n",
" points(x, m, type = \"b\", ...)\n",
" arrows(x, m - std, x, m + std, length=0.05, angle=90, code=3, ...)\n",
"}\n",
"plot(c(), ylim = c(.1, .57), xlim = c(2, 10), xlab = \"Tree Size\", ylab = \"Mean Squared Error\")\n",
"std.plot(res[1,])\n",
"std.plot(res[2,], col = \"red\")\n",
"std.plot(res[3,], col = \"blue\")\n",
"legend(\"bottomleft\", c(\"train\", \"test\", \"CV\"), bty = 'n',\n",
" col = c(\"black\", \"red\", \"blue\"), lty = 1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"We can conclude from the plot above that the optimal tree size is around 3 or 4.\n",
"Let us plot this tree."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"final.tree <- prune.tree(hitters.train(1:nrow(Hitters)), best = 3)\n",
"plot(final.tree)\n",
"text(final.tree)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finding the optimal tree size with cross-validation is sometimes also called\n",
"hyper-parameter tuning. If you are interested to see how the modern `tidymodels`\n",
"library allows hyper-parameter tuning you can follow this\n",
"[link](https://www.tidymodels.org/start/tuning/).\n",
"\n",
"## Classification Trees"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IRdisplay::display_html('')"
]
},
{
"cell_type": "markdown",
"metadata": {
"lines_to_next_cell": 0
},
"source": [
"In the following cell we fit a classification tree to the `Heart` data.\n",
"The `as.factor` function is used to tell R which columns contain\n",
"categorical data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Heart <-read.csv(\"http://faculty.marshall.usc.edu/gareth-james/ISL/Heart.csv\")[,-1]\n",
"Heart$AHD <- as.factor(Heart$AHD)\n",
"Heart$ChestPain <- as.factor(Heart$ChestPain)\n",
"Heart$Thal <- as.factor(Heart$Thal)\n",
"Heart$Sex <- as.factor(Heart$Sex)\n",
"heart.tree <- tree(AHD ~ ., Heart)\n",
"plot(heart.tree)\n",
"text(heart.tree)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once you are ready, please answer the [quiz questions](https://moodle.epfl.ch/mod/quiz/view.php?id=1112503).\n",
"\n",
"## Trees Versus Other Methods"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IRdisplay::display_html('')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercises\n",
"### Conceptual\n",
"\n",
"**Q1.**\n",
"(a) Draw an example (of your own invention) of a partition of two-dimensional feature space that could result from recursive binary splitting. Your example should contain at least six regions. Draw a decision tree corresponding to this partition. Be sure to label all aspects of your figures, including the regions $R_1,R_2,...$, the cutpoints $t_1,t_2,...$, and so forth.\n",
"\n",
"(b) Draw an example (of your own invention) of a partition of two-dimensional\n",
"feature space that could not result from recursive binary splitting. Justify,\n",
"why it cannot be the result from recursive binary splitting."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Q2.**\n",
"![](img/8.12.png)\n",
"\n",
"(a) Sketch the tree corresponding to the partition of the predictor space illustrated in the left-hand panel of the figure above. The numbers inside the boxes indicate the mean of $Y$ within each region.\n",
"\n",
"(b) Create a diagram similar to the left-hand panel of the figure above, using the tree illustrated in the right-hand panel of the same figure. You should divide up the predictor space into the correct regions, and indicate the mean for each region.\n",
"\n",
"### Applied"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Q3.** In this exercise we will look at the \"Carseats\" data set. You can get information about it by loading `library(ISLR)` and running `?Carseats`. We will seek to predict \"Sales\" using trees.\n",
"\n",
"(a) Split the data set into a training set and a test set.\n",
"\n",
"(b) Fit a regression tree to the training set. Plot the tree, and interpret the results. What test error rate do you obtain?\n",
"\n",
"(c) Instead of treating \"Sales\" as a quantitative variable we could recode it as\n",
"a qualitative variable by classifying the sales as \"low\" if they are below 5, \"medium\" if they are below 9 and \"high\" otherwise, i.e. we will introduce the new response variable `sales.class <- as.factor(ifelse(Carseats$Sales < 5, \"low\", ifelse(Carseats$Sales < 9, \"medium\", \"high\")))`. Fit a classification tree to the Carseats data with response `sales.class`. Do you get a similar tree as in (b)?\n",
"\n",
"(d) Use cross-validation in order to determine the optimal level of tree complexity for the tree in (c). Does pruning the tree improve the test error rate?\n",
"\n",
"**Q4.** (optional)\n",
"Fit a classification tree to the Histopathalogic Cancer Detection data set that\n",
"we studied in the last exercise of sheet 7* - part 2.\n",
"To tell R that the 0s and 1s in PCaml_y should be treated as values of a\n",
"categorical response, you may use `PCaml_y <- as.factor(PCaml_y)`.\n",
"Compare your results to the ones obtained with linear regression and\n",
"convolutional networks."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "R",
"language": "R",
"name": "ir"
},
"language_info": {
"codemirror_mode": "r",
"file_extension": ".r",
"mimetype": "text/x-r-source",
"name": "R",
"pygments_lexer": "r",
"version": "4.0.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}