chore: adding simulation logging with wandb

This commit is contained in:
2026-01-31 16:21:10 +01:00
parent 33cb0d7e95
commit 4abef97bf7
5 changed files with 81 additions and 31 deletions

View File

@@ -1,21 +1,16 @@
import wandb
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
from stable_baselines3.common.callbacks import EvalCallback
from .wrapper import PHANTOM
from .lib import EconomicMetricsWrapper, MetricsCallback
wandb.init(
project="phantom-pricing",
config={"alpha": 0.3, "n_products": 10, "total_timesteps": 50000}
)
class RenderCallback(BaseCallback):
"""Renders environment on every step for live visualization."""
def __init__(self, env: PHANTOM):
super().__init__()
self.env = env
def _on_step(self) -> bool:
self.env.render()
return True
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
eval_env = PHANTOM(n_products=10, alpha=0.3, render_mode=None)
env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
eval_env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
model = SAC(
"MultiInputPolicy",
@@ -28,11 +23,12 @@ model = SAC(
gamma=0.99,
)
render_cb = RenderCallback(env)
metrics_cb = MetricsCallback(log_histograms=True, log_freq=100)
eval_cb = EvalCallback(eval_env, eval_freq=1000, n_eval_episodes=5, verbose=1)
model.learn(total_timesteps=50000, callback=[render_cb, eval_cb])
model.learn(total_timesteps=50000, callback=[metrics_cb, eval_cb])
model.save("phantom_sac")
wandb.finish()
# test trained policy
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")