Path: blob/master/finrl/agents/stablebaselines3/hyperparams_opt.py
732 views
from __future__ import annotations12from typing import Any3from typing import Dict45import numpy as np6import optuna7from stable_baselines3.common.noise import NormalActionNoise8from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise9from torch import nn as nn10from utils import linear_schedule111213def sample_ppo_params(trial: optuna.Trial) -> dict[str, Any]:14"""15Sampler for PPO hyperparams.1617:param trial:18:return:19"""20batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64, 128, 256, 512])21n_steps = trial.suggest_categorical(22"n_steps", [8, 16, 32, 64, 128, 256, 512, 1024, 2048]23)24gamma = trial.suggest_categorical(25"gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]26)27learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)28lr_schedule = "constant"29# Uncomment to enable learning rate schedule30# lr_schedule = trial.suggest_categorical('lr_schedule', ['linear', 'constant'])31ent_coef = trial.suggest_loguniform("ent_coef", 0.00000001, 0.1)32clip_range = trial.suggest_categorical("clip_range", [0.1, 0.2, 0.3, 0.4])33n_epochs = trial.suggest_categorical("n_epochs", [1, 5, 10, 20])34gae_lambda = trial.suggest_categorical(35"gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0]36)37max_grad_norm = trial.suggest_categorical(38"max_grad_norm", [0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 5]39)40vf_coef = trial.suggest_uniform("vf_coef", 0, 1)41net_arch = trial.suggest_categorical("net_arch", ["small", "medium"])42# Uncomment for gSDE (continuous actions)43# log_std_init = trial.suggest_uniform("log_std_init", -4, 1)44# Uncomment for gSDE (continuous action)45# sde_sample_freq = trial.suggest_categorical("sde_sample_freq", [-1, 8, 16, 32, 64, 128, 256])46# Orthogonal initialization47ortho_init = False48# ortho_init = trial.suggest_categorical('ortho_init', [False, True])49# activation_fn = trial.suggest_categorical('activation_fn', ['tanh', 'relu', 'elu', 'leaky_relu'])50activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"])5152# TODO: account when using multiple envs53if batch_size > n_steps:54batch_size = n_steps5556if lr_schedule == "linear":57learning_rate = linear_schedule(learning_rate)5859# Independent networks usually work best60# when not working with images61net_arch = {62"small": [dict(pi=[64, 64], vf=[64, 64])],63"medium": [dict(pi=[256, 256], vf=[256, 256])],64}[net_arch]6566activation_fn = {67"tanh": nn.Tanh,68"relu": nn.ReLU,69"elu": nn.ELU,70"leaky_relu": nn.LeakyReLU,71}[activation_fn]7273return {74"n_steps": n_steps,75"batch_size": batch_size,76"gamma": gamma,77"learning_rate": learning_rate,78"ent_coef": ent_coef,79"clip_range": clip_range,80"n_epochs": n_epochs,81"gae_lambda": gae_lambda,82"max_grad_norm": max_grad_norm,83"vf_coef": vf_coef,84# "sde_sample_freq": sde_sample_freq,85"policy_kwargs": dict(86# log_std_init=log_std_init,87net_arch=net_arch,88activation_fn=activation_fn,89ortho_init=ortho_init,90),91}929394def sample_trpo_params(trial: optuna.Trial) -> dict[str, Any]:95"""96Sampler for TRPO hyperparams.9798:param trial:99:return:100"""101batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64, 128, 256, 512])102n_steps = trial.suggest_categorical(103"n_steps", [8, 16, 32, 64, 128, 256, 512, 1024, 2048]104)105gamma = trial.suggest_categorical(106"gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]107)108learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)109lr_schedule = "constant"110# Uncomment to enable learning rate schedule111# lr_schedule = trial.suggest_categorical('lr_schedule', ['linear', 'constant'])112# line_search_shrinking_factor = trial.suggest_categorical("line_search_shrinking_factor", [0.6, 0.7, 0.8, 0.9])113n_critic_updates = trial.suggest_categorical(114"n_critic_updates", [5, 10, 20, 25, 30]115)116cg_max_steps = trial.suggest_categorical("cg_max_steps", [5, 10, 20, 25, 30])117# cg_damping = trial.suggest_categorical("cg_damping", [0.5, 0.2, 0.1, 0.05, 0.01])118target_kl = trial.suggest_categorical(119"target_kl", [0.1, 0.05, 0.03, 0.02, 0.01, 0.005, 0.001]120)121gae_lambda = trial.suggest_categorical(122"gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0]123)124net_arch = trial.suggest_categorical("net_arch", ["small", "medium"])125# Uncomment for gSDE (continuous actions)126# log_std_init = trial.suggest_uniform("log_std_init", -4, 1)127# Uncomment for gSDE (continuous action)128# sde_sample_freq = trial.suggest_categorical("sde_sample_freq", [-1, 8, 16, 32, 64, 128, 256])129# Orthogonal initialization130ortho_init = False131# ortho_init = trial.suggest_categorical('ortho_init', [False, True])132# activation_fn = trial.suggest_categorical('activation_fn', ['tanh', 'relu', 'elu', 'leaky_relu'])133activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"])134135# TODO: account when using multiple envs136if batch_size > n_steps:137batch_size = n_steps138139if lr_schedule == "linear":140learning_rate = linear_schedule(learning_rate)141142# Independent networks usually work best143# when not working with images144net_arch = {145"small": [dict(pi=[64, 64], vf=[64, 64])],146"medium": [dict(pi=[256, 256], vf=[256, 256])],147}[net_arch]148149activation_fn = {150"tanh": nn.Tanh,151"relu": nn.ReLU,152"elu": nn.ELU,153"leaky_relu": nn.LeakyReLU,154}[activation_fn]155156return {157"n_steps": n_steps,158"batch_size": batch_size,159"gamma": gamma,160# "cg_damping": cg_damping,161"cg_max_steps": cg_max_steps,162# "line_search_shrinking_factor": line_search_shrinking_factor,163"n_critic_updates": n_critic_updates,164"target_kl": target_kl,165"learning_rate": learning_rate,166"gae_lambda": gae_lambda,167# "sde_sample_freq": sde_sample_freq,168"policy_kwargs": dict(169# log_std_init=log_std_init,170net_arch=net_arch,171activation_fn=activation_fn,172ortho_init=ortho_init,173),174}175176177def sample_a2c_params(trial: optuna.Trial) -> dict[str, Any]:178"""179Sampler for A2C hyperparams.180181:param trial:182:return:183"""184gamma = trial.suggest_categorical(185"gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]186)187normalize_advantage = trial.suggest_categorical(188"normalize_advantage", [False, True]189)190max_grad_norm = trial.suggest_categorical(191"max_grad_norm", [0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 5]192)193# Toggle PyTorch RMS Prop (different from TF one, cf doc)194use_rms_prop = trial.suggest_categorical("use_rms_prop", [False, True])195gae_lambda = trial.suggest_categorical(196"gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0]197)198n_steps = trial.suggest_categorical(199"n_steps", [8, 16, 32, 64, 128, 256, 512, 1024, 2048]200)201lr_schedule = trial.suggest_categorical("lr_schedule", ["linear", "constant"])202learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)203ent_coef = trial.suggest_loguniform("ent_coef", 0.00000001, 0.1)204vf_coef = trial.suggest_uniform("vf_coef", 0, 1)205# Uncomment for gSDE (continuous actions)206# log_std_init = trial.suggest_uniform("log_std_init", -4, 1)207ortho_init = trial.suggest_categorical("ortho_init", [False, True])208net_arch = trial.suggest_categorical("net_arch", ["small", "medium"])209# sde_net_arch = trial.suggest_categorical("sde_net_arch", [None, "tiny", "small"])210# full_std = trial.suggest_categorical("full_std", [False, True])211# activation_fn = trial.suggest_categorical('activation_fn', ['tanh', 'relu', 'elu', 'leaky_relu'])212activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"])213214if lr_schedule == "linear":215learning_rate = linear_schedule(learning_rate)216217net_arch = {218"small": [dict(pi=[64, 64], vf=[64, 64])],219"medium": [dict(pi=[256, 256], vf=[256, 256])],220}[net_arch]221222# sde_net_arch = {223# None: None,224# "tiny": [64],225# "small": [64, 64],226# }[sde_net_arch]227228activation_fn = {229"tanh": nn.Tanh,230"relu": nn.ReLU,231"elu": nn.ELU,232"leaky_relu": nn.LeakyReLU,233}[activation_fn]234235return {236"n_steps": n_steps,237"gamma": gamma,238"gae_lambda": gae_lambda,239"learning_rate": learning_rate,240"ent_coef": ent_coef,241"normalize_advantage": normalize_advantage,242"max_grad_norm": max_grad_norm,243"use_rms_prop": use_rms_prop,244"vf_coef": vf_coef,245"policy_kwargs": dict(246# log_std_init=log_std_init,247net_arch=net_arch,248# full_std=full_std,249activation_fn=activation_fn,250# sde_net_arch=sde_net_arch,251ortho_init=ortho_init,252),253}254255256def sample_sac_params(trial: optuna.Trial) -> dict[str, Any]:257"""258Sampler for SAC hyperparams.259260:param trial:261:return:262"""263gamma = trial.suggest_categorical(264"gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]265)266learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)267batch_size = trial.suggest_categorical(268"batch_size", [16, 32, 64, 128, 256, 512, 1024, 2048]269)270buffer_size = trial.suggest_categorical(271"buffer_size", [int(1e4), int(1e5), int(1e6)]272)273learning_starts = trial.suggest_categorical(274"learning_starts", [0, 1000, 10000, 20000]275)276# train_freq = trial.suggest_categorical('train_freq', [1, 10, 100, 300])277train_freq = trial.suggest_categorical(278"train_freq", [1, 4, 8, 16, 32, 64, 128, 256, 512]279)280# Polyak coeff281tau = trial.suggest_categorical("tau", [0.001, 0.005, 0.01, 0.02, 0.05, 0.08])282# gradient_steps takes too much time283# gradient_steps = trial.suggest_categorical('gradient_steps', [1, 100, 300])284gradient_steps = train_freq285# ent_coef = trial.suggest_categorical('ent_coef', ['auto', 0.5, 0.1, 0.05, 0.01, 0.0001])286ent_coef = "auto"287# You can comment that out when not using gSDE288log_std_init = trial.suggest_uniform("log_std_init", -4, 1)289# NOTE: Add "verybig" to net_arch when tuning HER290net_arch = trial.suggest_categorical("net_arch", ["small", "medium", "big"])291# activation_fn = trial.suggest_categorical('activation_fn', [nn.Tanh, nn.ReLU, nn.ELU, nn.LeakyReLU])292293net_arch = {294"small": [64, 64],295"medium": [256, 256],296"big": [400, 300],297# Uncomment for tuning HER298# "large": [256, 256, 256],299# "verybig": [512, 512, 512],300}[net_arch]301302target_entropy = "auto"303# if ent_coef == 'auto':304# # target_entropy = trial.suggest_categorical('target_entropy', ['auto', 5, 1, 0, -1, -5, -10, -20, -50])305# target_entropy = trial.suggest_uniform('target_entropy', -10, 10)306307hyperparams = {308"gamma": gamma,309"learning_rate": learning_rate,310"batch_size": batch_size,311"buffer_size": buffer_size,312"learning_starts": learning_starts,313"train_freq": train_freq,314"gradient_steps": gradient_steps,315"ent_coef": ent_coef,316"tau": tau,317"target_entropy": target_entropy,318"policy_kwargs": dict(log_std_init=log_std_init, net_arch=net_arch),319}320321if trial.using_her_replay_buffer:322hyperparams = sample_her_params(trial, hyperparams)323324return hyperparams325326327def sample_td3_params(trial: optuna.Trial) -> dict[str, Any]:328"""329Sampler for TD3 hyperparams.330331:param trial:332:return:333"""334gamma = trial.suggest_categorical(335"gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]336)337learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)338batch_size = trial.suggest_categorical(339"batch_size", [16, 32, 64, 100, 128, 256, 512, 1024, 2048]340)341buffer_size = trial.suggest_categorical(342"buffer_size", [int(1e4), int(1e5), int(1e6)]343)344# Polyak coeff345tau = trial.suggest_categorical("tau", [0.001, 0.005, 0.01, 0.02, 0.05, 0.08])346347train_freq = trial.suggest_categorical(348"train_freq", [1, 4, 8, 16, 32, 64, 128, 256, 512]349)350gradient_steps = train_freq351352noise_type = trial.suggest_categorical(353"noise_type", ["ornstein-uhlenbeck", "normal", None]354)355noise_std = trial.suggest_uniform("noise_std", 0, 1)356357# NOTE: Add "verybig" to net_arch when tuning HER358net_arch = trial.suggest_categorical("net_arch", ["small", "medium", "big"])359# activation_fn = trial.suggest_categorical('activation_fn', [nn.Tanh, nn.ReLU, nn.ELU, nn.LeakyReLU])360361net_arch = {362"small": [64, 64],363"medium": [256, 256],364"big": [400, 300],365# Uncomment for tuning HER366# "verybig": [256, 256, 256],367}[net_arch]368369hyperparams = {370"gamma": gamma,371"learning_rate": learning_rate,372"batch_size": batch_size,373"buffer_size": buffer_size,374"train_freq": train_freq,375"gradient_steps": gradient_steps,376"policy_kwargs": dict(net_arch=net_arch),377"tau": tau,378}379380if noise_type == "normal":381hyperparams["action_noise"] = NormalActionNoise(382mean=np.zeros(trial.n_actions), sigma=noise_std * np.ones(trial.n_actions)383)384elif noise_type == "ornstein-uhlenbeck":385hyperparams["action_noise"] = OrnsteinUhlenbeckActionNoise(386mean=np.zeros(trial.n_actions), sigma=noise_std * np.ones(trial.n_actions)387)388389if trial.using_her_replay_buffer:390hyperparams = sample_her_params(trial, hyperparams)391392return hyperparams393394395def sample_ddpg_params(trial: optuna.Trial) -> dict[str, Any]:396"""397Sampler for DDPG hyperparams.398399:param trial:400:return:401"""402gamma = trial.suggest_categorical(403"gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]404)405learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)406batch_size = trial.suggest_categorical(407"batch_size", [16, 32, 64, 100, 128, 256, 512, 1024, 2048]408)409buffer_size = trial.suggest_categorical(410"buffer_size", [int(1e4), int(1e5), int(1e6)]411)412# Polyak coeff413tau = trial.suggest_categorical("tau", [0.001, 0.005, 0.01, 0.02, 0.05, 0.08])414415train_freq = trial.suggest_categorical(416"train_freq", [1, 4, 8, 16, 32, 64, 128, 256, 512]417)418gradient_steps = train_freq419420noise_type = trial.suggest_categorical(421"noise_type", ["ornstein-uhlenbeck", "normal", None]422)423noise_std = trial.suggest_uniform("noise_std", 0, 1)424425# NOTE: Add "verybig" to net_arch when tuning HER (see TD3)426net_arch = trial.suggest_categorical("net_arch", ["small", "medium", "big"])427# activation_fn = trial.suggest_categorical('activation_fn', [nn.Tanh, nn.ReLU, nn.ELU, nn.LeakyReLU])428429net_arch = {"small": [64, 64], "medium": [256, 256], "big": [400, 300]}[net_arch]430431hyperparams = {432"gamma": gamma,433"tau": tau,434"learning_rate": learning_rate,435"batch_size": batch_size,436"buffer_size": buffer_size,437"train_freq": train_freq,438"gradient_steps": gradient_steps,439"policy_kwargs": dict(net_arch=net_arch),440}441442if noise_type == "normal":443hyperparams["action_noise"] = NormalActionNoise(444mean=np.zeros(trial.n_actions), sigma=noise_std * np.ones(trial.n_actions)445)446elif noise_type == "ornstein-uhlenbeck":447hyperparams["action_noise"] = OrnsteinUhlenbeckActionNoise(448mean=np.zeros(trial.n_actions), sigma=noise_std * np.ones(trial.n_actions)449)450451if trial.using_her_replay_buffer:452hyperparams = sample_her_params(trial, hyperparams)453454return hyperparams455456457def sample_dqn_params(trial: optuna.Trial) -> dict[str, Any]:458"""459Sampler for DQN hyperparams.460461:param trial:462:return:463"""464gamma = trial.suggest_categorical(465"gamma", [0.9, 0.95, 0.98, 0.99, 0.995, 0.999, 0.9999]466)467learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)468batch_size = trial.suggest_categorical(469"batch_size", [16, 32, 64, 100, 128, 256, 512]470)471buffer_size = trial.suggest_categorical(472"buffer_size", [int(1e4), int(5e4), int(1e5), int(1e6)]473)474exploration_final_eps = trial.suggest_uniform("exploration_final_eps", 0, 0.2)475exploration_fraction = trial.suggest_uniform("exploration_fraction", 0, 0.5)476target_update_interval = trial.suggest_categorical(477"target_update_interval", [1, 1000, 5000, 10000, 15000, 20000]478)479learning_starts = trial.suggest_categorical(480"learning_starts", [0, 1000, 5000, 10000, 20000]481)482483train_freq = trial.suggest_categorical("train_freq", [1, 4, 8, 16, 128, 256, 1000])484subsample_steps = trial.suggest_categorical("subsample_steps", [1, 2, 4, 8])485gradient_steps = max(train_freq // subsample_steps, 1)486487net_arch = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"])488489net_arch = {"tiny": [64], "small": [64, 64], "medium": [256, 256]}[net_arch]490491hyperparams = {492"gamma": gamma,493"learning_rate": learning_rate,494"batch_size": batch_size,495"buffer_size": buffer_size,496"train_freq": train_freq,497"gradient_steps": gradient_steps,498"exploration_fraction": exploration_fraction,499"exploration_final_eps": exploration_final_eps,500"target_update_interval": target_update_interval,501"learning_starts": learning_starts,502"policy_kwargs": dict(net_arch=net_arch),503}504505if trial.using_her_replay_buffer:506hyperparams = sample_her_params(trial, hyperparams)507508return hyperparams509510511def sample_her_params(512trial: optuna.Trial, hyperparams: dict[str, Any]513) -> dict[str, Any]:514"""515Sampler for HerReplayBuffer hyperparams.516517:param trial:518:parma hyperparams:519:return:520"""521her_kwargs = trial.her_kwargs.copy()522her_kwargs["n_sampled_goal"] = trial.suggest_int("n_sampled_goal", 1, 5)523her_kwargs["goal_selection_strategy"] = trial.suggest_categorical(524"goal_selection_strategy", ["final", "episode", "future"]525)526her_kwargs["online_sampling"] = trial.suggest_categorical(527"online_sampling", [True, False]528)529hyperparams["replay_buffer_kwargs"] = her_kwargs530return hyperparams531532533def sample_tqc_params(trial: optuna.Trial) -> dict[str, Any]:534"""535Sampler for TQC hyperparams.536537:param trial:538:return:539"""540# TQC is SAC + Distributional RL541hyperparams = sample_sac_params(trial)542543n_quantiles = trial.suggest_int("n_quantiles", 5, 50)544top_quantiles_to_drop_per_net = trial.suggest_int(545"top_quantiles_to_drop_per_net", 0, n_quantiles - 1546)547548hyperparams["policy_kwargs"].update({"n_quantiles": n_quantiles})549hyperparams["top_quantiles_to_drop_per_net"] = top_quantiles_to_drop_per_net550551return hyperparams552553554def sample_qrdqn_params(trial: optuna.Trial) -> dict[str, Any]:555"""556Sampler for QR-DQN hyperparams.557558:param trial:559:return:560"""561# TQC is DQN + Distributional RL562hyperparams = sample_dqn_params(trial)563564n_quantiles = trial.suggest_int("n_quantiles", 5, 200)565hyperparams["policy_kwargs"].update({"n_quantiles": n_quantiles})566567return hyperparams568569570def sample_ars_params(trial: optuna.Trial) -> dict[str, Any]:571"""572Sampler for ARS hyperparams.573:param trial:574:return:575"""576# n_eval_episodes = trial.suggest_categorical("n_eval_episodes", [1, 2])577n_delta = trial.suggest_categorical("n_delta", [4, 8, 6, 32, 64])578# learning_rate = trial.suggest_categorical("learning_rate", [0.01, 0.02, 0.025, 0.03])579learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1)580delta_std = trial.suggest_categorical(581"delta_std", [0.01, 0.02, 0.025, 0.03, 0.05, 0.1, 0.2, 0.3]582)583top_frac_size = trial.suggest_categorical(584"top_frac_size", [0.1, 0.2, 0.3, 0.5, 0.8, 0.9, 1.0]585)586zero_policy = trial.suggest_categorical("zero_policy", [True, False])587n_top = max(int(top_frac_size * n_delta), 1)588589# net_arch = trial.suggest_categorical("net_arch", ["linear", "tiny", "small"])590591# Note: remove bias to be as the original linear policy592# and do not squash output593# Comment out when doing hyperparams search with linear policy only594# net_arch = {595# "linear": [],596# "tiny": [16],597# "small": [32],598# }[net_arch]599600# TODO: optimize the alive_bonus_offset too601602return {603# "n_eval_episodes": n_eval_episodes,604"n_delta": n_delta,605"learning_rate": learning_rate,606"delta_std": delta_std,607"n_top": n_top,608"zero_policy": zero_policy,609# "policy_kwargs": dict(net_arch=net_arch),610}611612613HYPERPARAMS_SAMPLER = {614"a2c": sample_a2c_params,615"ars": sample_ars_params,616"ddpg": sample_ddpg_params,617"dqn": sample_dqn_params,618"qrdqn": sample_qrdqn_params,619"sac": sample_sac_params,620"tqc": sample_tqc_params,621"ppo": sample_ppo_params,622"td3": sample_td3_params,623"trpo": sample_trpo_params,624}625626627