cleaning manim and improving rtraining setup

This commit is contained in:
2026-03-12 00:22:46 +01:00
parent d748733231
commit 22e50aac4a
7 changed files with 94 additions and 1688 deletions

View File

@@ -28,6 +28,8 @@ try:
except ImportError:
_JAX_OK = False
_JAX_RUNTIME_OK = True
def _demand_for_actor_jax(prices, mean, std, noise_std, key):
"""d(p;theta) = max(0, val - price + noise), normalized to sum 100."""
@@ -104,7 +106,9 @@ def select_adversarial_alpha_jax(
falls back to a pure-numpy sequential loop when JAX is unavailable so the
wrapper can call this function unconditionally.
"""
if not _JAX_OK:
global _JAX_RUNTIME_OK
if not _JAX_OK or not _JAX_RUNTIME_OK:
return _fallback(
candidates,
prices,
@@ -117,28 +121,45 @@ def select_adversarial_alpha_jax(
reward_profit_weight,
)
k = len(candidates)
key = jax.random.PRNGKey(rng_seed)
keys = jax.random.split(key, k)
try:
k = len(candidates)
key = jax.random.PRNGKey(rng_seed)
keys = jax.random.split(key, k)
rewards = np.asarray(
_reward_batched(
jnp.asarray(candidates, dtype=jnp.float32),
jnp.asarray(prices, dtype=jnp.float32),
float(human_params[0]),
float(human_params[1]),
float(agent_params[0]),
float(agent_params[1]),
float(noise_std),
jnp.asarray(baseline_prices, dtype=jnp.float32),
float(lambda_coi),
float(info_value),
float(reward_profit_weight),
keys,
rewards = np.asarray(
_reward_batched(
jnp.asarray(candidates, dtype=jnp.float32),
jnp.asarray(prices, dtype=jnp.float32),
float(human_params[0]),
float(human_params[1]),
float(agent_params[0]),
float(agent_params[1]),
float(noise_std),
jnp.asarray(baseline_prices, dtype=jnp.float32),
float(lambda_coi),
float(info_value),
float(reward_profit_weight),
keys,
)
)
best_idx = int(np.argmin(rewards))
return float(candidates[best_idx]), rewards
except Exception as exc:
# TPU contention / backend init failures can happen in distributed schedulers.
# Degrade to numpy path for the remainder of the process.
_JAX_RUNTIME_OK = False
print(f"PHANTOM_JAX_FALLBACK: {exc}")
return _fallback(
candidates,
prices,
human_params,
agent_params,
noise_std,
baseline_prices,
lambda_coi,
info_value,
reward_profit_weight,
)
)
best_idx = int(np.argmin(rewards))
return float(candidates[best_idx]), rewards
def _fallback(