{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bd351344",
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://gitlab.inria.fr/rgautron/gym_dssat_pdi_baselines.git\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e62124bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "###########################\n",
      "## MODE: irrigation ##\n",
      "###########################\n",
      "Training PPO agent...\n",
      "Eval num_timesteps=1000, episode_reward=-48498.33 +/- 1135.83\n",
      "Episode length: 157.00 +/- 2.72\n",
      "New best mean reward!\n",
      "Eval num_timesteps=2000, episode_reward=-49417.25 +/- 1198.08\n",
      "Episode length: 159.20 +/- 2.93\n",
      "Training done\n"
     ]
    }
   ],
   "source": [
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.monitor import Monitor\n",
    "from stable_baselines3.common.callbacks import EvalCallback\n",
    "from gym_dssat_pdi_baselines.sb3_wrapper import GymDssatWrapper\n",
    "from gym_dssat_pdi.envs.utils import utils as dssat_utils\n",
    "import gym\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    try:\n",
    "        for dir in ['./output', 'logs']:\n",
    "            dssat_utils.make_folder(dir)\n",
    "\n",
    "        # Create environment\n",
    "        env_args = {\n",
    "            'run_dssat_location': '/home/jovyan/gym_dssat_pdi/run_dssat',\n",
    "            'log_saving_path': './logs/dssat_pdi.log',\n",
    "            # 'mode': 'fertilization',\n",
    "            'mode': 'irrigation',\n",
    "            'seed': 123,\n",
    "            'random_weather': True,\n",
    "        }\n",
    "\n",
    "        print(f'###########################\\n## MODE: {env_args[\"mode\"]} ##\\n###########################')\n",
    "\n",
    "        env = Monitor(GymDssatWrapper(gym.make('gym_dssat_pdi:GymDssatPdi-v0', **env_args)))\n",
    "\n",
    "        # Training arguments for PPO agent\n",
    "        ppo_args = {\n",
    "            'seed': 123,  # seed training for reproducibility\n",
    "            'gamma': 1,\n",
    "        }\n",
    "\n",
    "        # Create the agent\n",
    "        ppo_agent = PPO('MlpPolicy', env, **ppo_args)\n",
    "\n",
    "        # path to save best model found\n",
    "        path = f'./output/{env_args[\"mode\"]}'\n",
    "\n",
    "        # eval callback\n",
    "        eval_freq = 1000\n",
    "        eval_env_args = {**env_args, 'seed': 345}\n",
    "        eval_env = Monitor(GymDssatWrapper(gym.make('GymDssatPdi-v0', **eval_env_args)))\n",
    "        eval_callback = EvalCallback(eval_env,\n",
    "                                     eval_freq=eval_freq,\n",
    "                                     best_model_save_path=f'{path}',\n",
    "                                     deterministic=True,\n",
    "                                     n_eval_episodes=10)\n",
    "\n",
    "        # Train\n",
    "        # total_timesteps = 500_000\n",
    "        #total_timesteps = 1_000_000\n",
    "        total_timesteps = 500\n",
    "        print('Training PPO agent...')\n",
    "        ppo_agent.learn(total_timesteps=total_timesteps, callback=eval_callback)\n",
    "        # ppo_agent.save(f'{path}/final_model')\n",
    "        print('Training done')\n",
    "    finally:\n",
    "        env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53cedb57",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}