{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Tune Regression " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading required package: daltoolbox\n", "\n", "Registered S3 method overwritten by 'quantmod':\n", " method from\n", " as.zoo.data.frame zoo \n", "\n", "\n", "Attaching package: ‘daltoolbox’\n", "\n", "\n", "The following object is masked from ‘package:base’:\n", "\n", " transform\n", "\n", "\n" ] } ], "source": [ "# DAL ToolBox\n", "# version 1.0.767\n", "\n", "source(\"https://raw.githubusercontent.com/cefet-rj-dal/daltoolbox/main/jupyter.R\")\n", "\n", "#loading DAL\n", "load_library(\"daltoolbox\") " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset for regression analysis" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading required package: MASS\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " crim zn indus chas nox rm age \n", "[1,] \"numeric\" \"numeric\" \"numeric\" \"integer\" \"numeric\" \"numeric\" \"numeric\"\n", " dis rad tax ptratio black lstat medv \n", "[1,] \"numeric\" \"integer\" \"numeric\" \"numeric\" \"numeric\" \"numeric\" \"numeric\"\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "\t\n", "\t\n", "\n", "\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\t\n", "\n", "
A data.frame: 6 × 14
crimzninduschasnoxrmagedisradtaxptratioblacklstatmedv
<dbl><dbl><dbl><int><dbl><dbl><dbl><dbl><int><dbl><dbl><dbl><dbl><dbl>
10.00632182.3100.5386.57565.24.0900129615.3396.904.9824.0
20.02731 07.0700.4696.42178.94.9671224217.8396.909.1421.6
30.02729 07.0700.4697.18561.14.9671224217.8392.834.0334.7
40.03237 02.1800.4586.99845.86.0622322218.7394.632.9433.4
50.06905 02.1800.4587.14754.26.0622322218.7396.905.3336.2
60.02985 02.1800.4586.43058.76.0622322218.7394.125.2128.7
\n" ], "text/latex": [ "A data.frame: 6 × 14\n", "\\begin{tabular}{r|llllllllllllll}\n", " & crim & zn & indus & chas & nox & rm & age & dis & rad & tax & ptratio & black & lstat & medv\\\\\n", " & & & & & & & & & & & & & & \\\\\n", "\\hline\n", "\t1 & 0.00632 & 18 & 2.31 & 0 & 0.538 & 6.575 & 65.2 & 4.0900 & 1 & 296 & 15.3 & 396.90 & 4.98 & 24.0\\\\\n", "\t2 & 0.02731 & 0 & 7.07 & 0 & 0.469 & 6.421 & 78.9 & 4.9671 & 2 & 242 & 17.8 & 396.90 & 9.14 & 21.6\\\\\n", "\t3 & 0.02729 & 0 & 7.07 & 0 & 0.469 & 7.185 & 61.1 & 4.9671 & 2 & 242 & 17.8 & 392.83 & 4.03 & 34.7\\\\\n", "\t4 & 0.03237 & 0 & 2.18 & 0 & 0.458 & 6.998 & 45.8 & 6.0622 & 3 & 222 & 18.7 & 394.63 & 2.94 & 33.4\\\\\n", "\t5 & 0.06905 & 0 & 2.18 & 0 & 0.458 & 7.147 & 54.2 & 6.0622 & 3 & 222 & 18.7 & 396.90 & 5.33 & 36.2\\\\\n", "\t6 & 0.02985 & 0 & 2.18 & 0 & 0.458 & 6.430 & 58.7 & 6.0622 & 3 & 222 & 18.7 & 394.12 & 5.21 & 28.7\\\\\n", "\\end{tabular}\n" ], "text/markdown": [ "\n", "A data.frame: 6 × 14\n", "\n", "| | crim <dbl> | zn <dbl> | indus <dbl> | chas <int> | nox <dbl> | rm <dbl> | age <dbl> | dis <dbl> | rad <int> | tax <dbl> | ptratio <dbl> | black <dbl> | lstat <dbl> | medv <dbl> |\n", "|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|\n", "| 1 | 0.00632 | 18 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1 | 296 | 15.3 | 396.90 | 4.98 | 24.0 |\n", "| 2 | 0.02731 | 0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242 | 17.8 | 396.90 | 9.14 | 21.6 |\n", "| 3 | 0.02729 | 0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242 | 17.8 | 392.83 | 4.03 | 34.7 |\n", "| 4 | 0.03237 | 0 | 2.18 | 0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3 | 222 | 18.7 | 394.63 | 2.94 | 33.4 |\n", "| 5 | 0.06905 | 0 | 2.18 | 0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3 | 222 | 18.7 | 396.90 | 5.33 | 36.2 |\n", "| 6 | 0.02985 | 0 | 2.18 | 0 | 0.458 | 6.430 | 58.7 | 6.0622 | 3 | 222 | 18.7 | 394.12 | 5.21 | 28.7 |\n", "\n" ], "text/plain": [ " crim zn indus chas nox rm age dis rad tax ptratio black lstat\n", "1 0.00632 18 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98 \n", "2 0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14 \n", "3 0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03 \n", "4 0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94 \n", "5 0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33 \n", "6 0.02985 0 2.18 0 0.458 6.430 58.7 6.0622 3 222 18.7 394.12 5.21 \n", " medv\n", "1 24.0\n", "2 21.6\n", "3 34.7\n", "4 33.4\n", "5 36.2\n", "6 28.7" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "load_library(\"MASS\")\n", "data(Boston)\n", "print(t(sapply(Boston, class)))\n", "head(Boston)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# for performance issues, you can use matrix\n", "Boston <- as.matrix(Boston)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Building samples (training and testing)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# preparing dataset for random sampling\n", "set.seed(1)\n", "sr <- sample_random()\n", "sr <- train_test(sr, Boston)\n", "boston_train <- sr$train\n", "boston_test <- sr$test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "tune <- reg_tune(reg_svm(\"medv\"))\n", "ranges <- list(seq(0,1,0.2), cost=seq(20,100,20), kernel = c(\"radial\"))\n", "model <- fit(tune, boston_train, ranges)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model adjustment" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " mse smape R2\n", "1 2.393491 0.05155025 0.9734081\n" ] } ], "source": [ "train_prediction <- predict(model, boston_train)\n", "boston_train_predictand <- boston_train[,\"medv\"]\n", "train_eval <- evaluate(model, boston_train_predictand, train_prediction)\n", "print(train_eval$metrics)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " mse smape R2\n", "1 13.61128 0.1297673 0.7738067\n" ] } ], "source": [ "test_prediction <- predict(model, boston_test)\n", "boston_test_predictand <- boston_test[,\"medv\"]\n", "test_eval <- evaluate(model, boston_test_predictand, test_prediction)\n", "print(test_eval$metrics)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Options for other models" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "#svm\n", "ranges <- list(seq(0,1,0.2), cost=seq(20,100,20), kernel = c(\"linear\", \"radial\", \"polynomial\", \"sigmoid\"))\n", "\n", "#knn\n", "ranges <- list(k=1:20)\n", "\n", "#mlp\n", "ranges <- list(size=1:10, decay=seq(0, 1, 0.1))\n", "\n", "#rf\n", "ranges <- list(mtry=1:10, ntree=1:10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "R", "language": "R", "name": "ir" }, "language_info": { "codemirror_mode": "r", "file_extension": ".r", "mimetype": "text/x-r-source", "name": "R", "pygments_lexer": "r", "version": "4.3.3" } }, "nbformat": 4, "nbformat_minor": 4 }