{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Nonlinear regression for inverse dynamics\n", "\n", "In this question, we fit a model which can predict what torques a robot needs to apply in order to make its arm reach a desired point in space. The data was collected from a SARCOS robot arm with $7$ degrees of freedom. The input vector $x \\in \\mathbb{R}^{21}$ encodes the desired position, velocity and accelaration of the $7$ joints. The output vector $y \\in \\mathbb{R}^7$ encodes the torques that should be applied to the joints to reach that point. The mapping from $x$ to $y$ is highly nonlinear.\n", "\n", "We can find the data at [http://www.gaussianprocess.org/gpml/data/](http://www.gaussianprocess.org/gpml/data/). Obviously the first step is to read this data." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | x1 | \n", "x2 | \n", "x3 | \n", "x4 | \n", "x5 | \n", "x6 | \n", "x7 | \n", "x8 | \n", "x9 | \n", "x10 | \n", "... | \n", "x19 | \n", "x20 | \n", "x21 | \n", "y1 | \n", "y2 | \n", "y3 | \n", "y4 | \n", "y5 | \n", "y6 | \n", "y7 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.019478 | \n", "-0.134218 | \n", "0.027439 | \n", "1.516401 | \n", "0.300936 | \n", "0.058259 | \n", "0.150134 | \n", "-0.266791 | \n", "-0.237134 | \n", "-0.091272 | \n", "... | \n", "11.695956 | \n", "1.210212 | \n", "-22.119289 | \n", "50.292652 | \n", "-36.971897 | \n", "20.937170 | \n", "47.821712 | \n", "-0.424812 | \n", "-0.907553 | \n", "8.090739 | \n", "
1 | \n", "0.017279 | \n", "-0.137077 | \n", "0.026999 | \n", "1.532517 | \n", "0.301344 | \n", "0.058259 | \n", "0.128653 | \n", "-0.153640 | \n", "-0.335279 | \n", "0.006449 | \n", "... | \n", "14.643369 | \n", "1.015070 | \n", "-17.048688 | \n", "44.104164 | \n", "-28.851845 | \n", "16.230194 | \n", "43.194073 | \n", "-0.228739 | \n", "-1.235817 | \n", "7.762475 | \n", "
2 | \n", "0.016336 | \n", "-0.140878 | \n", "0.027250 | \n", "1.549670 | \n", "0.302318 | \n", "0.059027 | \n", "0.104104 | \n", "-0.047313 | \n", "-0.418732 | \n", "0.106274 | \n", "... | \n", "15.467628 | \n", "0.910548 | \n", "-11.415526 | \n", "37.354858 | \n", "-20.809343 | \n", "12.379975 | \n", "39.386017 | \n", "0.244491 | \n", "-1.700880 | \n", "7.289678 | \n", "
3 | \n", "0.016273 | \n", "-0.145307 | \n", "0.029072 | \n", "1.566855 | \n", "0.307628 | \n", "0.059027 | \n", "0.080321 | \n", "0.053238 | \n", "-0.460963 | \n", "0.188013 | \n", "... | \n", "10.309203 | \n", "0.921360 | \n", "-5.772058 | \n", "30.676065 | \n", "-13.963816 | \n", "7.702940 | \n", "36.478813 | \n", "-0.182062 | \n", "-2.143370 | \n", "6.410800 | \n", "
4 | \n", "0.017279 | \n", "-0.150051 | \n", "0.031083 | \n", "1.584416 | \n", "0.314162 | \n", "0.059027 | \n", "0.058840 | \n", "0.133810 | \n", "-0.462264 | \n", "0.263975 | \n", "... | \n", "2.868096 | \n", "1.059957 | \n", "-0.491542 | \n", "25.920128 | \n", "-11.178479 | \n", "5.643934 | \n", "34.773911 | \n", "-1.031687 | \n", "-2.355776 | \n", "5.792892 | \n", "
5 rows × 28 columns
\n", "\n", " | bandwidth | \n", "mean_test_score | \n", "mean_train_score | \n", "std_test_score | \n", "std_train_score | \n", "
---|---|---|---|---|---|
0 | \n", "1 | \n", "-267.152902 | \n", "-265.932425 | \n", "15.279617 | \n", "3.813405 | \n", "
1 | \n", "2 | \n", "-188.522188 | \n", "-187.648077 | \n", "11.168259 | \n", "2.790824 | \n", "
2 | \n", "4 | \n", "-122.894823 | \n", "-122.112400 | \n", "6.091327 | \n", "1.522200 | \n", "
3 | \n", "8 | \n", "-72.431290 | \n", "-71.858192 | \n", "2.480018 | \n", "0.620141 | \n", "
4 | \n", "16 | \n", "-40.163809 | \n", "-39.825444 | \n", "0.767083 | \n", "0.193418 | \n", "
5 | \n", "32 | \n", "-25.264711 | \n", "-25.057333 | \n", "0.227940 | \n", "0.056388 | \n", "
6 | \n", "64 | \n", "-20.167603 | \n", "-20.004292 | \n", "0.097739 | \n", "0.025807 | \n", "
7 | \n", "128 | \n", "-18.716779 | \n", "-18.568064 | \n", "0.192818 | \n", "0.049968 | \n", "
8 | \n", "256 | \n", "-18.422612 | \n", "-18.277355 | \n", "0.228230 | \n", "0.058791 | \n", "
9 | \n", "512 | \n", "-18.432984 | \n", "-18.287904 | \n", "0.238813 | \n", "0.061415 | \n", "