{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "bd351344", "metadata": {}, "outputs": [], "source": [ "!git clone https://gitlab.inria.fr/rgautron/gym_dssat_pdi_baselines.git\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "e62124bd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "###########################\n", "## MODE: irrigation ##\n", "###########################\n", "Training PPO agent...\n", "Eval num_timesteps=1000, episode_reward=-48498.33 +/- 1135.83\n", "Episode length: 157.00 +/- 2.72\n", "New best mean reward!\n", "Eval num_timesteps=2000, episode_reward=-49417.25 +/- 1198.08\n", "Episode length: 159.20 +/- 2.93\n", "Training done\n" ] } ], "source": [ "from stable_baselines3 import PPO\n", "from stable_baselines3.common.monitor import Monitor\n", "from stable_baselines3.common.callbacks import EvalCallback\n", "from gym_dssat_pdi_baselines.sb3_wrapper import GymDssatWrapper\n", "from gym_dssat_pdi.envs.utils import utils as dssat_utils\n", "import gym\n", "\n", "if __name__ == '__main__':\n", " try:\n", " for dir in ['./output', 'logs']:\n", " dssat_utils.make_folder(dir)\n", "\n", " # Create environment\n", " env_args = {\n", " 'run_dssat_location': '/home/jovyan/gym_dssat_pdi/run_dssat',\n", " 'log_saving_path': './logs/dssat_pdi.log',\n", " # 'mode': 'fertilization',\n", " 'mode': 'irrigation',\n", " 'seed': 123,\n", " 'random_weather': True,\n", " }\n", "\n", " print(f'###########################\\n## MODE: {env_args[\"mode\"]} ##\\n###########################')\n", "\n", " env = Monitor(GymDssatWrapper(gym.make('gym_dssat_pdi:GymDssatPdi-v0', **env_args)))\n", "\n", " # Training arguments for PPO agent\n", " ppo_args = {\n", " 'seed': 123, # seed training for reproducibility\n", " 'gamma': 1,\n", " }\n", "\n", " # Create the agent\n", " ppo_agent = PPO('MlpPolicy', env, **ppo_args)\n", "\n", " # path to save best model found\n", " path = f'./output/{env_args[\"mode\"]}'\n", "\n", " # eval callback\n", " eval_freq = 1000\n", " eval_env_args = {**env_args, 'seed': 345}\n", " eval_env = Monitor(GymDssatWrapper(gym.make('GymDssatPdi-v0', **eval_env_args)))\n", " eval_callback = EvalCallback(eval_env,\n", " eval_freq=eval_freq,\n", " best_model_save_path=f'{path}',\n", " deterministic=True,\n", " n_eval_episodes=10)\n", "\n", " # Train\n", " # total_timesteps = 500_000\n", " #total_timesteps = 1_000_000\n", " total_timesteps = 500\n", " print('Training PPO agent...')\n", " ppo_agent.learn(total_timesteps=total_timesteps, callback=eval_callback)\n", " # ppo_agent.save(f'{path}/final_model')\n", " print('Training done')\n", " finally:\n", " env.close()" ] }, { "cell_type": "code", "execution_count": null, "id": "53cedb57", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }