mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: adding simulation logging with wandb
This commit is contained in:
@@ -1,21 +1,16 @@
|
||||
import wandb
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from .wrapper import PHANTOM
|
||||
from .lib import EconomicMetricsWrapper, MetricsCallback
|
||||
|
||||
wandb.init(
|
||||
project="phantom-pricing",
|
||||
config={"alpha": 0.3, "n_products": 10, "total_timesteps": 50000}
|
||||
)
|
||||
|
||||
class RenderCallback(BaseCallback):
|
||||
"""Renders environment on every step for live visualization."""
|
||||
def __init__(self, env: PHANTOM):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
self.env.render()
|
||||
return True
|
||||
|
||||
|
||||
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
|
||||
eval_env = PHANTOM(n_products=10, alpha=0.3, render_mode=None)
|
||||
env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
|
||||
eval_env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
|
||||
|
||||
model = SAC(
|
||||
"MultiInputPolicy",
|
||||
@@ -28,11 +23,12 @@ model = SAC(
|
||||
gamma=0.99,
|
||||
)
|
||||
|
||||
render_cb = RenderCallback(env)
|
||||
metrics_cb = MetricsCallback(log_histograms=True, log_freq=100)
|
||||
eval_cb = EvalCallback(eval_env, eval_freq=1000, n_eval_episodes=5, verbose=1)
|
||||
|
||||
model.learn(total_timesteps=50000, callback=[render_cb, eval_cb])
|
||||
model.learn(total_timesteps=50000, callback=[metrics_cb, eval_cb])
|
||||
model.save("phantom_sac")
|
||||
wandb.finish()
|
||||
|
||||
# test trained policy
|
||||
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
|
||||
|
||||
Reference in New Issue
Block a user