import os import random import sys import time from copy import copy from dataclasses import asdict, dataclass, replace from pathlib import Path from typing import Optional, Tuple import numpy as np import pickle import tyro from mujoco_utils import composer_utils from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CallbackList from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv import lr_scheduler from apr import APRConfig, APRModule, FEATURE_JOINTS from apr.apr_utils import _ANGLE_RANGES, _ANGLE_MINS, _ANGLE_MAXS, NUM_ANGLE_FEATURES from wrappers.apr_wrapper import _extract_hand_features, _build_joint_name_map from apr_callback import APRCallback from utils import PROJECT_ROOT, make_envs from wrappers.apr_wrapper import APRGymWrapper REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) import torch try: import wandb except ImportError: wandb = None if torch.cuda.is_available(): torch.cuda.init() _ = torch.zeros(1, device="cuda") @dataclass(frozen=True) class Args: root_dir: str = "./robopianist_apr/experiments" seed: int = 42 max_steps: int = 1_000_000 warmstart_steps: int = 5_000 log_interval: int = 1_000 eval_interval: int = 10_000 eval_episodes: int = 1 batch_size: int = 256 discount: float = 0.99 tqdm_bar: bool = False replay_capacity: int = 1_000_000 project: str = "robopianist-apr" entity: str = "" name: str = "" tags: str = "" notes: str = "" mode: str = "online" use_wandb: bool = True environment_name: str = "RoboPianist" n_steps_lookahead: int = 10 trim_silence: bool = False gravity_compensation: bool = False reduced_action_space: bool = False control_timestep: float = 0.05 stretch_factor: float = 1.0 shift_factor: int = 0 wrong_press_termination: bool = False disable_fingering_reward: bool = False disable_forearm_reward: bool = False disable_colorization: bool = False disable_hand_collisions: bool = True primitive_fingertip_collisions: bool = False frame_stack: int = 4 clip: bool = True record_dir: Optional[Path] = None record_every: int = 1 record_resolution: Tuple[int, int] = (480, 640) camera_id: Optional[str | int] = "piano/back" action_reward_observation: bool = False deepmimic: bool = False mimic_task: str = "GlimpseOfUs" midi_start_from: int = 0 residual_action: bool = True use_note_trajectory: bool = True mimic_z_axis: bool = False rsi: bool = False curriculum: bool = False eval_only: bool = False num_envs: int = 24 pretrained: Optional[Path] = None initial_lr: float = 3e-4 lr_decay_rate: float = 0.999 residual_factor: float = 0.03 n_steps: int = 512 total_iters: int = 2000 use_apr: bool = True w_G: float = 0.7 w_S: float = 0.3 w_gp: float = 10.0 lr_D: float = 1e-4 apr_batch_size: int = 256 apr_buffer_size: int = 10_000_000 apr_update_freq: int = 1 apr_n_updates: int = 1 apr_warmup_iters: int = 0 apr_style_reward_scale: float = 0.25 apr_expert_dir: str = "" apr_expert_songs: str = "improvised" apr_mask_z: bool = False apr_mask_vel: bool = False def get_inner_env(env, attr_name: str): """Recursively find an inner environment exposing ``attr_name``. Args: env: The wrapped environment object. attr_name: Attribute name to search for. Returns: The first inner environment that has ``attr_name``; otherwise ``None``. """ current = env while current is not None: if hasattr(current, attr_name): return current if hasattr(current, "env"): current = current.env elif hasattr(current, "_environment"): current = current._environment else: break return None def compute_cpsi(joint_angles: np.ndarray, q_min: np.ndarray, q_max: np.ndarray) -> float: """Continuous Posture Strain Index. Args: joint_angles: (T, D) raw joint angles in radians. q_min: (D,) joint lower limits. q_max: (D,) joint upper limits. Returns: cPSI value in [0, 1]. Lower = more neutral posture. """ mid = (q_max + q_min) / 2.0 span = q_max - q_min span = np.where(span < 1e-8, 1.0, span) deviation = (2.0 * joint_angles - (q_max + q_min)) / span return float(np.mean(deviation ** 2)) _PIP_DIP_PAIRS = [ (FEATURE_JOINTS.index("FFJ2"), FEATURE_JOINTS.index("FFJ1")), (FEATURE_JOINTS.index("MFJ2"), FEATURE_JOINTS.index("MFJ1")), (FEATURE_JOINTS.index("RFJ2"), FEATURE_JOINTS.index("RFJ1")), (FEATURE_JOINTS.index("LFJ2"), FEATURE_JOINTS.index("LFJ1")), ] _BSE_K = 2.0 / 3.0 _MCP_PIP_DIP_TRIPLES = [ (FEATURE_JOINTS.index("FFJ3"), FEATURE_JOINTS.index("FFJ2"), FEATURE_JOINTS.index("FFJ1")), (FEATURE_JOINTS.index("MFJ3"), FEATURE_JOINTS.index("MFJ2"), FEATURE_JOINTS.index("MFJ1")), (FEATURE_JOINTS.index("RFJ3"), FEATURE_JOINTS.index("RFJ2"), FEATURE_JOINTS.index("RFJ1")), (FEATURE_JOINTS.index("LFJ3"), FEATURE_JOINTS.index("LFJ2"), FEATURE_JOINTS.index("LFJ1")), ] def compute_bse(joint_angles: np.ndarray) -> float: """Biomechanical Synergy Error. Measures deviation from the natural tendon coupling: q_DIP ≈ (2/3) * q_PIP. Args: joint_angles: (T, D) raw joint angles in radians, columns ordered by FEATURE_JOINTS. Returns: BSE value. Lower = more natural coupling. """ T = joint_angles.shape[0] if T == 0: return 0.0 total = 0.0 for pip_idx, dip_idx in _PIP_DIP_PAIRS: q_pip = joint_angles[:, pip_idx] q_dip = joint_angles[:, dip_idx] total += np.sum((q_dip - _BSE_K * q_pip) ** 2) return float(total / (T * len(_PIP_DIP_PAIRS))) def compute_fac(joint_angles: np.ndarray, lam: float = 0.1) -> float: """Finger Arc Continuity. Penalises "Z-shaped" or "S-shaped" finger profiles by measuring the discrete Laplacian (second-order difference) across MCP -> PIP -> DIP, plus a penalty for hyperextension (negative angles). FAC = (1/(T*F)) Σ_t Σ_f (q_MCP - 2*q_PIP + q_DIP)^2 + λ * Σ Penalty(q<0) Args: joint_angles: (T, D) raw joint angles in radians, columns ordered by FEATURE_JOINTS. lam: penalty weight for negative (hyperextended) joint angles. Returns: FAC value. Lower = smoother finger arc. """ T = joint_angles.shape[0] if T == 0: return 0.0 F = len(_MCP_PIP_DIP_TRIPLES) laplacian_sum = 0.0 penalty_sum = 0.0 for mcp_idx, pip_idx, dip_idx in _MCP_PIP_DIP_TRIPLES: q_mcp = joint_angles[:, mcp_idx] q_pip = joint_angles[:, pip_idx] q_dip = joint_angles[:, dip_idx] laplacian_sum += np.sum((q_mcp - 2.0 * q_pip + q_dip) ** 2) for q in (q_mcp, q_pip, q_dip): penalty_sum += np.sum(np.maximum(-q, 0.0) ** 2) return float(laplacian_sum / (T * F) + lam * penalty_sum / (T * F)) def _compute_gap(expert_angles: np.ndarray, policy_angles: np.ndarray) -> dict: """Helper: gap between expert and policy angle arrays (both normalized).""" expert_mean = expert_angles.mean(axis=0) policy_mean = policy_angles.mean(axis=0) gap_norm = np.abs(expert_mean - policy_mean) gap_rad = gap_norm * _ANGLE_RANGES gap_deg = np.degrees(gap_rad) return { "mean_gap_norm": float(gap_norm.mean()), "mean_gap_deg": float(gap_deg.mean()), "per_joint_gap_norm": gap_norm, "per_joint_gap_deg": gap_deg, } def compute_angle_gap_apr(apr_module) -> Optional[dict]: """Compute per-joint and mean angle gap between expert and policy buffers. Returns a single-hand aggregated dict with normalized gap, degree gap, and per-joint breakdowns; returns ``None`` if either buffer is empty. """ buf = apr_module.buffer if buf.expert_size == 0 or buf.policy_size == 0: return None expert_angles = buf.expert_phi_s[:, :NUM_ANGLE_FEATURES].cpu().numpy() policy_angles = buf.policy_phi_s[:buf.policy_size, :NUM_ANGLE_FEATURES] return _compute_gap(expert_angles, policy_angles) def get_env_with_apr(args, record_dir: Optional[Path] = None): """Build an APR-enabled training environment. Args: args: Runtime arguments. record_dir: Optional output directory for video recording. Returns: A gym-compatible environment with APR wrappers enabled when requested. """ import dm_env_wrappers as wrappers import robopianist.wrappers as robopianist_wrappers import wrappers as pianomime_wrappers import piano_with_shadow_hands_res from robopianist import music left_hand_action_list = np.load( PROJECT_ROOT / f"dataset/high_level_trajectories/{args.mimic_task}_left_hand_action_list.npy" ) right_hand_action_list = np.load( PROJECT_ROOT / f"dataset/high_level_trajectories/{args.mimic_task}_right_hand_action_list.npy" ) length = left_hand_action_list.shape[0] trim = not (length >= 600 or length < 500) if args.use_note_trajectory: with open(PROJECT_ROOT / f"dataset/notes/{args.mimic_task}.pkl", "rb") as f: note_traj = pickle.load(f) task = piano_with_shadow_hands_res.PianoWithShadowHandsResidual( note_trajectory=note_traj, change_color_on_activation=True, wrong_press_termination=args.wrong_press_termination, trim_silence=trim, control_timestep=0.05, disable_hand_collisions=True, disable_forearm_reward=True, disable_fingering_reward=False, midi_start_from=0, n_steps_lookahead=10, gravity_compensation=True, reduced_action_space=False, residual_factor=args.residual_factor, curriculum=args.curriculum, ) else: task = piano_with_shadow_hands_res.PianoWithShadowHandsResidual( midi=music.load(args.mimic_task), change_color_on_activation=True, wrong_press_termination=args.wrong_press_termination, trim_silence=trim, control_timestep=0.05, disable_hand_collisions=True, disable_forearm_reward=True, disable_fingering_reward=False, midi_start_from=0, n_steps_lookahead=10, gravity_compensation=True, reduced_action_space=False, residual_factor=args.residual_factor, ) env = composer_utils.Environment( recompile_physics=False, task=task, strip_singleton_obs_buffer_dim=True ) if args.deepmimic: env = pianomime_wrappers.DeepMimicWrapper( environment=env, demonstrations_lh=left_hand_action_list, demonstrations_rh=right_hand_action_list, remove_goal_observation=False, mimic_z_axis=args.mimic_z_axis, n_steps_lookahead=args.n_steps_lookahead, ) if args.residual_action: env = pianomime_wrappers.ResidualWrapper( environment=env, demonstrations_lh=left_hand_action_list, demonstrations_rh=right_hand_action_list, demo_ctrl_timestep=0.05, rsi=args.rsi, ) if args.use_apr: from wrappers.apr_wrapper import APRWrapper env = APRWrapper(environment=env) if record_dir is not None: env = robopianist_wrappers.PianoSoundVideoWrapper( environment=env, record_dir=record_dir, record_every=args.record_every, camera_id=args.camera_id, height=args.record_resolution[0], width=args.record_resolution[1], ) env = wrappers.EpisodeStatisticsWrapper( environment=env, deque_size=args.record_every ) env = robopianist_wrappers.MidiEvaluationWrapper( environment=env, deque_size=args.record_every ) else: env = wrappers.EpisodeStatisticsWrapper(environment=env, deque_size=1) env = robopianist_wrappers.MidiEvaluationWrapper( environment=env, deque_size=1 ) if args.action_reward_observation: env = wrappers.ObservationActionRewardWrapper(env) env = wrappers.ConcatObservationWrapper(env) if args.frame_stack > 1: env = wrappers.FrameStackingWrapper( env, num_frames=args.frame_stack, flatten=True ) env = wrappers.CanonicalSpecWrapper(env, clip=args.clip) env = wrappers.SinglePrecisionWrapper(env) env = wrappers.DmControlWrapper(env) env = robopianist_wrappers.Dm2GymWrapper(env) if args.use_apr: env = APRGymWrapper(env) return env def main(args: Args) -> None: if args.name: run_name = args.name else: apr_suffix = "-APR" if args.use_apr else "" run_name = f"PPO{apr_suffix}-{args.mimic_task}-{args.seed}-{time.time()}" experiment_dir = Path(args.root_dir) / run_name experiment_dir.mkdir(parents=True, exist_ok=True) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.eval_only: assert args.pretrained is not None, "--pretrained is required for --eval_only" eval_args = copy(args) eval_args = replace(eval_args, rsi=False) eval_args = replace(eval_args, camera_id=args.camera_id) record_dir = Path("./videos") record_dir.mkdir(parents=True, exist_ok=True) record_env = get_env_with_apr(eval_args, record_dir=str(record_dir)) model = PPO.load(str(args.pretrained)) dm_env = get_inner_env(record_env, "physics") lh_map = _build_joint_name_map(dm_env.physics, "lh_") if dm_env else {} rh_map = _build_joint_name_map(dm_env.physics, "rh_") if dm_env else {} obs, _ = record_env.reset() total_reward = 0.0 steps = 0 all_hand_angles: list = [] while True: action, _ = model.predict(obs, deterministic=True) obs, reward, done, _, info = record_env.step(action) total_reward += reward steps += 1 if dm_env is not None: a_l, _, _ = _extract_hand_features( dm_env.physics, lh_map, dm_env.task.left_hand.fingertip_sites, "lh_shadow_hand/wrist_site", ) a_r, _, _ = _extract_hand_features( dm_env.physics, rh_map, dm_env.task.right_hand.fingertip_sites, "rh_shadow_hand/wrist_site", ) all_hand_angles.append(a_l.copy()) all_hand_angles.append(a_r.copy()) if done: break music_env = get_inner_env(record_env, "get_musical_metrics") metrics = music_env.get_musical_metrics() if music_env else {} f1 = metrics.get("f1", 0.0) print(f"\nEval done: {steps} steps, reward={total_reward:.2f}, F1={f1:.4f}") print(f"Musical metrics: {metrics}") if all_hand_angles: angles = np.array(all_hand_angles) q_min = _ANGLE_MINS q_max = _ANGLE_MAXS cpsi = compute_cpsi(angles, q_min, q_max) bse = compute_bse(angles) fac = compute_fac(angles) print(f"\n--- Naturalness Metrics ---") print(f"cPSI (lower=more neutral): hand_avg={cpsi:.4f}") print(f"BSE (lower=more natural): hand_avg={bse:.4f}") print(f"FAC (lower=smoother arc): hand_avg={fac:.4f}") video_env = get_inner_env(record_env, "latest_filename") if video_env is not None: print(f"\nVideo saved: {video_env.latest_filename}") record_env.close() return if args.mode == "offline": raise ValueError("W&B offline mode has been removed. Use mode='online' or mode='disabled'.") if args.mode not in {"online", "disabled"}: raise ValueError(f"Invalid W&B mode: {args.mode}. Use 'online' or 'disabled'.") use_wandb = args.use_wandb and args.mode == "online" if use_wandb: if wandb is None: raise ImportError("use_wandb=True wandb wandb use_wandb") wandb.init( entity=os.environ.get("WANDB_ENTITY"), project=os.environ.get("WANDB_PROJECT", "robopianist-apr"), config=asdict(args), name=run_name, mode="online", ) print(f"W&B run: {wandb.run.url}") eval_args = copy(args) eval_args = replace(eval_args, rsi=False) eval_env = get_env_with_apr(eval_args, record_dir=None) def make_env(): env = get_env_with_apr(args) return Monitor(env) vec_env = SubprocVecEnv( [make_envs(make_env, i) for i in range(args.num_envs)], start_method="fork" ) apr_module = None apr_callback = None if args.use_apr: apr_config = APRConfig( w_G=args.w_G, w_S=args.w_S, w_gp=args.w_gp, batch_size_disc=args.apr_batch_size, replay_buffer_size=args.apr_buffer_size, lr_D=args.lr_D, gamma=args.discount, style_reward_scale=args.apr_style_reward_scale, device="cuda" if torch.cuda.is_available() else "cpu", ) apr_module = APRModule(apr_config) expert_dir = args.apr_expert_dir or str(REPO_ROOT / "dataapr") song_names = ( [s.strip() for s in args.apr_expert_songs.split(",") if s.strip()] if args.apr_expert_songs else None ) apr_module.load_expert_retarget_data( expert_dir, song_names=song_names, verbose=True, ) if args.apr_mask_z or args.apr_mask_vel: import wrappers.apr_wrapper as _aw buf = apr_module.buffer if args.apr_mask_vel: _aw.MASK_VELOCITY = True if buf.expert_size > 0: buf.expert_phi_s[:, 18:36] = 0.0 buf.expert_phi_s_next[:, 18:36] = 0.0 print("[APR] Velocity features MASKED (cols 18:36 zeroed)") if args.apr_mask_z: _aw.MASK_FINGERTIP_Z = True if buf.expert_size > 0: buf.expert_phi_s[:, 36:41] = 0.0 buf.expert_phi_s_next[:, 36:41] = 0.0 print("[APR] Fingertip Z features MASKED (cols 36:41 zeroed)") print("\n" + "=" * 60) print(" APR Configuration Summary") print("=" * 60) print(f" Expert transitions : {apr_module.expert_buffer_size} (per-hand, pooled)") print(f" Expert songs : {song_names or 'ALL'}") print(f" Expert dir : {expert_dir}") print(f" Feature dim (hand) : {apr_config.hand_feature_dim}") print(f" Disc input dim : {apr_config.disc_input_dim}") print(f" Hidden dims : {apr_config.hidden_dims}") print(f" w_G (task) : {apr_config.w_G}") print(f" w_S (style) : {apr_config.w_S}") print(f" w_gp (grad pen) : {apr_config.w_gp}") print(f" lr_D : {apr_config.lr_D}") print(f" style_reward_scale : {apr_config.style_reward_scale}") print(f" batch_size_disc : {apr_config.batch_size_disc}") print(f" replay_buffer_size : {apr_config.replay_buffer_size}") print(f" gamma : {apr_config.gamma}") print(f" device : {apr_config.device}") print(f" warmup_iters : {args.apr_warmup_iters}") print(f" update_freq : {args.apr_update_freq}") print(f" n_disc_updates : {args.apr_n_updates}") buf = apr_module.buffer if buf.expert_size > 0: e_angles = buf.expert_phi_s[:, :NUM_ANGLE_FEATURES].cpu().numpy() print(f"\n Expert angle stats (norm [0,1]):") print(f" mean = {e_angles.mean(axis=0).round(3).tolist()}") print(f" std = {e_angles.std(axis=0).round(3).tolist()}") print("=" * 60 + "\n") apr_callback = APRCallback( apr_module=apr_module, verbose=0, update_discriminator_every=args.apr_update_freq, n_disc_updates=args.apr_n_updates, log_interval=10, save_path=str(experiment_dir / "apr_checkpoints"), warmup_iters=args.apr_warmup_iters, total_iters=args.total_iters, ) lr_scheduler_instance = lr_scheduler.LR_Scheduler( initial_lr=args.initial_lr, decay_rate=args.lr_decay_rate, ) policy_kwargs = dict( activation_fn=torch.nn.GELU, net_arch=dict(pi=[1024, 256], vf=[1024, 256]) ) model = PPO( "MlpPolicy", vec_env, n_epochs=10, n_steps=args.n_steps, batch_size=1024, learning_rate=lr_scheduler_instance.lr_schedule, policy_kwargs=policy_kwargs, verbose=0, gamma=args.discount, clip_range=0.2, ) if args.pretrained is not None: custom_objects = {"learning_rate": lr_scheduler_instance.lr_schedule} model = PPO.load(args.pretrained, env=vec_env, custom_objects=custom_objects) best_f1 = -np.inf callbacks = [] if apr_callback is not None: callbacks.append(apr_callback) callback_list = CallbackList(callbacks) if callbacks else None print(f"Training: {args.total_iters} iters, APR={args.use_apr}, Song={args.mimic_task}") try: for i in range(args.total_iters): model.learn( total_timesteps=args.n_steps * args.num_envs, progress_bar=True, reset_num_timesteps=False, callback=callback_list, ) iter_log: dict = {} disc: dict = {} if args.use_apr and apr_module is not None and apr_callback is not None: disc = apr_callback.last_disc_loss_dict or {} style_r = (apr_callback._style_rewards_history[-1] if apr_callback._style_rewards_history else 0.0) w_G_cur = apr_callback.current_w_G w_S_cur = apr_callback.current_w_S angle_info = compute_angle_gap_apr(apr_module) gap_deg = angle_info["mean_gap_deg"] if angle_info else 0.0 gap_norm = angle_info["mean_gap_norm"] if angle_info else 0.0 print( f"[iter {i}/{args.total_iters}] " f"style_r={style_r:.4f} " f"D_loss={disc.get('loss_D', 0):.4f} " f"eAcc={disc.get('expert_acc', 0):.1%} " f"pAcc={disc.get('policy_acc', 0):.1%} " f"w_G={w_G_cur:.3f} w_S={w_S_cur:.3f} " f"gap_hand_avg={gap_deg:.2f}\u00b0" ) iter_log.update({ "apr/style_reward": style_r, "apr/w_G": w_G_cur, "apr/w_S": w_S_cur, "apr/angle_gap_deg": gap_deg, "apr/angle_gap_norm": gap_norm, }) if disc: iter_log.update({ "apr/disc_loss": disc.get("loss_D", 0), "apr/d_expert": disc.get("d_expert_mean", 0), "apr/d_policy": disc.get("d_policy_mean", 0), "apr/expert_acc": disc.get("expert_acc", 0), "apr/policy_acc": disc.get("policy_acc", 0), "apr/loss_gp": disc.get("loss_gp", 0), }) if angle_info: for j, jname in enumerate(FEATURE_JOINTS): iter_log[f"apr/joint_gap_deg/{jname}"] = float( angle_info["per_joint_gap_deg"][j] ) if i % 50 == 0: print(f"[iter {i}] Evaluating...", end=" ", flush=True) obs, _ = eval_env.reset() eval_steps = 0 while True: action, _state = model.predict(obs, deterministic=True) obs, reward, done, _, info = eval_env.step(action) eval_steps += 1 if done: break print(f"done ({eval_steps} steps)") music_env = get_inner_env(eval_env, "get_musical_metrics") f1 = 0.0 if music_env is not None: f1 = music_env.get_musical_metrics().get("f1", 0.0) iter_log["eval/f1"] = float(f1) print( f"[iter {i}/{args.total_iters}] F1: {f1:.4f} " f"| Expert Acc: {disc.get('expert_acc', 0):.2%} " f"| Policy Acc: {disc.get('policy_acc', 0):.2%}" ) if f1 > best_f1: print(f" -> New best F1: {best_f1:.4f} -> {f1:.4f}") best_f1 = f1 model.save(f"./robopianist_apr/ckpts/{run_name}_best") if args.use_apr and apr_module is not None: apr_module.save(experiment_dir / "apr_checkpoints" / "apr_best.pt") if use_wandb and iter_log: wandb.log(iter_log, step=i) except KeyboardInterrupt: print("\nTraining interrupted by user") print("=" * 60) print("Final Evaluation (with video recording)") print("=" * 60) final_eval_env = get_env_with_apr(eval_args, record_dir=experiment_dir / "final_eval") best_ckpt = Path(f"./robopianist_apr/ckpts/{run_name}_best") final_ckpt = best_ckpt if best_ckpt.exists() else args.pretrained if final_ckpt is None: raise FileNotFoundError( f": {best_ckpt}. _best checkpoint --pretrained." ) print(f"Loading final eval model from: {final_ckpt}") model = PPO.load(str(final_ckpt), env=vec_env) obs, _ = final_eval_env.reset() actions = [] rewards = 0 while True: action, _states = model.predict(obs, deterministic=True) actions.append(action) obs, reward, done, _, info = final_eval_env.step(action) rewards += reward if done: break print(f"Total reward: {rewards}") video_env = get_inner_env(final_eval_env, "latest_filename") if video_env is not None and hasattr(video_env, "latest_filename"): print(f"Video: {video_env.latest_filename}") music_env = get_inner_env(final_eval_env, "get_musical_metrics") if music_env is not None: print(f"Musical metrics: {music_env.get_musical_metrics()}") final_eval_env.close() actions = np.array(actions) actions_dir = Path("./trained_songs") / args.mimic_task actions_dir.mkdir(parents=True, exist_ok=True) np.save(str(actions_dir / f"actions_{args.mimic_task}_apr"), actions) if args.use_apr and apr_module is not None: apr_module.save(experiment_dir / "apr_final.pt") del model vec_env.close() eval_env.close() if use_wandb: wandb.finish() if __name__ == "__main__": main(tyro.cli(Args, description=__doc__))