mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
cleaning manim and improving rtraining setup
This commit is contained in:
9
Makefile
9
Makefile
@@ -35,7 +35,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" ||
|
|||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
@echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch pdf.arxiv | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.simple | benchmark.agent | train.agent | train.bootstrap | stats.lines"
|
@echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch pdf.arxiv | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.simple | benchmark.agent | train.agent | train.bootstrap | stats.lines | manim.render manim.render.all"
|
||||||
@echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish"
|
@echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish"
|
||||||
@echo "data.pull data.push | study.margin-erosion study.margin-erosion.quick study.margin-erosion.plot"
|
@echo "data.pull data.push | study.margin-erosion study.margin-erosion.quick study.margin-erosion.plot"
|
||||||
@echo ""
|
@echo ""
|
||||||
@@ -201,3 +201,10 @@ count-lines:
|
|||||||
|
|
||||||
all:
|
all:
|
||||||
@$(NX) run paper:build
|
@$(NX) run paper:build
|
||||||
|
|
||||||
|
.PHONY: manim.render manim.render.all
|
||||||
|
manim.render:
|
||||||
|
@$(NX) run manim:render
|
||||||
|
|
||||||
|
manim.render.all:
|
||||||
|
@$(NX) run manim:render-all
|
||||||
|
|||||||
@@ -1,12 +1,32 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
# clear stale TPU locks on startup
|
||||||
|
if os.path.exists("/dev/accel0"):
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["rm", "-f", "/tmp/.libtpu_lockfile", "/tmp/libtpu_lockfile"],
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
|
||||||
|
jax.config.update("jax_threefry_partitionable", True)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
_JAX_OK = False
|
_JAX_OK = False
|
||||||
|
|
||||||
|
_JAX_RUNTIME_OK = True
|
||||||
|
|
||||||
|
|
||||||
def _demand_for_actor_jax(prices, mean, std, noise_std, key):
|
def _demand_for_actor_jax(prices, mean, std, noise_std, key):
|
||||||
"""d(p;theta) = max(0, val - price + noise), normalized to sum 100."""
|
"""d(p;theta) = max(0, val - price + noise), normalized to sum 100."""
|
||||||
@@ -104,7 +106,9 @@ def select_adversarial_alpha_jax(
|
|||||||
falls back to a pure-numpy sequential loop when JAX is unavailable so the
|
falls back to a pure-numpy sequential loop when JAX is unavailable so the
|
||||||
wrapper can call this function unconditionally.
|
wrapper can call this function unconditionally.
|
||||||
"""
|
"""
|
||||||
if not _JAX_OK:
|
global _JAX_RUNTIME_OK
|
||||||
|
|
||||||
|
if not _JAX_OK or not _JAX_RUNTIME_OK:
|
||||||
return _fallback(
|
return _fallback(
|
||||||
candidates,
|
candidates,
|
||||||
prices,
|
prices,
|
||||||
@@ -117,28 +121,45 @@ def select_adversarial_alpha_jax(
|
|||||||
reward_profit_weight,
|
reward_profit_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
k = len(candidates)
|
try:
|
||||||
key = jax.random.PRNGKey(rng_seed)
|
k = len(candidates)
|
||||||
keys = jax.random.split(key, k)
|
key = jax.random.PRNGKey(rng_seed)
|
||||||
|
keys = jax.random.split(key, k)
|
||||||
|
|
||||||
rewards = np.asarray(
|
rewards = np.asarray(
|
||||||
_reward_batched(
|
_reward_batched(
|
||||||
jnp.asarray(candidates, dtype=jnp.float32),
|
jnp.asarray(candidates, dtype=jnp.float32),
|
||||||
jnp.asarray(prices, dtype=jnp.float32),
|
jnp.asarray(prices, dtype=jnp.float32),
|
||||||
float(human_params[0]),
|
float(human_params[0]),
|
||||||
float(human_params[1]),
|
float(human_params[1]),
|
||||||
float(agent_params[0]),
|
float(agent_params[0]),
|
||||||
float(agent_params[1]),
|
float(agent_params[1]),
|
||||||
float(noise_std),
|
float(noise_std),
|
||||||
jnp.asarray(baseline_prices, dtype=jnp.float32),
|
jnp.asarray(baseline_prices, dtype=jnp.float32),
|
||||||
float(lambda_coi),
|
float(lambda_coi),
|
||||||
float(info_value),
|
float(info_value),
|
||||||
float(reward_profit_weight),
|
float(reward_profit_weight),
|
||||||
keys,
|
keys,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
best_idx = int(np.argmin(rewards))
|
||||||
|
return float(candidates[best_idx]), rewards
|
||||||
|
except Exception as exc:
|
||||||
|
# TPU contention / backend init failures can happen in distributed schedulers.
|
||||||
|
# Degrade to numpy path for the remainder of the process.
|
||||||
|
_JAX_RUNTIME_OK = False
|
||||||
|
print(f"PHANTOM_JAX_FALLBACK: {exc}")
|
||||||
|
return _fallback(
|
||||||
|
candidates,
|
||||||
|
prices,
|
||||||
|
human_params,
|
||||||
|
agent_params,
|
||||||
|
noise_std,
|
||||||
|
baseline_prices,
|
||||||
|
lambda_coi,
|
||||||
|
info_value,
|
||||||
|
reward_profit_weight,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
best_idx = int(np.argmin(rewards))
|
|
||||||
return float(candidates[best_idx]), rewards
|
|
||||||
|
|
||||||
|
|
||||||
def _fallback(
|
def _fallback(
|
||||||
|
|||||||
@@ -179,8 +179,29 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def main(argv: list[str] | None = None) -> None:
|
def main(argv: list[str] | None = None) -> None:
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
# Ensure data is downloaded
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
project_root = Path(__file__).parents[1]
|
||||||
|
data_dir = project_root / "experiments" / "collected_data"
|
||||||
|
needs_pull = (not data_dir.exists()) or (not any(data_dir.iterdir()))
|
||||||
|
if needs_pull:
|
||||||
|
try:
|
||||||
|
subprocess.run(["make", "data.pull"], cwd=str(project_root), check=True)
|
||||||
|
except (subprocess.SubprocessError, OSError) as exc:
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
try:
|
||||||
|
from scripts.hf_data import pull
|
||||||
|
|
||||||
|
pull()
|
||||||
|
except (ImportError, OSError, RuntimeError, ValueError) as fallback_exc:
|
||||||
|
print(
|
||||||
|
f"Warning: data.pull failed ({exc}); fallback pull failed ({fallback_exc})"
|
||||||
|
)
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
raw_args = list(sys.argv[1:] if argv is None else argv)
|
raw_args = list(sys.argv[1:] if argv is None else argv)
|
||||||
run_kind = _probe_run_kind(raw_args)
|
run_kind = _probe_run_kind(raw_args)
|
||||||
|
|||||||
@@ -7,6 +7,8 @@
|
|||||||
],
|
],
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"nx": "nx",
|
"nx": "nx",
|
||||||
|
"manim:render": "nx run manim:render",
|
||||||
|
"manim:render-all": "nx run manim:render-all",
|
||||||
"projects": "nx show projects",
|
"projects": "nx show projects",
|
||||||
"graph": "nx graph",
|
"graph": "nx graph",
|
||||||
"web:dev": "nx run web:dev",
|
"web:dev": "nx run web:dev",
|
||||||
|
|||||||
@@ -1,84 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from scenes import SCENE_ORDER
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
|
||||||
parser = argparse.ArgumentParser(description="Render thesis-defense Manim scenes")
|
|
||||||
parser.add_argument(
|
|
||||||
"--quality",
|
|
||||||
default="qm",
|
|
||||||
choices=["ql", "qm", "qh", "qk"],
|
|
||||||
help="Manim quality preset",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--scene",
|
|
||||||
action="append",
|
|
||||||
dest="scenes",
|
|
||||||
help="Scene name; repeat flag to render many",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--preview", action="store_true", help="Open video after each render"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--list", action="store_true", help="List available scenes and exit"
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def validate_requested(requested: list[str]) -> list[str]:
|
|
||||||
missing = [name for name in requested if name not in SCENE_ORDER]
|
|
||||||
if missing:
|
|
||||||
choices = ", ".join(SCENE_ORDER)
|
|
||||||
raise ValueError(f"Unknown scenes: {', '.join(missing)}. Choices: {choices}")
|
|
||||||
return requested
|
|
||||||
|
|
||||||
|
|
||||||
def run_manim(scene_file: Path, scene_name: str, quality: str, preview: bool) -> None:
|
|
||||||
cmd = [sys.executable, "-m", "manim"]
|
|
||||||
if preview:
|
|
||||||
cmd.append("-p")
|
|
||||||
cmd.extend([f"-{quality}", str(scene_file), scene_name])
|
|
||||||
subprocess.run(cmd, cwd=scene_file.parent, check=True)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
|
||||||
args = parse_args()
|
|
||||||
if args.list:
|
|
||||||
for scene in SCENE_ORDER:
|
|
||||||
print(scene)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
scenes = validate_requested(args.scenes) if args.scenes else list(SCENE_ORDER)
|
|
||||||
scene_file = Path(__file__).resolve().parent / "scenes.py"
|
|
||||||
|
|
||||||
try:
|
|
||||||
for scene_name in scenes:
|
|
||||||
run_manim(
|
|
||||||
scene_file=scene_file,
|
|
||||||
scene_name=scene_name,
|
|
||||||
quality=args.quality,
|
|
||||||
preview=args.preview,
|
|
||||||
)
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(
|
|
||||||
"manim executable not found. Install Manim in your Python environment.",
|
|
||||||
file=sys.stderr,
|
|
||||||
)
|
|
||||||
return 2
|
|
||||||
except ValueError as exc:
|
|
||||||
print(str(exc), file=sys.stderr)
|
|
||||||
return 2
|
|
||||||
except subprocess.CalledProcessError as exc:
|
|
||||||
return exc.returncode
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
raise SystemExit(main())
|
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user