refactored training approaches

This commit is contained in:
2026-02-19 18:23:08 +01:00
parent 5912062dc0
commit 1a9901f118
8 changed files with 947 additions and 308 deletions

View File

@@ -308,6 +308,8 @@ if JAX_AVAILABLE:
n_states: int,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
k_actor, k_product, k_step = jax.random.split(key, 3)
start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32)
term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32)
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
actors = (actor_draw < alpha).astype(jnp.int32)
products = jax.random.randint(
@@ -315,7 +317,7 @@ if JAX_AVAILABLE:
)
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32)
state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32)
def _scan_step(carry, _):
states, active, rng = carry
@@ -324,11 +326,11 @@ if JAX_AVAILABLE:
probs_a = agent_T[states]
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
next_state = jnp.where(active, next_state, int(term_idx))
next_state = jnp.where(active, next_state, term_idx_i32)
emitted = jnp.where(active, next_state, -1)
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
next_active = active & (~is_terminal)
carry_states = jnp.where(next_active, next_state, int(term_idx))
carry_states = jnp.where(next_active, next_state, term_idx_i32)
return (carry_states, next_active, rng), emitted
_, state_t = jax.lax.scan(

View File

@@ -1,5 +1,5 @@
flax>=0.8.0
optax>=0.2.0
distrax>=0.1.5
orbax-checkpoint>=0.5.0
chex>=0.1.8
flax==0.10.7
optax==0.2.7
distrax==0.1.5
orbax-checkpoint==0.11.32
chex==0.1.90

File diff suppressed because it is too large Load Diff