mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
refactored training approaches
This commit is contained in:
@@ -308,6 +308,8 @@ if JAX_AVAILABLE:
|
||||
n_states: int,
|
||||
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
|
||||
k_actor, k_product, k_step = jax.random.split(key, 3)
|
||||
start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32)
|
||||
term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32)
|
||||
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
|
||||
actors = (actor_draw < alpha).astype(jnp.int32)
|
||||
products = jax.random.randint(
|
||||
@@ -315,7 +317,7 @@ if JAX_AVAILABLE:
|
||||
)
|
||||
|
||||
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
|
||||
state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32)
|
||||
state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32)
|
||||
|
||||
def _scan_step(carry, _):
|
||||
states, active, rng = carry
|
||||
@@ -324,11 +326,11 @@ if JAX_AVAILABLE:
|
||||
probs_a = agent_T[states]
|
||||
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
|
||||
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
|
||||
next_state = jnp.where(active, next_state, int(term_idx))
|
||||
next_state = jnp.where(active, next_state, term_idx_i32)
|
||||
emitted = jnp.where(active, next_state, -1)
|
||||
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
|
||||
next_active = active & (~is_terminal)
|
||||
carry_states = jnp.where(next_active, next_state, int(term_idx))
|
||||
carry_states = jnp.where(next_active, next_state, term_idx_i32)
|
||||
return (carry_states, next_active, rng), emitted
|
||||
|
||||
_, state_t = jax.lax.scan(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
flax>=0.8.0
|
||||
optax>=0.2.0
|
||||
distrax>=0.1.5
|
||||
orbax-checkpoint>=0.5.0
|
||||
chex>=0.1.8
|
||||
flax==0.10.7
|
||||
optax==0.2.7
|
||||
distrax==0.1.5
|
||||
orbax-checkpoint==0.11.32
|
||||
chex==0.1.90
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,11 +3,16 @@ from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parents[2]))
|
||||
|
||||
from sim.rl.behavior_loader.models import (
|
||||
BehaviorModel,
|
||||
AgentBehaviorModel,
|
||||
aggregate_event_transitions,
|
||||
)
|
||||
try:
|
||||
from sim.rl.behavior_loader.models import (
|
||||
BehaviorModel,
|
||||
AgentBehaviorModel,
|
||||
aggregate_event_transitions,
|
||||
)
|
||||
except ImportError:
|
||||
BehaviorModel = None
|
||||
AgentBehaviorModel = None
|
||||
aggregate_event_transitions = None
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from .demand import generate_demand_for_actor
|
||||
@@ -20,6 +25,12 @@ _cache = {} # lazy cache for models and base pivots
|
||||
|
||||
|
||||
def _get_base_pivot(human: bool):
|
||||
if (
|
||||
BehaviorModel is None
|
||||
or AgentBehaviorModel is None
|
||||
or aggregate_event_transitions is None
|
||||
):
|
||||
raise ImportError("behavior loader dependencies are unavailable")
|
||||
key = "human" if human else "agent"
|
||||
if key not in _cache:
|
||||
model = BehaviorModel(human_dir) if human else AgentBehaviorModel(agent_dir)
|
||||
@@ -34,6 +45,13 @@ def get_transition_models():
|
||||
returns:
|
||||
tuple: (human_transitions, agent_transitions) as dicts of event->event->prob
|
||||
"""
|
||||
if (
|
||||
BehaviorModel is None
|
||||
or AgentBehaviorModel is None
|
||||
or aggregate_event_transitions is None
|
||||
):
|
||||
raise ImportError("behavior loader dependencies are unavailable")
|
||||
|
||||
human_model = BehaviorModel(human_dir)
|
||||
agent_model = AgentBehaviorModel(agent_dir)
|
||||
|
||||
|
||||
@@ -384,8 +384,6 @@ def train_once(cfg: dict) -> dict:
|
||||
"JAX backend requested but JAX is not installed. "
|
||||
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
||||
)
|
||||
if algo == "qtable":
|
||||
raise ValueError("qtable is not supported in JAX backend")
|
||||
try:
|
||||
from .jax.train import train_jax
|
||||
except Exception as exc: # pragma: no cover
|
||||
@@ -409,20 +407,25 @@ def run_wandb(
|
||||
init_kwargs = {"mode": mode}
|
||||
if sweep_mode:
|
||||
run = wandb.init(**init_kwargs)
|
||||
cfg = _cfg(_wandb_cfg_dict())
|
||||
for k, v in overrides.items():
|
||||
if k not in wandb.config:
|
||||
cfg[k] = v
|
||||
else:
|
||||
run = wandb.init(project=project, config=overrides, **init_kwargs)
|
||||
|
||||
try:
|
||||
cfg = _cfg(_wandb_cfg_dict())
|
||||
metrics = train_once(cfg)
|
||||
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
|
||||
wandb.log(metrics, step=step)
|
||||
for k, v in metrics.items():
|
||||
run.summary[k] = v
|
||||
wandb.finish()
|
||||
return metrics
|
||||
if sweep_mode:
|
||||
for k, v in overrides.items():
|
||||
if k not in wandb.config:
|
||||
cfg[k] = v
|
||||
|
||||
metrics = train_once(cfg)
|
||||
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
|
||||
wandb.log(metrics, step=step)
|
||||
for k, v in metrics.items():
|
||||
run.summary[k] = v
|
||||
return metrics
|
||||
finally:
|
||||
if wandb.run is not None:
|
||||
wandb.finish()
|
||||
|
||||
|
||||
def run_local(overrides: dict) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user