refactored training approaches

This commit is contained in:
2026-02-19 18:23:08 +01:00
parent 5912062dc0
commit 1a9901f118
8 changed files with 947 additions and 308 deletions

View File

@@ -308,6 +308,8 @@ if JAX_AVAILABLE:
n_states: int,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
k_actor, k_product, k_step = jax.random.split(key, 3)
start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32)
term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32)
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
actors = (actor_draw < alpha).astype(jnp.int32)
products = jax.random.randint(
@@ -315,7 +317,7 @@ if JAX_AVAILABLE:
)
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32)
state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32)
def _scan_step(carry, _):
states, active, rng = carry
@@ -324,11 +326,11 @@ if JAX_AVAILABLE:
probs_a = agent_T[states]
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
next_state = jnp.where(active, next_state, int(term_idx))
next_state = jnp.where(active, next_state, term_idx_i32)
emitted = jnp.where(active, next_state, -1)
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
next_active = active & (~is_terminal)
carry_states = jnp.where(next_active, next_state, int(term_idx))
carry_states = jnp.where(next_active, next_state, term_idx_i32)
return (carry_states, next_active, rng), emitted
_, state_t = jax.lax.scan(

View File

@@ -1,5 +1,5 @@
flax>=0.8.0
optax>=0.2.0
distrax>=0.1.5
orbax-checkpoint>=0.5.0
chex>=0.1.8
flax==0.10.7
optax==0.2.7
distrax==0.1.5
orbax-checkpoint==0.11.32
chex==0.1.90

File diff suppressed because it is too large Load Diff

View File

@@ -3,11 +3,16 @@ from pathlib import Path
sys.path.insert(0, str(Path(__file__).parents[2]))
from sim.rl.behavior_loader.models import (
BehaviorModel,
AgentBehaviorModel,
aggregate_event_transitions,
)
try:
from sim.rl.behavior_loader.models import (
BehaviorModel,
AgentBehaviorModel,
aggregate_event_transitions,
)
except ImportError:
BehaviorModel = None
AgentBehaviorModel = None
aggregate_event_transitions = None
import pandas as pd
import numpy as np
from .demand import generate_demand_for_actor
@@ -20,6 +25,12 @@ _cache = {} # lazy cache for models and base pivots
def _get_base_pivot(human: bool):
if (
BehaviorModel is None
or AgentBehaviorModel is None
or aggregate_event_transitions is None
):
raise ImportError("behavior loader dependencies are unavailable")
key = "human" if human else "agent"
if key not in _cache:
model = BehaviorModel(human_dir) if human else AgentBehaviorModel(agent_dir)
@@ -34,6 +45,13 @@ def get_transition_models():
returns:
tuple: (human_transitions, agent_transitions) as dicts of event->event->prob
"""
if (
BehaviorModel is None
or AgentBehaviorModel is None
or aggregate_event_transitions is None
):
raise ImportError("behavior loader dependencies are unavailable")
human_model = BehaviorModel(human_dir)
agent_model = AgentBehaviorModel(agent_dir)

View File

@@ -384,8 +384,6 @@ def train_once(cfg: dict) -> dict:
"JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
)
if algo == "qtable":
raise ValueError("qtable is not supported in JAX backend")
try:
from .jax.train import train_jax
except Exception as exc: # pragma: no cover
@@ -409,20 +407,25 @@ def run_wandb(
init_kwargs = {"mode": mode}
if sweep_mode:
run = wandb.init(**init_kwargs)
cfg = _cfg(_wandb_cfg_dict())
for k, v in overrides.items():
if k not in wandb.config:
cfg[k] = v
else:
run = wandb.init(project=project, config=overrides, **init_kwargs)
try:
cfg = _cfg(_wandb_cfg_dict())
metrics = train_once(cfg)
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
wandb.log(metrics, step=step)
for k, v in metrics.items():
run.summary[k] = v
wandb.finish()
return metrics
if sweep_mode:
for k, v in overrides.items():
if k not in wandb.config:
cfg[k] = v
metrics = train_once(cfg)
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
wandb.log(metrics, step=step)
for k, v in metrics.items():
run.summary[k] = v
return metrics
finally:
if wandb.run is not None:
wandb.finish()
def run_local(overrides: dict) -> dict: