Files
PHANTOM/engine/train.py

42 lines
1.2 KiB
Python

import wandb
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback
from .wrapper import PHANTOM
from .lib import EconomicMetricsWrapper, MetricsCallback
wandb.init(
project="phantom-pricing",
config={"alpha": 0.3, "n_products": 10, "total_timesteps": 50000}
)
env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
eval_env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
model = SAC(
"MultiInputPolicy",
env,
verbose=1,
learning_rate=3e-4,
buffer_size=50000,
batch_size=256,
tau=0.005,
gamma=0.99,
)
metrics_cb = MetricsCallback(log_histograms=True, log_freq=100)
eval_cb = EvalCallback(eval_env, eval_freq=1000, n_eval_episodes=5, verbose=1)
model.learn(total_timesteps=50000, callback=[metrics_cb, eval_cb])
model.save("phantom_sac")
wandb.finish()
# test trained policy
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
obs, _ = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, term, trunc, _ = env.step(action)
env.render()
if term or trunc: break
env.close()