mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: cleaning the code
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""full factorial design - all factor combinations"""
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "..")
|
||||
import logging
|
||||
from itertools import product
|
||||
@@ -12,6 +14,7 @@ from .factors import FACTORS, DEMAND_FUNCTIONS, SEEDS_PER_CONFIG
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_configs():
|
||||
"""generate all factor combinations with seeds"""
|
||||
all_levels = [f.levels for f in FACTORS]
|
||||
@@ -22,10 +25,13 @@ def generate_configs():
|
||||
base = {names[i]: combo[i] for i in range(len(names))}
|
||||
for seed in range(SEEDS_PER_CONFIG):
|
||||
cfg = {**base, "seed": seed}
|
||||
cfg["id"] = hashlib.md5(json.dumps(cfg, sort_keys=True).encode()).hexdigest()[:8]
|
||||
cfg["id"] = hashlib.md5(
|
||||
json.dumps(cfg, sort_keys=True).encode()
|
||||
).hexdigest()[:8]
|
||||
configs.append(cfg)
|
||||
return configs
|
||||
|
||||
|
||||
def run_single(cfg: dict) -> dict:
|
||||
"""execute one experiment config, return metrics"""
|
||||
from engine.wrapper import PHANTOM
|
||||
@@ -49,7 +55,8 @@ def run_single(cfg: dict) -> dict:
|
||||
obs, reward, term, trunc, _ = env.step(action)
|
||||
total_reward += reward
|
||||
steps += 1
|
||||
if term: break
|
||||
if term:
|
||||
break
|
||||
|
||||
env.close()
|
||||
return {
|
||||
@@ -60,22 +67,28 @@ def run_single(cfg: dict) -> dict:
|
||||
"steps": steps,
|
||||
}
|
||||
|
||||
|
||||
def run_study(max_workers: int = None, output: str = "results_full.jsonl"):
|
||||
configs = generate_configs()
|
||||
log.info(f"full factorial: {len(configs)} configs ({len(configs)//SEEDS_PER_CONFIG} unique × {SEEDS_PER_CONFIG} seeds)")
|
||||
log.info(
|
||||
f"full factorial: {len(configs)} configs ({len(configs) // SEEDS_PER_CONFIG} unique × {SEEDS_PER_CONFIG} seeds)"
|
||||
)
|
||||
|
||||
results = []
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as ex:
|
||||
for i, result in enumerate(ex.map(run_single, configs)):
|
||||
results.append(result)
|
||||
if (i+1) % 100 == 0: log.info(f"progress: {i+1}/{len(configs)}")
|
||||
if (i + 1) % 100 == 0:
|
||||
log.info(f"progress: {i + 1}/{len(configs)}")
|
||||
|
||||
Path(output).write_text("\n".join(json.dumps(r) for r in results))
|
||||
log.info(f"wrote {len(results)} results to {output}")
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--workers", type=int, default=None)
|
||||
p.add_argument("--output", default="results_full.jsonl")
|
||||
@@ -83,7 +96,9 @@ if __name__ == "__main__":
|
||||
args = p.parse_args()
|
||||
|
||||
configs = generate_configs()
|
||||
log.info(f"design: {len(configs)} runs | factors: {[f.name for f in FACTORS]} | levels: {[len(f.levels) for f in FACTORS]}")
|
||||
log.info(
|
||||
f"design: {len(configs)} runs | factors: {[f.name for f in FACTORS]} | levels: {[len(f.levels) for f in FACTORS]}"
|
||||
)
|
||||
|
||||
if not args.dry_run:
|
||||
run_study(args.workers, args.output)
|
||||
|
||||
Reference in New Issue
Block a user