mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
cleaning manim and improving rtraining setup
This commit is contained in:
@@ -28,6 +28,8 @@ try:
|
||||
except ImportError:
|
||||
_JAX_OK = False
|
||||
|
||||
_JAX_RUNTIME_OK = True
|
||||
|
||||
|
||||
def _demand_for_actor_jax(prices, mean, std, noise_std, key):
|
||||
"""d(p;theta) = max(0, val - price + noise), normalized to sum 100."""
|
||||
@@ -104,7 +106,9 @@ def select_adversarial_alpha_jax(
|
||||
falls back to a pure-numpy sequential loop when JAX is unavailable so the
|
||||
wrapper can call this function unconditionally.
|
||||
"""
|
||||
if not _JAX_OK:
|
||||
global _JAX_RUNTIME_OK
|
||||
|
||||
if not _JAX_OK or not _JAX_RUNTIME_OK:
|
||||
return _fallback(
|
||||
candidates,
|
||||
prices,
|
||||
@@ -117,28 +121,45 @@ def select_adversarial_alpha_jax(
|
||||
reward_profit_weight,
|
||||
)
|
||||
|
||||
k = len(candidates)
|
||||
key = jax.random.PRNGKey(rng_seed)
|
||||
keys = jax.random.split(key, k)
|
||||
try:
|
||||
k = len(candidates)
|
||||
key = jax.random.PRNGKey(rng_seed)
|
||||
keys = jax.random.split(key, k)
|
||||
|
||||
rewards = np.asarray(
|
||||
_reward_batched(
|
||||
jnp.asarray(candidates, dtype=jnp.float32),
|
||||
jnp.asarray(prices, dtype=jnp.float32),
|
||||
float(human_params[0]),
|
||||
float(human_params[1]),
|
||||
float(agent_params[0]),
|
||||
float(agent_params[1]),
|
||||
float(noise_std),
|
||||
jnp.asarray(baseline_prices, dtype=jnp.float32),
|
||||
float(lambda_coi),
|
||||
float(info_value),
|
||||
float(reward_profit_weight),
|
||||
keys,
|
||||
rewards = np.asarray(
|
||||
_reward_batched(
|
||||
jnp.asarray(candidates, dtype=jnp.float32),
|
||||
jnp.asarray(prices, dtype=jnp.float32),
|
||||
float(human_params[0]),
|
||||
float(human_params[1]),
|
||||
float(agent_params[0]),
|
||||
float(agent_params[1]),
|
||||
float(noise_std),
|
||||
jnp.asarray(baseline_prices, dtype=jnp.float32),
|
||||
float(lambda_coi),
|
||||
float(info_value),
|
||||
float(reward_profit_weight),
|
||||
keys,
|
||||
)
|
||||
)
|
||||
best_idx = int(np.argmin(rewards))
|
||||
return float(candidates[best_idx]), rewards
|
||||
except Exception as exc:
|
||||
# TPU contention / backend init failures can happen in distributed schedulers.
|
||||
# Degrade to numpy path for the remainder of the process.
|
||||
_JAX_RUNTIME_OK = False
|
||||
print(f"PHANTOM_JAX_FALLBACK: {exc}")
|
||||
return _fallback(
|
||||
candidates,
|
||||
prices,
|
||||
human_params,
|
||||
agent_params,
|
||||
noise_std,
|
||||
baseline_prices,
|
||||
lambda_coi,
|
||||
info_value,
|
||||
reward_profit_weight,
|
||||
)
|
||||
)
|
||||
best_idx = int(np.argmin(rewards))
|
||||
return float(candidates[best_idx]), rewards
|
||||
|
||||
|
||||
def _fallback(
|
||||
|
||||
Reference in New Issue
Block a user