{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib.pylab import *\n", "from matplotlib.patches import Rectangle\n", "from matplotlib.collections import PatchCollection\n", "from matplotlib.lines import Line2D\n", "π = pi" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "style.use(['dark_background', 'bmh'])\n", "%matplotlib widget" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Car-trailer diagram (inverted image `res/car-trainer-k.png` available as well):\n", "![car-trailer](res/car-trailer-w.png)\n", "\n", "Car-trailer equation:\n", "\\begin{align}\n", "\\dot x &= s \\cos \\theta_0 \\\\\n", "\\dot y &= s \\sin \\theta_0 \\\\\n", "\\dot \\theta_0 &= \\frac{s}{L} \\tan \\phi \\\\\n", "\\dot \\theta_1 &= \\frac{s}{d_1} \\sin(\\theta_1 - \\theta_0)\n", "\\end{align}\n", "where $s$: signed speed, $\\phi$: negative steering angle," ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Truck:\n", " def __init__(self, display=False):\n", "\n", " self.W = 1 # car and trailer width, for drawing only\n", " self.L = 1 * self.W # car length\n", " self.d = 4 * self.L # d_1\n", " self.s = -0.1 # speed\n", " self.display = display\n", " \n", " self.box = [0, 40, -10, 10]\n", " if self.display:\n", " self.f = figure(figsize=(10, 5), num='The truck backer-upper', facecolor='none')\n", " self.ax = self.f.add_axes([0.01, 0.01, 0.98, 0.98], facecolor='black')\n", " self.patches = list()\n", " \n", " self.ax.axis('equal')\n", " b = self.box\n", " self.ax.axis([b[0] - 1, b[1], b[2], b[3]])\n", " self.ax.set_xticks([]); self.ax.set_yticks([])\n", " self.ax.axhline(); self.ax.axvline()\n", "\n", " self.reset()\n", " \n", " def reset(self, ϕ=0):\n", " self.ϕ = ϕ # car initial steering angle\n", " \n", " # self.θ0 = deg2rad(30) # car initial direction\n", " # self.θ1 = deg2rad(-30) # trailer initial direction\n", " # self.x, self.y = 20, -5 # initial car coordinates\n", " \n", " self.θ0 = random() * 2 * π # 0 <= ϑ₀ < 2π\n", " self.θ1 = (random() - 0.5) * π / 2 + self.θ0 # -π/4 <= ϑ₁ - ϑ₀ < π/4\n", " self.x = (random() * .75 + 0.25) * self.box[1]\n", " self.y = (random() - 0.5) * (self.box[3] - self.box[2])\n", " \n", " # If poorly initialise, then re-initialise\n", " if not self.valid():\n", " self.reset(ϕ)\n", " \n", " # Draw, if display is True\n", " if self.display: self.draw()\n", " \n", " def step(self, ϕ=0, dt=1):\n", " \n", " # Check for illegal conditions\n", " if self.is_jackknifed():\n", " print('The truck is jackknifed!')\n", " return\n", " \n", " if self.is_offscreen():\n", " print('The car or trailer is off screen')\n", " return\n", " \n", " self.ϕ = ϕ\n", " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", " \n", " # Perform state update\n", " self.x += s * cos(θ0) * dt\n", " self.y += s * sin(θ0) * dt\n", " self.θ0 += s / L * tan(ϕ) * dt\n", " self.θ1 += s / d * sin(θ0 - θ1) * dt\n", " \n", " return (self.x, self.y, self.θ0, *self._traler_xy(), self.θ1)\n", " \n", " def state(self):\n", " return (self.x, self.y, self.θ0, *self._traler_xy(), self.θ1)\n", " \n", " def _get_atributes(self):\n", " return (\n", " self.x, self.y, self.W, self.L, self.d, self.s,\n", " self.θ0, self.θ1, self.ϕ\n", " )\n", " \n", " def _traler_xy(self):\n", " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", " return x - d * cos(θ1), y - d * sin(θ1)\n", " \n", " def is_jackknifed(self):\n", " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", " return abs(θ0 - θ1) * 180 / π > 90\n", " \n", " def is_offscreen(self):\n", " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", " \n", " x1, y1 = x + 1.5 * L * cos(θ0), y + 1.5 * L * sin(θ0)\n", " x2, y2 = self._traler_xy()\n", " \n", " b = self.box\n", " return not (\n", " b[0] <= x1 <= b[1] and b[2] <= y1 <= b[3] and\n", " b[0] <= x2 <= b[1] and b[2] <= y2 <= b[3]\n", " )\n", " \n", " def valid(self):\n", " return not self.is_jackknifed() and not self.is_offscreen()\n", " \n", " def draw(self):\n", " if not self.display: return\n", " if self.patches: self.clear()\n", " self._draw_car()\n", " self._draw_trailer()\n", " self.f.canvas.draw()\n", " \n", " def clear(self):\n", " for p in self.patches:\n", " p.remove()\n", " self.patches = list()\n", " \n", " def _draw_car(self):\n", " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", " ax = self.ax\n", " \n", " x1, y1 = x + L / 2 * cos(θ0), y + L / 2 * sin(θ0)\n", " bar = Line2D((x, x1), (y, y1), lw=5, color='C2', alpha=0.8)\n", " ax.add_line(bar)\n", "\n", " car = Rectangle(\n", " (x1, y1 - W / 2), L, W, color='C2', alpha=0.8, transform=\n", " matplotlib.transforms.Affine2D().rotate_deg_around(x1, y1, θ0 * 180 / π) +\n", " ax.transData\n", " )\n", " ax.add_patch(car)\n", "\n", " x2, y2 = x1 + L / 2 ** 0.5 * cos(θ0 + π / 4), y1 + L / 2 ** 0.5 * sin(θ0 + π / 4)\n", " left_wheel = Line2D(\n", " (x2 - L / 4 * cos(θ0 + ϕ), x2 + L / 4 * cos(θ0 + ϕ)),\n", " (y2 - L / 4 * sin(θ0 + ϕ), y2 + L / 4 * sin(θ0 + ϕ)),\n", " lw=3, color='C5', alpha=1)\n", " ax.add_line(left_wheel)\n", "\n", " x3, y3 = x1 + L / 2 ** 0.5 * cos(π / 4 - θ0), y1 - L / 2 ** 0.5 * sin(π / 4 - θ0)\n", " right_wheel = Line2D(\n", " (x3 - L / 4 * cos(θ0 + ϕ), x3 + L / 4 * cos(θ0 + ϕ)),\n", " (y3 - L / 4 * sin(θ0 + ϕ), y3 + L / 4 * sin(θ0 + ϕ)),\n", " lw=3, color='C5', alpha=1)\n", " ax.add_line(right_wheel)\n", " \n", " self.patches += [car, bar, left_wheel, right_wheel]\n", " \n", " def _draw_trailer(self):\n", " x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n", " ax = self.ax\n", " \n", " x, y = x - d * cos(θ1), y - d * sin(θ1) - W / 2\n", " trailer = Rectangle(\n", " (x, y), d, W, color='C0', alpha=0.8, transform=\n", " matplotlib.transforms.Affine2D().rotate_deg_around(x, y + W/2, θ1 * 180 / π) +\n", " ax.transData\n", " )\n", " ax.add_patch(trailer)\n", " \n", " self.patches += [trailer]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "truck = Truck(display=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ϕ = deg2rad(-35) # positive left, negative right\n", "truck.step(ϕ)\n", "truck.draw()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "truck.reset()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import SGD\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Build expert data set\n", "\n", "episodes = 10\n", "inputs = list()\n", "outputs = list()\n", "# truck = Truck(); episodes = 10_000 # uncomment for creating the data set\n", "\n", "for episode in tqdm(range(episodes)):\n", " \n", " truck.reset()\n", " \n", " while truck.valid():\n", " initial_state = truck.state()\n", " ϕ = (random() - 0.5) * π / 2\n", " inputs.append((ϕ, *initial_state))\n", " outputs.append(truck.step(ϕ))\n", " truck.draw()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(inputs), len(outputs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "state_size = 6\n", "steering_size = 1\n", "hidden_units_e = 45\n", "\n", "emulator = nn.Sequential(\n", " nn.Linear(steering_size + state_size, hidden_units_e),\n", " nn.ReLU(),\n", " nn.Linear(hidden_units_e, state_size)\n", ")\n", "\n", "optimiser_e = SGD(emulator.parameters(), lr=0.005)\n", "criterion = nn.MSELoss()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tensor_inputs = torch.Tensor(inputs)\n", "tensor_outputs = torch.Tensor(outputs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mean = tensor_inputs.mean(0)\n", "std = tensor_inputs.std(0)\n", "tensor_inputs = (tensor_inputs - mean) / std\n", "tensor_outputs = (tensor_outputs - mean[1:]) / std[1:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Split the data into 80:20 for test:train.\n", "test_size = int(len(tensor_inputs) * 0.8)\n", "print(len(tensor_inputs), test_size)\n", "\n", "train_inputs = tensor_inputs[:test_size]\n", "train_outputs = tensor_outputs[:test_size]\n", "test_inputs = tensor_inputs[test_size:]\n", "test_outputs = tensor_outputs[test_size:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(train_inputs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Emulator training\n", "cnt = 0\n", "for i in torch.randperm(len(train_inputs)):\n", " ϕ_state = train_inputs[i]\n", " next_state_prediction = emulator(ϕ_state)\n", " \n", " next_state = train_outputs[i]\n", " loss = criterion(next_state_prediction, next_state)\n", " \n", " optimiser_e.zero_grad()\n", " loss.backward()\n", " optimiser_e.step()\n", " \n", " if cnt == 0 or (cnt + 1) % 1000 == 0:\n", " print(f'{cnt + 1:4d} / {len(train_inputs)}, {loss.item():.10f}')\n", " cnt += 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test\n", "total_loss = 0\n", "with torch.no_grad():\n", " for idx, ϕ_state in enumerate(test_inputs):\n", " next_state_prediction = emulator(ϕ_state)\n", "\n", " next_state = test_outputs[idx]\n", " total_loss += criterion(next_state_prediction, next_state).item()\n", "\n", "ave_test_loss = total_loss/test_size\n", "print(f'Test loss: {ave_test_loss:.10f}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Here you need to insert the code for training the controller\n", "# by using the emulator for backpropagation\n", "\n", "# If you succeed, feel free to send a PR" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }