mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: adding simulation logging with wandb
This commit is contained in:
@@ -4,6 +4,7 @@ import numpy as np
|
||||
from .engine import Limbo, MarketEngine, PricingEngine
|
||||
from .lib.render import DashboardRenderer
|
||||
from .lib.coi import compute_coi_proxy
|
||||
from .lib.wrappers import EconomicMetricsWrapper
|
||||
|
||||
|
||||
class PHANTOM(gym.Env):
|
||||
@@ -134,11 +135,43 @@ class PHANTOM(gym.Env):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = PHANTOM(n_products=15, alpha=0.3, N=100, render_mode="human")
|
||||
obs, _ = env.reset()
|
||||
for step in range(100):
|
||||
action = env.action_space.sample()
|
||||
obs, reward, term, trunc, info = env.step(action)
|
||||
env.render()
|
||||
if term: break
|
||||
import wandb
|
||||
from .lib import MetricsCallback
|
||||
|
||||
class RandomPolicy:
|
||||
"""Minimal SB3-compatible random policy for baseline testing."""
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
self.num_timesteps = 0
|
||||
|
||||
def learn(self, total_timesteps, callback=None):
|
||||
callback.model = self
|
||||
callback.num_timesteps = 0
|
||||
callback.locals = {}
|
||||
callback.on_training_start({}, {})
|
||||
|
||||
obs, _ = self.env.reset()
|
||||
for step in range(total_timesteps):
|
||||
action = self.env.action_space.sample()
|
||||
obs, reward, term, trunc, info = self.env.step(action)
|
||||
self.num_timesteps = step + 1
|
||||
callback.num_timesteps = self.num_timesteps
|
||||
callback.locals = {"infos": [info]}
|
||||
callback.on_step()
|
||||
if term or trunc:
|
||||
callback.on_rollout_end()
|
||||
obs, _ = self.env.reset()
|
||||
return self
|
||||
|
||||
def predict(self, obs, **kwargs):
|
||||
return self.env.action_space.sample(), None
|
||||
|
||||
wandb.init(project="phantom-pricing", config={"policy": "random", "alpha": 0.3})
|
||||
env = EconomicMetricsWrapper(PHANTOM(n_products=15, alpha=0.3, render_mode=None))
|
||||
|
||||
model = RandomPolicy(env)
|
||||
model.learn(total_timesteps=1000, callback=MetricsCallback())
|
||||
|
||||
print(f"Episode revenue: {env.episode_revenue:.1f}")
|
||||
wandb.finish()
|
||||
env.close()
|
||||
|
||||
Reference in New Issue
Block a user