mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
fix: coi better defined and aligned and sac improved
This commit is contained in:
@@ -28,7 +28,6 @@ except ImportError:
|
||||
HAS_TB = False
|
||||
|
||||
from .simplified_env import PricingEnv, EnvConfig, make_env, adaptive_policy, fixed_price_policy, random_policy
|
||||
from .coi import coi_erosion
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -214,7 +213,7 @@ def train(cfg: ExperimentConfig) -> Dict[str, Any]:
|
||||
common = dict(verbose=1, seed=cfg.seed, tensorboard_log=str(log_path), device="auto")
|
||||
model = {
|
||||
"ppo": lambda: PPO("MlpPolicy", train_env, learning_rate=3e-4, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, ent_coef=0.01, **common),
|
||||
"sac": lambda: SAC("MlpPolicy", train_env, learning_rate=3e-4, buffer_size=100_000, batch_size=256, tau=0.005, gamma=0.99, **common),
|
||||
"sac": lambda: SAC("MlpPolicy", train_env, learning_rate=1e-4, buffer_size=50_000, batch_size=512, tau=0.02, gamma=0.99, learning_starts=1000, ent_coef="auto_0.1", train_freq=4, **common),
|
||||
"a2c": lambda: A2C("MlpPolicy", train_env, learning_rate=7e-4, n_steps=5, gamma=0.99, **common),
|
||||
}[cfg.algo.lower()]()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user