mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-06-01 00:53:36 +00:00
refactored training approaches
This commit is contained in:
@@ -308,6 +308,8 @@ if JAX_AVAILABLE:
|
||||
n_states: int,
|
||||
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
|
||||
k_actor, k_product, k_step = jax.random.split(key, 3)
|
||||
start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32)
|
||||
term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32)
|
||||
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
|
||||
actors = (actor_draw < alpha).astype(jnp.int32)
|
||||
products = jax.random.randint(
|
||||
@@ -315,7 +317,7 @@ if JAX_AVAILABLE:
|
||||
)
|
||||
|
||||
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
|
||||
state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32)
|
||||
state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32)
|
||||
|
||||
def _scan_step(carry, _):
|
||||
states, active, rng = carry
|
||||
@@ -324,11 +326,11 @@ if JAX_AVAILABLE:
|
||||
probs_a = agent_T[states]
|
||||
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
|
||||
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
|
||||
next_state = jnp.where(active, next_state, int(term_idx))
|
||||
next_state = jnp.where(active, next_state, term_idx_i32)
|
||||
emitted = jnp.where(active, next_state, -1)
|
||||
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
|
||||
next_active = active & (~is_terminal)
|
||||
carry_states = jnp.where(next_active, next_state, int(term_idx))
|
||||
carry_states = jnp.where(next_active, next_state, term_idx_i32)
|
||||
return (carry_states, next_active, rng), emitted
|
||||
|
||||
_, state_t = jax.lax.scan(
|
||||
|
||||
Reference in New Issue
Block a user