# Slithering Snake Example

This Elastica tutorial explains how to setup a Cosserat rod simulation to simulate a slithering snake. It is a more complex use case than the Timoshenko Beam example. If you have not done so, we strongly suggest you start with [this beam example](./1_Timoshenko_Beam.ipynb) as it covers many of the basics of setting up and running simulations with Elastica. 

This slithering snake example includes gravitational forces, friction forces, and internal muscle torques. It also introduces the use of call back functions to allow logging of simulations data for post-processing after the simulation is over. 


## Getting Started
To set up the simulation, the first thing you need to do is import the necessary classes. As with the Timoshenko bean, we need to import modules which allow us to more easily construct different simulation systems. We also need to import a rod class, all the necessary forces to be applied, timestepping functions, and callback classes. 

In [None]:
!pip install "pyelastica[examples,docs]"
!conda install -c conda-forge ffmpeg -y

In [None]:
import numpy as np

# import modules
from elastica.modules import BaseSystemCollection, Constraints, Forcing, CallBacks, Damping

# import rod class, damping and forces to be applied
from elastica.rod.cosserat_rod import CosseratRod
from elastica.dissipation import AnalyticalLinearDamper
from elastica.external_forces import GravityForces, MuscleTorques
from elastica.interaction import AnisotropicFrictionalPlane

# import timestepping functions
from elastica.timestepper.symplectic_steppers import PositionVerlet
from elastica.timestepper import integrate

# import call back functions
from elastica.callback_functions import CallBackBaseClass
from collections import defaultdict

## Initialize System and Add Rod
The first thing to do is initialize the simulator class by combining all the imported modules. After initializing, we will generate a rod and add it to the simulation. 

In [None]:
class SnakeSimulator(BaseSystemCollection, Constraints, Forcing, CallBacks, Damping):
 pass


snake_sim = SnakeSimulator()

# Define rod parameters
n_elem = 50
start = np.array([0.0, 0.0, 0.0])
direction = np.array([0.0, 0.0, 1.0])
normal = np.array([0.0, 1.0, 0.0])
base_length = 0.35
base_radius = base_length * 0.011
base_area = np.pi * base_radius ** 2
density = 1000
nu = 2e-3
E = 1e6
poisson_ratio = 0.5
shear_modulus = E / (poisson_ratio + 1.0)

# Create rod
shearable_rod = CosseratRod.straight_rod(
 n_elem,
 start,
 direction,
 normal,
 base_length,
 base_radius,
 density,
 youngs_modulus=E,
 shear_modulus=shear_modulus,
)

# Add rod to the snake system
snake_sim.append(shearable_rod)


## Adding Damping
With the rod added to the simulator, we can add damping to the rod. We do this using the `.dampen()` option and the `AnalyticalLinearDamper`. We are modifying `snake_sim` simulator to `dampen` the `shearable_rod` object using `AnalyticalLinearDamper` type of dissipation (damping) model.

We also need to define `damping_constant` and simulation `time_step` and pass in `.using()` method.

In [None]:
dt = 1e-4
snake_sim.dampen(shearable_rod).using(
 AnalyticalLinearDamper,
 damping_constant=nu,
 time_step=dt,
)

## Add Forces to Rod
With our rod added to the system, we need to specify the relevant forces that will be acting on the rod. For all the forces, the method of adding forces is `system_name.add_forcing_to(name_of_rod).using(type_of_force, *kwargs)` where `*kwargs` are the parameters specific to each type of force. 

### Gravity
The first force to add is gravity. We specify the strength of gravity and also the direction it is pointing. 

In [None]:
# Add gravitational forces
gravitational_acc = -9.80665
snake_sim.add_forcing_to(shearable_rod).using(
 GravityForces, acc_gravity=np.array([0.0, gravitational_acc, 0.0])
)
print("Gravity now acting on shearable rod")

### Muscle Torques
A snake generates torque throughout its body through muscle activations. While these muscle activations are generated internally by the snake, it is simpler to treat them as applied external forces, allowing us to apply them to the rod in the same manner as the other external forces. 

You may notice that the muscle torque parameters appear to have special values. These are optimized coefficients for a snake gait. For information about how to do this optimization, see the [snake optimization example script](../ContinuumSnakeCase/continuum_snake.py).

In [None]:
# Define muscle torque parameters
period = 2.0
wave_length = 1.0
b_coeff = np.array([3.4e-3, 3.3e-3, 4.2e-3, 2.6e-3, 3.6e-3, 3.5e-3])

# Add muscle torques to the rod
snake_sim.add_forcing_to(shearable_rod).using(
 MuscleTorques,
 base_length=base_length,
 b_coeff=b_coeff,
 period=period,
 wave_number=2.0 * np.pi / (wave_length),
 phase_shift=0.0,
 rest_lengths=shearable_rod.rest_lengths,
 ramp_up_time=period,
 direction=normal,
 with_spline=True,
)
print("Muscle torques added to the rod")

### Anisotropic Friction Forces
The last force that needs to be added is the friction force between the snake and the ground. Snakes exhibits anisotropic friction where the friction coefficient is different in different directions. You can also define both static and kinematic friction coefficients. This is accomplished by defining some small velocity threshold `slip_velocity_tol` that defines the transitions between static and kinematic friction. 

In [None]:
# Define friction force parameters
origin_plane = np.array([0.0, -base_radius, 0.0])
normal_plane = normal
slip_velocity_tol = 1e-8
froude = 0.1
mu = base_length / (period * period * np.abs(gravitational_acc) * froude)
kinetic_mu_array = np.array(
 [1.0 * mu, 1.5 * mu, 2.0 * mu]
) # [forward, backward, sideways]
static_mu_array = 2 * kinetic_mu_array

# Add friction forces to the rod
snake_sim.add_forcing_to(shearable_rod).using(
 AnisotropicFrictionalPlane,
 k=1.0,
 nu=1e-6,
 plane_origin=origin_plane,
 plane_normal=normal_plane,
 slip_velocity_tol=slip_velocity_tol,
 static_mu_array=static_mu_array,
 kinetic_mu_array=kinetic_mu_array,
)
print("Friction forces added to the rod")

## Add Callback Function
The simulation is now setup, but before it is run, we want to define a callback function. A callback function allows us to record time-series data throughout the simulation. If you do not define a callback function, you will only have access to the final configuration of the system. If you want to be able to analyze how the system evolves over time, it is critical that you record the appropriate quantities. 

To create a callback function, begin with the `CallBackBaseClass`. You can then define which state quantities you wish to record by having them appended to the `self.callback_params` dictionary as well as how often you wish to save the data by defining `skip_step`.

In [None]:
# Add call backs
class ContinuumSnakeCallBack(CallBackBaseClass):
 """
 Call back function for continuum snake
 """

 def __init__(self, step_skip: int, callback_params: dict):
 CallBackBaseClass.__init__(self)
 self.every = step_skip
 self.callback_params = callback_params

 def make_callback(self, system, time, current_step: int):

 if current_step % self.every == 0:

 self.callback_params["time"].append(time)
 self.callback_params["step"].append(current_step)
 self.callback_params["position"].append(system.position_collection.copy())
 self.callback_params["velocity"].append(system.velocity_collection.copy())
 self.callback_params["avg_velocity"].append(
 system.compute_velocity_center_of_mass()
 )

 self.callback_params["center_of_mass"].append(
 system.compute_position_center_of_mass()
 )
 self.callback_params["curvature"].append(system.kappa.copy())

 return


pp_list = defaultdict(list)
snake_sim.collect_diagnostics(shearable_rod).using(
 ContinuumSnakeCallBack, step_skip=100, callback_params=pp_list
)
print("Callback function added to the simulator")

With the callback function added, we can now finalize the system and also define the time stepping parameters of the simulation such as the time step, final time, and time stepping algorithm to use. 

In [None]:
snake_sim.finalize()

final_time = 5.0 * period
total_steps = int(final_time / dt)
print("Total steps", total_steps)

timestepper = PositionVerlet()

Now all that is left is to run the simulation. Using the default parameters the simulation takes about 2-3 minutes to complete. 

In [None]:
integrate(timestepper, snake_sim, final_time, total_steps)

## Post-Process Data
With the simulation complete, we want to analyze the simulation. Because we added a callback function, we can analyze how the snake evolves over time. All of the data from the callback function is located in the `pp_list` dictionary. Here we will use this information to compute and plot the velocity of the snake in the forward, lateral, and normal directions. We do this by using a pre-written analysis function `compute_projected_velocity`.

In the plotted graph, you can see that it takes about one period for the snake to begin moving before rapidly reaching a steady gait over just 2-3 periods. We also see that the normal velocity is zero since we are only actuating the snake in a 2D plane. 

In [None]:
def compute_projected_velocity(plot_params: dict, period):
 import numpy as np

 time_per_period = np.array(plot_params["time"]) / period
 avg_velocity = np.array(plot_params["avg_velocity"])
 center_of_mass = np.array(plot_params["center_of_mass"])

 # Compute rod velocity in rod direction. We need to compute that because,
 # after snake starts to move it chooses an arbitrary direction, which does not
 # have to be initial tangent direction of the rod. Thus we need to project the
 # snake velocity with respect to its new tangent and roll direction, after that
 # we will get the correct forward and lateral speed. After this projection
 # lateral velocity of the snake has to be oscillating between + and - values with
 # zero mean.

 # Number of steps in one period.
 period_step = int(1.0 / (time_per_period[-1] - time_per_period[-2]))
 number_of_period = int(time_per_period[-1])

 # Center of mass position averaged in one period
 center_of_mass_averaged_over_one_period = np.zeros((number_of_period - 2, 3))
 for i in range(1, number_of_period - 2):
 # position of center of mass averaged over one period
 center_of_mass_averaged_over_one_period[i - 1] = np.mean(
 center_of_mass[(i + 1) * period_step : (i + 2) * period_step]
 - center_of_mass[(i + 0) * period_step : (i + 1) * period_step],
 axis=0,
 )
 # Average the rod directions over multiple periods and get the direction of the rod.
 direction_of_rod = np.mean(center_of_mass_averaged_over_one_period, axis=0)
 direction_of_rod /= np.linalg.norm(direction_of_rod, ord=2)

 # Compute the projected rod velocity in the direction of the rod
 velocity_mag_in_direction_of_rod = np.einsum(
 "ji,i->j", avg_velocity, direction_of_rod
 )
 velocity_in_direction_of_rod = np.einsum(
 "j,i->ji", velocity_mag_in_direction_of_rod, direction_of_rod
 )

 # Get the lateral or roll velocity of the rod after subtracting its projected
 # velocity in the direction of rod
 velocity_in_rod_roll_dir = avg_velocity - velocity_in_direction_of_rod

 # Compute the average velocity over the simulation, this can be used for optimizing snake
 # for fastest forward velocity. We start after first period, because of the ramping up happens
 # in first period.
 average_velocity_over_simulation = np.mean(
 velocity_in_direction_of_rod[period_step * 2 :], axis=0
 )

 return (
 velocity_in_direction_of_rod,
 velocity_in_rod_roll_dir,
 average_velocity_over_simulation[2],
 average_velocity_over_simulation[0],
 )


def compute_and_plot_velocity(plot_params: dict, period):
 from matplotlib import pyplot as plt
 from matplotlib.colors import to_rgb

 time_per_period = np.array(plot_params["time"]) / period
 avg_velocity = np.array(plot_params["avg_velocity"])

 [
 velocity_in_direction_of_rod,
 velocity_in_rod_roll_dir,
 _,
 _,
 ] = compute_projected_velocity(plot_params, period)

 fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
 plt.rcParams.update({"font.size": 16})
 ax = fig.add_subplot(111)
 ax.grid(which="minor", color="k", linestyle="--")
 ax.grid(which="major", color="k", linestyle="-")
 ax.plot(
 time_per_period[:], velocity_in_direction_of_rod[:, 2], "r-", label="forward"
 )
 ax.plot(
 time_per_period[:],
 velocity_in_rod_roll_dir[:, 0],
 c=to_rgb("xkcd:bluish"),
 label="lateral",
 )
 ax.plot(time_per_period[:], avg_velocity[:, 1], "k-", label="normal")
 ax.set_ylabel("Velocity [m/s]", fontsize=16)
 ax.set_xlabel("Time [s]", fontsize=16)
 fig.legend(prop={"size": 20})
 plt.show()
 plt.close(plt.gcf())


compute_and_plot_velocity(pp_list, period)

We can plot the curvature along the snake at different time instance and compare it with the sterotypical snake curvature function $7cos(2 \pi s)$.

In [None]:
def plot_curvature(
 plot_params: dict,
 rest_lengths,
 period,
):
 from matplotlib import pyplot as plt
 from matplotlib.colors import to_rgb

 s = np.cumsum(rest_lengths)
 L0 = s[-1]
 s = s / L0
 s = s[:-1].copy()
 x = np.linspace(0, 1, 100)
 curvature = np.array(plot_params["curvature"])
 time = np.array(plot_params["time"])
 peak_time = period * 0.125
 dt = time[1] - time[0]
 peak_idx = int(peak_time / (dt))
 plt.rcParams.update({"font.size": 16})
 fig = plt.figure(figsize=(10, 8), frameon=True, dpi=150)
 ax = fig.add_subplot(111)
 try:
 for i in range(peak_idx * 8, peak_idx * 8 * 2, peak_idx):
 ax.plot(s, curvature[i, 0, :] * L0, "k")
 except:
 print("Simulation time not long enough to plot curvature")
 ax.plot(
 x, 7 * np.cos(2 * np.pi * x - 0.80), "--", label="stereotypical snake curvature"
 )
 ax.set_ylabel(r"$\kappa$", fontsize=16)
 ax.set_xlabel("s", fontsize=16)
 ax.set_xlim(0, 1)
 ax.set_ylim(-10, 10)
 fig.legend(prop={"size": 16})
 plt.show()

 plt.close(plt.gcf())


plot_curvature(pp_list, shearable_rod.rest_lengths, period)

### Make Video of Snake Gait
Because we saved data of the snake's behavior, we can make a video of its movement. The easiest way to do this is to do this is to plot the snake's position at each time that the data was recorded and then stitch these plots together to form a video. 

note: ffmpeg is required for matplotlib to be able to create a video. More info on ffmepg [here](https://www.ffmpeg.org/).

In [None]:
from IPython.display import Video
from tqdm import tqdm


def plot_video_2D(plot_params: dict, video_name="video.mp4", margin=0.2, fps=15):
 from matplotlib import pyplot as plt
 import matplotlib.animation as manimation

 t = np.array(plot_params["time"])
 positions_over_time = np.array(plot_params["position"])
 total_time = int(np.around(t[..., -1], 1))
 total_frames = fps * total_time
 step = round(len(t) / total_frames)

 print("creating video -- this can take a few minutes")
 FFMpegWriter = manimation.writers["ffmpeg"]
 metadata = dict(title="Movie Test", artist="Matplotlib", comment="Movie support!")
 writer = FFMpegWriter(fps=fps, metadata=metadata)

 fig = plt.figure()
 ax = fig.add_subplot(111)
 plt.axis("equal")
 rod_lines_2d = ax.plot(
 positions_over_time[0][2], positions_over_time[0][0], linewidth=3
 )[0]
 ax.set_xlim([0 - margin, 3 + margin])
 ax.set_ylim([-1.5 - margin, 1.5 + margin])
 with writer.saving(fig, video_name, dpi=100):
 with plt.style.context("seaborn-v0_8-whitegrid"):
 for time in range(1, len(t), step):
 rod_lines_2d.set_xdata(positions_over_time[time][2])
 rod_lines_2d.set_ydata(positions_over_time[time][0])

 writer.grab_frame()
 plt.close(fig)


filename_video = "continuum_snake.mp4"
plot_video_2D(pp_list, video_name=filename_video, margin=0.2, fps=125)

Video("continuum_snake.mp4")

Finally, you can also plot the position of the snake from a 3D perspective. This is most helpful is you have a simulation that consists of more than planar motion. 

In [None]:
from IPython.display import Video


def plot_video(plot_params: dict, video_name="video.mp4", margin=0.2, fps=15):
 from matplotlib import pyplot as plt
 import matplotlib.animation as manimation
 from mpl_toolkits import mplot3d

 t = np.array(plot_params["time"])
 positions_over_time = np.array(plot_params["position"])
 total_time = int(np.around(t[..., -1], 1))
 total_frames = fps * total_time
 step = round(len(t) / total_frames)
 print("creating video -- this can take a few minutes")
 FFMpegWriter = manimation.writers["ffmpeg"]
 metadata = dict(title="Movie Test", artist="Matplotlib", comment="Movie support!")
 writer = FFMpegWriter(fps=fps, metadata=metadata)
 fig = plt.figure()
 ax = fig.add_subplot(111, projection="3d")
 ax.set_xlim(0 - margin, 3 + margin)
 ax.set_ylim(-1.5 - margin, 1.5 + margin)
 ax.set_zlim(0, 1)
 ax.view_init(elev=20, azim=-80)
 rod_lines_3d = ax.plot(
 positions_over_time[0][2],
 positions_over_time[0][0],
 positions_over_time[0][1],
 linewidth=3,
 )[0]
 with writer.saving(fig, video_name, dpi=100):
 with plt.style.context("seaborn-v0_8-whitegrid"):
 for time in range(1, len(t), step):
 rod_lines_3d.set_xdata(positions_over_time[time][2])
 rod_lines_3d.set_ydata(positions_over_time[time][0])
 rod_lines_3d.set_3d_properties(positions_over_time[time][1])

 writer.grab_frame()
 plt.close(fig)


filename_video = "continuum_snake_3d.mp4"
plot_video(pp_list, video_name=filename_video, margin=0.2, fps=60)

Video("continuum_snake_3d.mp4")