{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# Slithering Snake Example\n", "\n", "This Elastica tutorial explains how to setup a Cosserat rod simulation to simulate a slithering snake. It is a more complex use case than the Timoshenko Beam example. If you have not done so, we strongly suggest you start with [this beam example](./1_Timoshenko_Beam.ipynb) as it covers many of the basics of setting up and running simulations with Elastica. \n", "\n", "This slithering snake example includes gravitational forces, friction forces, and internal muscle torques. It also introduces the use of call back functions to allow logging of simulations data for post-processing after the simulation is over. \n", "\n", "\n", "## Getting Started\n", "To set up the simulation, the first thing you need to do is import the necessary classes. As with the Timoshenko bean, we need to import modules which allow us to more easily construct different simulation systems. We also need to import a rod class, all the necessary forces to be applied, timestepping functions, and callback classes. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install \"pyelastica[examples,docs]\"\n", "!conda install -c conda-forge ffmpeg -y" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import numpy as np\n", "\n", "# import modules\n", "from elastica.modules import BaseSystemCollection, Constraints, Forcing, CallBacks, Damping\n", "\n", "# import rod class, damping and forces to be applied\n", "from elastica.rod.cosserat_rod import CosseratRod\n", "from elastica.dissipation import AnalyticalLinearDamper\n", "from elastica.external_forces import GravityForces, MuscleTorques\n", "from elastica.interaction import AnisotropicFrictionalPlane\n", "\n", "# import timestepping functions\n", "from elastica.timestepper.symplectic_steppers import PositionVerlet\n", "from elastica.timestepper import integrate\n", "\n", "# import call back functions\n", "from elastica.callback_functions import CallBackBaseClass\n", "from collections import defaultdict" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Initialize System and Add Rod\n", "The first thing to do is initialize the simulator class by combining all the imported modules. After initializing, we will generate a rod and add it to the simulation. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "class SnakeSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks, Damping):\n", " pass\n", "\n", "\n", "snake_sim = SnakeSimulator()\n", "\n", "# Define rod parameters\n", "n_elem = 50\n", "start = np.array([0.0, 0.0, 0.0])\n", "direction = np.array([0.0, 0.0, 1.0])\n", "normal = np.array([0.0, 1.0, 0.0])\n", "base_length = 0.35\n", "base_radius = base_length * 0.011\n", "base_area = np.pi * base_radius ** 2\n", "density = 1000\n", "nu = 2e-3\n", "E = 1e6\n", "poisson_ratio = 0.5\n", "shear_modulus = E / (poisson_ratio + 1.0)\n", "\n", "# Create rod\n", "shearable_rod = CosseratRod.straight_rod(\n", " n_elem,\n", " start,\n", " direction,\n", " normal,\n", " base_length,\n", " base_radius,\n", " density,\n", " youngs_modulus=E,\n", " shear_modulus=shear_modulus,\n", ")\n", "\n", "# Add rod to the snake system\n", "snake_sim.append(shearable_rod)\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "## Adding Damping\n", "With the rod added to the simulator, we can add damping to the rod. We do this using the `.dampen()` option and the `AnalyticalLinearDamper`. We are modifying `snake_sim` simulator to `dampen` the `shearable_rod` object using `AnalyticalLinearDamper` type of dissipation (damping) model.\n", "\n", "We also need to define `damping_constant` and simulation `time_step` and pass in `.using()` method." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "dt = 1e-4\n", "snake_sim.dampen(shearable_rod).using(\n", " AnalyticalLinearDamper,\n", " damping_constant=nu,\n", " time_step=dt,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Add Forces to Rod\n", "With our rod added to the system, we need to specify the relevant forces that will be acting on the rod. For all the forces, the method of adding forces is `system_name.add_forcing_to(name_of_rod).using(type_of_force, *kwargs)` where `*kwargs` are the parameters specific to each type of force. \n", "\n", "### Gravity\n", "The first force to add is gravity. We specify the strength of gravity and also the direction it is pointing. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Add gravitational forces\n", "gravitational_acc = -9.80665\n", "snake_sim.add_forcing_to(shearable_rod).using(\n", " GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])\n", ")\n", "print(\"Gravity now acting on shearable rod\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Muscle Torques\n", "A snake generates torque throughout its body through muscle activations. While these muscle activations are generated internally by the snake, it is simpler to treat them as applied external forces, allowing us to apply them to the rod in the same manner as the other external forces. \n", "\n", "You may notice that the muscle torque parameters appear to have special values. These are optimized coefficients for a snake gait. For information about how to do this optimization, see the [snake optimization example script](../ContinuumSnakeCase/continuum_snake.py)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Define muscle torque parameters\n", "period = 2.0\n", "wave_length = 1.0\n", "b_coeff = np.array([3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3])\n", "\n", "# Add muscle torques to the rod\n", "snake_sim.add_forcing_to(shearable_rod).using(\n", " MuscleTorques,\n", " base_length=base_length,\n", " b_coeff=b_coeff,\n", " period=period,\n", " wave_number=2.0 * np.pi / (wave_length),\n", " phase_shift=0.0,\n", " rest_lengths=shearable_rod.rest_lengths,\n", " ramp_up_time=period,\n", " direction=normal,\n", " with_spline=True,\n", ")\n", "print(\"Muscle torques added to the rod\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Anisotropic Friction Forces\n", "The last force that needs to be added is the friction force between the snake and the ground. Snakes exhibits anisotropic friction where the friction coefficient is different in different directions. You can also define both static and kinematic friction coefficients. This is accomplished by defining some small velocity threshold `slip_velocity_tol` that defines the transitions between static and kinematic friction. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Define friction force parameters\n", "origin_plane = np.array([0.0, -base_radius, 0.0])\n", "normal_plane = normal\n", "slip_velocity_tol = 1e-8\n", "froude = 0.1\n", "mu = base_length / (period * period * np.abs(gravitational_acc) * froude)\n", "kinetic_mu_array = np.array(\n", " [1.0 * mu, 1.5 * mu, 2.0 * mu]\n", ") # [forward, backward, sideways]\n", "static_mu_array = 2 * kinetic_mu_array\n", "\n", "# Add friction forces to the rod\n", "snake_sim.add_forcing_to(shearable_rod).using(\n", " AnisotropicFrictionalPlane,\n", " k=1.0,\n", " nu=1e-6,\n", " plane_origin=origin_plane,\n", " plane_normal=normal_plane,\n", " slip_velocity_tol=slip_velocity_tol,\n", " static_mu_array=static_mu_array,\n", " kinetic_mu_array=kinetic_mu_array,\n", ")\n", "print(\"Friction forces added to the rod\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Add Callback Function\n", "The simulation is now setup, but before it is run, we want to define a callback function. A callback function allows us to record time-series data throughout the simulation. If you do not define a callback function, you will only have access to the final configuration of the system. If you want to be able to analyze how the system evolves over time, it is critical that you record the appropriate quantities. \n", "\n", "To create a callback function, begin with the `CallBackBaseClass`. You can then define which state quantities you wish to record by having them appended to the `self.callback_params` dictionary as well as how often you wish to save the data by defining `skip_step`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Add call backs\n", "class ContinuumSnakeCallBack(CallBackBaseClass):\n", " \"\"\"\n", " Call back function for continuum snake\n", " \"\"\"\n", "\n", " def __init__(self, step_skip: int, callback_params: dict):\n", " CallBackBaseClass.__init__(self)\n", " self.every = step_skip\n", " self.callback_params = callback_params\n", "\n", " def make_callback(self, system, time, current_step: int):\n", "\n", " if current_step % self.every == 0:\n", "\n", " self.callback_params[\"time\"].append(time)\n", " self.callback_params[\"step\"].append(current_step)\n", " self.callback_params[\"position\"].append(system.position_collection.copy())\n", " self.callback_params[\"velocity\"].append(system.velocity_collection.copy())\n", " self.callback_params[\"avg_velocity\"].append(\n", " system.compute_velocity_center_of_mass()\n", " )\n", "\n", " self.callback_params[\"center_of_mass\"].append(\n", " system.compute_position_center_of_mass()\n", " )\n", " self.callback_params[\"curvature\"].append(system.kappa.copy())\n", "\n", " return\n", "\n", "\n", "pp_list = defaultdict(list)\n", "snake_sim.collect_diagnostics(shearable_rod).using(\n", " ContinuumSnakeCallBack, step_skip=100, callback_params=pp_list\n", ")\n", "print(\"Callback function added to the simulator\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "With the callback function added, we can now finalize the system and also define the time stepping parameters of the simulation such as the time step, final time, and time stepping algorithm to use. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "snake_sim.finalize()\n", "\n", "final_time = 5.0 * period\n", "total_steps = int(final_time / dt)\n", "print(\"Total steps\", total_steps)\n", "\n", "timestepper = PositionVerlet()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Now all that is left is to run the simulation. Using the default parameters the simulation takes about 2-3 minutes to complete. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "integrate(timestepper, snake_sim, final_time, total_steps)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Post-Process Data\n", "With the simulation complete, we want to analyze the simulation. Because we added a callback function, we can analyze how the snake evolves over time. All of the data from the callback function is located in the `pp_list` dictionary. Here we will use this information to compute and plot the velocity of the snake in the forward, lateral, and normal directions. We do this by using a pre-written analysis function `compute_projected_velocity`.\n", "\n", "In the plotted graph, you can see that it takes about one period for the snake to begin moving before rapidly reaching a steady gait over just 2-3 periods. We also see that the normal velocity is zero since we are only actuating the snake in a 2D plane. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "def compute_projected_velocity(plot_params: dict, period):\n", " import numpy as np\n", "\n", " time_per_period = np.array(plot_params[\"time\"]) / period\n", " avg_velocity = np.array(plot_params[\"avg_velocity\"])\n", " center_of_mass = np.array(plot_params[\"center_of_mass\"])\n", "\n", " # Compute rod velocity in rod direction. We need to compute that because,\n", " # after snake starts to move it chooses an arbitrary direction, which does not\n", " # have to be initial tangent direction of the rod. Thus we need to project the\n", " # snake velocity with respect to its new tangent and roll direction, after that\n", " # we will get the correct forward and lateral speed. After this projection\n", " # lateral velocity of the snake has to be oscillating between + and - values with\n", " # zero mean.\n", "\n", " # Number of steps in one period.\n", " period_step = int(1.0 / (time_per_period[-1] - time_per_period[-2]))\n", " number_of_period = int(time_per_period[-1])\n", "\n", " # Center of mass position averaged in one period\n", " center_of_mass_averaged_over_one_period = np.zeros((number_of_period - 2, 3))\n", " for i in range(1, number_of_period - 2):\n", " # position of center of mass averaged over one period\n", " center_of_mass_averaged_over_one_period[i - 1] = np.mean(\n", " center_of_mass[(i + 1) * period_step : (i + 2) * period_step]\n", " - center_of_mass[(i + 0) * period_step : (i + 1) * period_step],\n", " axis=0,\n", " )\n", " # Average the rod directions over multiple periods and get the direction of the rod.\n", " direction_of_rod = np.mean(center_of_mass_averaged_over_one_period, axis=0)\n", " direction_of_rod /= np.linalg.norm(direction_of_rod, ord=2)\n", "\n", " # Compute the projected rod velocity in the direction of the rod\n", " velocity_mag_in_direction_of_rod = np.einsum(\n", " \"ji,i->j\", avg_velocity, direction_of_rod\n", " )\n", " velocity_in_direction_of_rod = np.einsum(\n", " \"j,i->ji\", velocity_mag_in_direction_of_rod, direction_of_rod\n", " )\n", "\n", " # Get the lateral or roll velocity of the rod after subtracting its projected\n", " # velocity in the direction of rod\n", " velocity_in_rod_roll_dir = avg_velocity - velocity_in_direction_of_rod\n", "\n", " # Compute the average velocity over the simulation, this can be used for optimizing snake\n", " # for fastest forward velocity. We start after first period, because of the ramping up happens\n", " # in first period.\n", " average_velocity_over_simulation = np.mean(\n", " velocity_in_direction_of_rod[period_step * 2 :], axis=0\n", " )\n", "\n", " return (\n", " velocity_in_direction_of_rod,\n", " velocity_in_rod_roll_dir,\n", " average_velocity_over_simulation[2],\n", " average_velocity_over_simulation[0],\n", " )\n", "\n", "\n", "def compute_and_plot_velocity(plot_params: dict, period):\n", " from matplotlib import pyplot as plt\n", " from matplotlib.colors import to_rgb\n", "\n", " time_per_period = np.array(plot_params[\"time\"]) / period\n", " avg_velocity = np.array(plot_params[\"avg_velocity\"])\n", "\n", " [\n", " velocity_in_direction_of_rod,\n", " velocity_in_rod_roll_dir,\n", " _,\n", " _,\n", " ] = compute_projected_velocity(plot_params, period)\n", "\n", " fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)\n", " plt.rcParams.update({\"font.size\": 16})\n", " ax = fig.add_subplot(111)\n", " ax.grid(which=\"minor\", color=\"k\", linestyle=\"--\")\n", " ax.grid(which=\"major\", color=\"k\", linestyle=\"-\")\n", " ax.plot(\n", " time_per_period[:], velocity_in_direction_of_rod[:, 2], \"r-\", label=\"forward\"\n", " )\n", " ax.plot(\n", " time_per_period[:],\n", " velocity_in_rod_roll_dir[:, 0],\n", " c=to_rgb(\"xkcd:bluish\"),\n", " label=\"lateral\",\n", " )\n", " ax.plot(time_per_period[:], avg_velocity[:, 1], \"k-\", label=\"normal\")\n", " ax.set_ylabel(\"Velocity [m/s]\", fontsize=16)\n", " ax.set_xlabel(\"Time [s]\", fontsize=16)\n", " fig.legend(prop={\"size\": 20})\n", " plt.show()\n", " plt.close(plt.gcf())\n", "\n", "\n", "compute_and_plot_velocity(pp_list, period)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } }, "source": [ "We can plot the curvature along the snake at different time instance and compare it with the sterotypical snake curvature function $7cos(2 \\pi s)$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "def plot_curvature(\n", " plot_params: dict,\n", " rest_lengths,\n", " period,\n", "):\n", " from matplotlib import pyplot as plt\n", " from matplotlib.colors import to_rgb\n", "\n", " s = np.cumsum(rest_lengths)\n", " L0 = s[-1]\n", " s = s / L0\n", " s = s[:-1].copy()\n", " x = np.linspace(0, 1, 100)\n", " curvature = np.array(plot_params[\"curvature\"])\n", " time = np.array(plot_params[\"time\"])\n", " peak_time = period * 0.125\n", " dt = time[1] - time[0]\n", " peak_idx = int(peak_time / (dt))\n", " plt.rcParams.update({\"font.size\": 16})\n", " fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)\n", " ax = fig.add_subplot(111)\n", " try:\n", " for i in range(peak_idx * 8, peak_idx * 8 * 2, peak_idx):\n", " ax.plot(s, curvature[i, 0, :] * L0, \"k\")\n", " except:\n", " print(\"Simulation time not long enough to plot curvature\")\n", " ax.plot(\n", " x, 7 * np.cos(2 * np.pi * x - 0.80), \"--\", label=\"stereotypical snake curvature\"\n", " )\n", " ax.set_ylabel(r\"$\\kappa$\", fontsize=16)\n", " ax.set_xlabel(\"s\", fontsize=16)\n", " ax.set_xlim(0, 1)\n", " ax.set_ylim(-10, 10)\n", " fig.legend(prop={\"size\": 16})\n", " plt.show()\n", "\n", " plt.close(plt.gcf())\n", "\n", "\n", "plot_curvature(pp_list, shearable_rod.rest_lengths, period)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Make Video of Snake Gait\n", "Because we saved data of the snake's behavior, we can make a video of its movement. The easiest way to do this is to do this is to plot the snake's position at each time that the data was recorded and then stitch these plots together to form a video. \n", "\n", "note: ffmpeg is required for matplotlib to be able to create a video. More info on ffmepg [here](https://www.ffmpeg.org/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from IPython.display import Video\n", "from tqdm import tqdm\n", "\n", "\n", "def plot_video_2D(plot_params: dict, video_name=\"video.mp4\", margin=0.2, fps=15):\n", " from matplotlib import pyplot as plt\n", " import matplotlib.animation as manimation\n", "\n", " t = np.array(plot_params[\"time\"])\n", " positions_over_time = np.array(plot_params[\"position\"])\n", " total_time = int(np.around(t[..., -1], 1))\n", " total_frames = fps * total_time\n", " step = round(len(t) / total_frames)\n", "\n", " print(\"creating video -- this can take a few minutes\")\n", " FFMpegWriter = manimation.writers[\"ffmpeg\"]\n", " metadata = dict(title=\"Movie Test\", artist=\"Matplotlib\", comment=\"Movie support!\")\n", " writer = FFMpegWriter(fps=fps, metadata=metadata)\n", "\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", " plt.axis(\"equal\")\n", " rod_lines_2d = ax.plot(\n", " positions_over_time[0][2], positions_over_time[0][0], linewidth=3\n", " )[0]\n", " ax.set_xlim([0 - margin, 3 + margin])\n", " ax.set_ylim([-1.5 - margin, 1.5 + margin])\n", " with writer.saving(fig, video_name, dpi=100):\n", " with plt.style.context(\"seaborn-v0_8-whitegrid\"):\n", " for time in range(1, len(t), step):\n", " rod_lines_2d.set_xdata(positions_over_time[time][2])\n", " rod_lines_2d.set_ydata(positions_over_time[time][0])\n", "\n", " writer.grab_frame()\n", " plt.close(fig)\n", "\n", "\n", "filename_video = \"continuum_snake.mp4\"\n", "plot_video_2D(pp_list, video_name=filename_video, margin=0.2, fps=125)\n", "\n", "Video(\"continuum_snake.mp4\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Finally, you can also plot the position of the snake from a 3D perspective. This is most helpful is you have a simulation that consists of more than planar motion. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from IPython.display import Video\n", "\n", "\n", "def plot_video(plot_params: dict, video_name=\"video.mp4\", margin=0.2, fps=15):\n", " from matplotlib import pyplot as plt\n", " import matplotlib.animation as manimation\n", " from mpl_toolkits import mplot3d\n", "\n", " t = np.array(plot_params[\"time\"])\n", " positions_over_time = np.array(plot_params[\"position\"])\n", " total_time = int(np.around(t[..., -1], 1))\n", " total_frames = fps * total_time\n", " step = round(len(t) / total_frames)\n", " print(\"creating video -- this can take a few minutes\")\n", " FFMpegWriter = manimation.writers[\"ffmpeg\"]\n", " metadata = dict(title=\"Movie Test\", artist=\"Matplotlib\", comment=\"Movie support!\")\n", " writer = FFMpegWriter(fps=fps, metadata=metadata)\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111, projection=\"3d\")\n", " ax.set_xlim(0 - margin, 3 + margin)\n", " ax.set_ylim(-1.5 - margin, 1.5 + margin)\n", " ax.set_zlim(0, 1)\n", " ax.view_init(elev=20, azim=-80)\n", " rod_lines_3d = ax.plot(\n", " positions_over_time[0][2],\n", " positions_over_time[0][0],\n", " positions_over_time[0][1],\n", " linewidth=3,\n", " )[0]\n", " with writer.saving(fig, video_name, dpi=100):\n", " with plt.style.context(\"seaborn-v0_8-whitegrid\"):\n", " for time in range(1, len(t), step):\n", " rod_lines_3d.set_xdata(positions_over_time[time][2])\n", " rod_lines_3d.set_ydata(positions_over_time[time][0])\n", " rod_lines_3d.set_3d_properties(positions_over_time[time][1])\n", "\n", " writer.grab_frame()\n", " plt.close(fig)\n", "\n", "\n", "filename_video = \"continuum_snake_3d.mp4\"\n", "plot_video(pp_list, video_name=filename_video, margin=0.2, fps=60)\n", "\n", "Video(\"continuum_snake_3d.mp4\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.12 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "vscode": { "interpreter": { "hash": "69b409ac0f88f35940b70a7cc6f81becc62e64d51a49713aaa6660602c575037" } } }, "nbformat": 4, "nbformat_minor": 2 }