mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
refactored training approaches
This commit is contained in:
@@ -308,6 +308,8 @@ if JAX_AVAILABLE:
|
||||
n_states: int,
|
||||
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
|
||||
k_actor, k_product, k_step = jax.random.split(key, 3)
|
||||
start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32)
|
||||
term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32)
|
||||
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
|
||||
actors = (actor_draw < alpha).astype(jnp.int32)
|
||||
products = jax.random.randint(
|
||||
@@ -315,7 +317,7 @@ if JAX_AVAILABLE:
|
||||
)
|
||||
|
||||
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
|
||||
state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32)
|
||||
state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32)
|
||||
|
||||
def _scan_step(carry, _):
|
||||
states, active, rng = carry
|
||||
@@ -324,11 +326,11 @@ if JAX_AVAILABLE:
|
||||
probs_a = agent_T[states]
|
||||
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
|
||||
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
|
||||
next_state = jnp.where(active, next_state, int(term_idx))
|
||||
next_state = jnp.where(active, next_state, term_idx_i32)
|
||||
emitted = jnp.where(active, next_state, -1)
|
||||
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
|
||||
next_active = active & (~is_terminal)
|
||||
carry_states = jnp.where(next_active, next_state, int(term_idx))
|
||||
carry_states = jnp.where(next_active, next_state, term_idx_i32)
|
||||
return (carry_states, next_active, rng), emitted
|
||||
|
||||
_, state_t = jax.lax.scan(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
flax>=0.8.0
|
||||
optax>=0.2.0
|
||||
distrax>=0.1.5
|
||||
orbax-checkpoint>=0.5.0
|
||||
chex>=0.1.8
|
||||
flax==0.10.7
|
||||
optax==0.2.7
|
||||
distrax==0.1.5
|
||||
orbax-checkpoint==0.11.32
|
||||
chex==0.1.90
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user