cleaning manim and improving rtraining setup

This commit is contained in:
2026-03-12 00:22:46 +01:00
parent d748733231
commit 22e50aac4a
7 changed files with 94 additions and 1688 deletions

View File

@@ -35,7 +35,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" ||
.PHONY: help .PHONY: help
help: help:
@echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch pdf.arxiv | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.simple | benchmark.agent | train.agent | train.bootstrap | stats.lines" @echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch pdf.arxiv | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.simple | benchmark.agent | train.agent | train.bootstrap | stats.lines | manim.render manim.render.all"
@echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish" @echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish"
@echo "data.pull data.push | study.margin-erosion study.margin-erosion.quick study.margin-erosion.plot" @echo "data.pull data.push | study.margin-erosion study.margin-erosion.quick study.margin-erosion.plot"
@echo "" @echo ""
@@ -201,3 +201,10 @@ count-lines:
all: all:
@$(NX) run paper:build @$(NX) run paper:build
.PHONY: manim.render manim.render.all
manim.render:
@$(NX) run manim:render
manim.render.all:
@$(NX) run manim:render-all

View File

@@ -1,12 +1,32 @@
from __future__ import annotations from __future__ import annotations
import os
import subprocess
import sys
import argparse import argparse
import json import json
import logging import logging
import os
from datetime import datetime, UTC from datetime import datetime, UTC
from pathlib import Path from pathlib import Path
# clear stale TPU locks on startup
if os.path.exists("/dev/accel0"):
try:
subprocess.run(
["rm", "-f", "/tmp/.libtpu_lockfile", "/tmp/libtpu_lockfile"],
stderr=subprocess.DEVNULL,
)
except:
pass
try:
import jax
jax.config.update("jax_threefry_partitionable", True)
except ImportError:
pass
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd

View File

@@ -28,6 +28,8 @@ try:
except ImportError: except ImportError:
_JAX_OK = False _JAX_OK = False
_JAX_RUNTIME_OK = True
def _demand_for_actor_jax(prices, mean, std, noise_std, key): def _demand_for_actor_jax(prices, mean, std, noise_std, key):
"""d(p;theta) = max(0, val - price + noise), normalized to sum 100.""" """d(p;theta) = max(0, val - price + noise), normalized to sum 100."""
@@ -104,7 +106,9 @@ def select_adversarial_alpha_jax(
falls back to a pure-numpy sequential loop when JAX is unavailable so the falls back to a pure-numpy sequential loop when JAX is unavailable so the
wrapper can call this function unconditionally. wrapper can call this function unconditionally.
""" """
if not _JAX_OK: global _JAX_RUNTIME_OK
if not _JAX_OK or not _JAX_RUNTIME_OK:
return _fallback( return _fallback(
candidates, candidates,
prices, prices,
@@ -117,28 +121,45 @@ def select_adversarial_alpha_jax(
reward_profit_weight, reward_profit_weight,
) )
k = len(candidates) try:
key = jax.random.PRNGKey(rng_seed) k = len(candidates)
keys = jax.random.split(key, k) key = jax.random.PRNGKey(rng_seed)
keys = jax.random.split(key, k)
rewards = np.asarray( rewards = np.asarray(
_reward_batched( _reward_batched(
jnp.asarray(candidates, dtype=jnp.float32), jnp.asarray(candidates, dtype=jnp.float32),
jnp.asarray(prices, dtype=jnp.float32), jnp.asarray(prices, dtype=jnp.float32),
float(human_params[0]), float(human_params[0]),
float(human_params[1]), float(human_params[1]),
float(agent_params[0]), float(agent_params[0]),
float(agent_params[1]), float(agent_params[1]),
float(noise_std), float(noise_std),
jnp.asarray(baseline_prices, dtype=jnp.float32), jnp.asarray(baseline_prices, dtype=jnp.float32),
float(lambda_coi), float(lambda_coi),
float(info_value), float(info_value),
float(reward_profit_weight), float(reward_profit_weight),
keys, keys,
)
)
best_idx = int(np.argmin(rewards))
return float(candidates[best_idx]), rewards
except Exception as exc:
# TPU contention / backend init failures can happen in distributed schedulers.
# Degrade to numpy path for the remainder of the process.
_JAX_RUNTIME_OK = False
print(f"PHANTOM_JAX_FALLBACK: {exc}")
return _fallback(
candidates,
prices,
human_params,
agent_params,
noise_std,
baseline_prices,
lambda_coi,
info_value,
reward_profit_weight,
) )
)
best_idx = int(np.argmin(rewards))
return float(candidates[best_idx]), rewards
def _fallback( def _fallback(

View File

@@ -179,8 +179,29 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
def main(argv: list[str] | None = None) -> None: def main(argv: list[str] | None = None) -> None:
import subprocess
import sys import sys
# Ensure data is downloaded
from pathlib import Path
project_root = Path(__file__).parents[1]
data_dir = project_root / "experiments" / "collected_data"
needs_pull = (not data_dir.exists()) or (not any(data_dir.iterdir()))
if needs_pull:
try:
subprocess.run(["make", "data.pull"], cwd=str(project_root), check=True)
except (subprocess.SubprocessError, OSError) as exc:
sys.path.insert(0, str(project_root))
try:
from scripts.hf_data import pull
pull()
except (ImportError, OSError, RuntimeError, ValueError) as fallback_exc:
print(
f"Warning: data.pull failed ({exc}); fallback pull failed ({fallback_exc})"
)
configure_logging() configure_logging()
raw_args = list(sys.argv[1:] if argv is None else argv) raw_args = list(sys.argv[1:] if argv is None else argv)
run_kind = _probe_run_kind(raw_args) run_kind = _probe_run_kind(raw_args)

View File

@@ -7,6 +7,8 @@
], ],
"scripts": { "scripts": {
"nx": "nx", "nx": "nx",
"manim:render": "nx run manim:render",
"manim:render-all": "nx run manim:render-all",
"projects": "nx show projects", "projects": "nx show projects",
"graph": "nx graph", "graph": "nx graph",
"web:dev": "nx run web:dev", "web:dev": "nx run web:dev",

View File

@@ -1,84 +0,0 @@
from __future__ import annotations
import argparse
import subprocess
import sys
from pathlib import Path
from scenes import SCENE_ORDER
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Render thesis-defense Manim scenes")
parser.add_argument(
"--quality",
default="qm",
choices=["ql", "qm", "qh", "qk"],
help="Manim quality preset",
)
parser.add_argument(
"--scene",
action="append",
dest="scenes",
help="Scene name; repeat flag to render many",
)
parser.add_argument(
"--preview", action="store_true", help="Open video after each render"
)
parser.add_argument(
"--list", action="store_true", help="List available scenes and exit"
)
return parser.parse_args()
def validate_requested(requested: list[str]) -> list[str]:
missing = [name for name in requested if name not in SCENE_ORDER]
if missing:
choices = ", ".join(SCENE_ORDER)
raise ValueError(f"Unknown scenes: {', '.join(missing)}. Choices: {choices}")
return requested
def run_manim(scene_file: Path, scene_name: str, quality: str, preview: bool) -> None:
cmd = [sys.executable, "-m", "manim"]
if preview:
cmd.append("-p")
cmd.extend([f"-{quality}", str(scene_file), scene_name])
subprocess.run(cmd, cwd=scene_file.parent, check=True)
def main() -> int:
args = parse_args()
if args.list:
for scene in SCENE_ORDER:
print(scene)
return 0
scenes = validate_requested(args.scenes) if args.scenes else list(SCENE_ORDER)
scene_file = Path(__file__).resolve().parent / "scenes.py"
try:
for scene_name in scenes:
run_manim(
scene_file=scene_file,
scene_name=scene_name,
quality=args.quality,
preview=args.preview,
)
except FileNotFoundError:
print(
"manim executable not found. Install Manim in your Python environment.",
file=sys.stderr,
)
return 2
except ValueError as exc:
print(str(exc), file=sys.stderr)
return 2
except subprocess.CalledProcessError as exc:
return exc.returncode
return 0
if __name__ == "__main__":
raise SystemExit(main())

File diff suppressed because it is too large Load Diff