mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: adding simulation logging with wandb
This commit is contained in:
@@ -1,3 +1,6 @@
|
|||||||
from .demand import estimate_demand, generate_demand_for_actor
|
from .demand import estimate_demand, generate_demand_for_actor
|
||||||
from .behavior import sample_behavior
|
from .behavior import sample_behavior
|
||||||
from .render import DashboardRenderer, style_axis
|
from .render import DashboardRenderer, style_axis
|
||||||
|
from .wrappers import EconomicMetricsWrapper
|
||||||
|
from .callbacks import MetricsCallback, EvalMetricsCallback
|
||||||
|
from .providers import ProviderBenchmark, ProviderResult, BenchmarkConfig
|
||||||
|
|||||||
@@ -1,27 +1,39 @@
|
|||||||
from sim.rl.behavior_loader.models import BehaviorModel, AgentBehaviorModel, aggregate_event_transitions
|
from sim.rl.behavior_loader.models import (
|
||||||
|
BehaviorModel,
|
||||||
|
AgentBehaviorModel,
|
||||||
|
aggregate_event_transitions,
|
||||||
|
)
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .demand import generate_demand_for_actor
|
from .demand import generate_demand_for_actor
|
||||||
|
|
||||||
base_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments"
|
base_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments"
|
||||||
human_dir, agent_dir = f"{base_dir}/collected_data/", f"{base_dir}/agents/collected_data/"
|
human_dir, agent_dir = (
|
||||||
|
f"{base_dir}/collected_data/",
|
||||||
|
f"{base_dir}/agents/collected_data/",
|
||||||
|
)
|
||||||
|
|
||||||
_cache = {} # lazy cache for models and base pivots
|
_cache = {} # lazy cache for models and base pivots
|
||||||
|
|
||||||
|
|
||||||
def _get_base_pivot(human: bool):
|
def _get_base_pivot(human: bool):
|
||||||
key = 'human' if human else 'agent'
|
key = "human" if human else "agent"
|
||||||
if key not in _cache:
|
if key not in _cache:
|
||||||
model = BehaviorModel(human_dir) if human else AgentBehaviorModel(agent_dir)
|
model = BehaviorModel(human_dir) if human else AgentBehaviorModel(agent_dir)
|
||||||
mdp = model.build_MDP()
|
mdp = model.build_MDP()
|
||||||
_cache[key] = pd.DataFrame(aggregate_event_transitions(mdp)).fillna(0.0)
|
_cache[key] = pd.DataFrame(aggregate_event_transitions(mdp)).fillna(0.0)
|
||||||
return _cache[key]
|
return _cache[key]
|
||||||
|
|
||||||
|
|
||||||
def adjust_behavior_to_condition(condition, transition_matrix):
|
def adjust_behavior_to_condition(condition, transition_matrix):
|
||||||
# expand NxN transition matrix to (N*P)x(N*P) weighted by demand condition
|
# expand NxN transition matrix to (N*P)x(N*P) weighted by demand condition
|
||||||
cond_norm = condition / np.sum(condition)
|
cond_norm = condition / np.sum(condition)
|
||||||
n_products = len(condition)
|
n_products = len(condition)
|
||||||
base_vals = transition_matrix.values
|
base_vals = transition_matrix.values
|
||||||
base_cols, base_rows = transition_matrix.columns.tolist(), transition_matrix.index.tolist()
|
base_cols, base_rows = (
|
||||||
|
transition_matrix.columns.tolist(),
|
||||||
|
transition_matrix.index.tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
# expand via kronecker-like tiling: each cell becomes a P*P block weighted by outer product of cond_norm
|
# expand via kronecker-like tiling: each cell becomes a P*P block weighted by outer product of cond_norm
|
||||||
expanded = np.kron(base_vals, np.outer(cond_norm, cond_norm))
|
expanded = np.kron(base_vals, np.outer(cond_norm, cond_norm))
|
||||||
@@ -29,19 +41,24 @@ def adjust_behavior_to_condition(condition, transition_matrix):
|
|||||||
new_rows = [f"{r}_product{p}" for r in base_rows for p in range(n_products)]
|
new_rows = [f"{r}_product{p}" for r in base_rows for p in range(n_products)]
|
||||||
return pd.DataFrame(expanded, index=new_rows, columns=new_cols)
|
return pd.DataFrame(expanded, index=new_rows, columns=new_cols)
|
||||||
|
|
||||||
|
|
||||||
def sample_behavior(condition, human=True, max_len=40):
|
def sample_behavior(condition, human=True, max_len=40):
|
||||||
base_pivot = _get_base_pivot(human)
|
base_pivot = _get_base_pivot(human)
|
||||||
adjusted_transitions = adjust_behavior_to_condition(condition, base_pivot)
|
adjusted_transitions = adjust_behavior_to_condition(condition, base_pivot)
|
||||||
|
|
||||||
trajectory = [np.random.choice(adjusted_transitions.index)]
|
trajectory = [np.random.choice(adjusted_transitions.index)]
|
||||||
while len(trajectory) < max_len or 'checkout' in trajectory[-1]:
|
while len(trajectory) < max_len and "checkout" not in trajectory[-1]:
|
||||||
probs = adjusted_transitions.loc[trajectory[-1]].values
|
probs = adjusted_transitions.loc[trajectory[-1]].values
|
||||||
sample = np.random.choice(adjusted_transitions.columns, p=probs/np.sum(probs) if np.sum(probs) > 0 else None)
|
sample = np.random.choice(
|
||||||
|
adjusted_transitions.columns,
|
||||||
|
p=probs / np.sum(probs) if np.sum(probs) > 0 else None,
|
||||||
|
)
|
||||||
trajectory.append(sample)
|
trajectory.append(sample)
|
||||||
return trajectory
|
return trajectory
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
t=sample_behavior(generate_demand_for_actor(np.array([10,20,30])), human=True)
|
t = sample_behavior(generate_demand_for_actor(np.array([10, 20, 30])), human=True)
|
||||||
print(t)
|
print(t)
|
||||||
t=sample_behavior(generate_demand_for_actor(np.array([10,20,30])), human=False)
|
t = sample_behavior(generate_demand_for_actor(np.array([10, 20, 30])), human=False)
|
||||||
print(t)
|
print(t)
|
||||||
|
|||||||
@@ -1,21 +1,16 @@
|
|||||||
|
import wandb
|
||||||
from stable_baselines3 import SAC
|
from stable_baselines3 import SAC
|
||||||
from stable_baselines3.common.callbacks import EvalCallback, BaseCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
from .wrapper import PHANTOM
|
from .wrapper import PHANTOM
|
||||||
|
from .lib import EconomicMetricsWrapper, MetricsCallback
|
||||||
|
|
||||||
|
wandb.init(
|
||||||
|
project="phantom-pricing",
|
||||||
|
config={"alpha": 0.3, "n_products": 10, "total_timesteps": 50000}
|
||||||
|
)
|
||||||
|
|
||||||
class RenderCallback(BaseCallback):
|
env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
|
||||||
"""Renders environment on every step for live visualization."""
|
eval_env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
|
||||||
def __init__(self, env: PHANTOM):
|
|
||||||
super().__init__()
|
|
||||||
self.env = env
|
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
|
||||||
self.env.render()
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
|
|
||||||
eval_env = PHANTOM(n_products=10, alpha=0.3, render_mode=None)
|
|
||||||
|
|
||||||
model = SAC(
|
model = SAC(
|
||||||
"MultiInputPolicy",
|
"MultiInputPolicy",
|
||||||
@@ -28,11 +23,12 @@ model = SAC(
|
|||||||
gamma=0.99,
|
gamma=0.99,
|
||||||
)
|
)
|
||||||
|
|
||||||
render_cb = RenderCallback(env)
|
metrics_cb = MetricsCallback(log_histograms=True, log_freq=100)
|
||||||
eval_cb = EvalCallback(eval_env, eval_freq=1000, n_eval_episodes=5, verbose=1)
|
eval_cb = EvalCallback(eval_env, eval_freq=1000, n_eval_episodes=5, verbose=1)
|
||||||
|
|
||||||
model.learn(total_timesteps=50000, callback=[render_cb, eval_cb])
|
model.learn(total_timesteps=50000, callback=[metrics_cb, eval_cb])
|
||||||
model.save("phantom_sac")
|
model.save("phantom_sac")
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
# test trained policy
|
# test trained policy
|
||||||
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
|
env = PHANTOM(n_products=10, alpha=0.3, render_mode="human")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import numpy as np
|
|||||||
from .engine import Limbo, MarketEngine, PricingEngine
|
from .engine import Limbo, MarketEngine, PricingEngine
|
||||||
from .lib.render import DashboardRenderer
|
from .lib.render import DashboardRenderer
|
||||||
from .lib.coi import compute_coi_proxy
|
from .lib.coi import compute_coi_proxy
|
||||||
|
from .lib.wrappers import EconomicMetricsWrapper
|
||||||
|
|
||||||
|
|
||||||
class PHANTOM(gym.Env):
|
class PHANTOM(gym.Env):
|
||||||
@@ -134,11 +135,43 @@ class PHANTOM(gym.Env):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env = PHANTOM(n_products=15, alpha=0.3, N=100, render_mode="human")
|
import wandb
|
||||||
obs, _ = env.reset()
|
from .lib import MetricsCallback
|
||||||
for step in range(100):
|
|
||||||
action = env.action_space.sample()
|
class RandomPolicy:
|
||||||
obs, reward, term, trunc, info = env.step(action)
|
"""Minimal SB3-compatible random policy for baseline testing."""
|
||||||
env.render()
|
def __init__(self, env):
|
||||||
if term: break
|
self.env = env
|
||||||
|
self.num_timesteps = 0
|
||||||
|
|
||||||
|
def learn(self, total_timesteps, callback=None):
|
||||||
|
callback.model = self
|
||||||
|
callback.num_timesteps = 0
|
||||||
|
callback.locals = {}
|
||||||
|
callback.on_training_start({}, {})
|
||||||
|
|
||||||
|
obs, _ = self.env.reset()
|
||||||
|
for step in range(total_timesteps):
|
||||||
|
action = self.env.action_space.sample()
|
||||||
|
obs, reward, term, trunc, info = self.env.step(action)
|
||||||
|
self.num_timesteps = step + 1
|
||||||
|
callback.num_timesteps = self.num_timesteps
|
||||||
|
callback.locals = {"infos": [info]}
|
||||||
|
callback.on_step()
|
||||||
|
if term or trunc:
|
||||||
|
callback.on_rollout_end()
|
||||||
|
obs, _ = self.env.reset()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, obs, **kwargs):
|
||||||
|
return self.env.action_space.sample(), None
|
||||||
|
|
||||||
|
wandb.init(project="phantom-pricing", config={"policy": "random", "alpha": 0.3})
|
||||||
|
env = EconomicMetricsWrapper(PHANTOM(n_products=15, alpha=0.3, render_mode=None))
|
||||||
|
|
||||||
|
model = RandomPolicy(env)
|
||||||
|
model.learn(total_timesteps=1000, callback=MetricsCallback())
|
||||||
|
|
||||||
|
print(f"Episode revenue: {env.episode_revenue:.1f}")
|
||||||
|
wandb.finish()
|
||||||
env.close()
|
env.close()
|
||||||
|
|||||||
@@ -12,3 +12,4 @@ uv
|
|||||||
scikit-learn
|
scikit-learn
|
||||||
supabase
|
supabase
|
||||||
pymc
|
pymc
|
||||||
|
wandb
|
||||||
|
|||||||
Reference in New Issue
Block a user