mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
first meaningful runs
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .common import evaluate, make_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def train_qtable(
|
||||
cfg: Mapping[str, Any],
|
||||
@@ -29,7 +33,9 @@ def train_qtable(
|
||||
steps = 0
|
||||
epsilon = float(cfg["eps_start"])
|
||||
log_freq = max(1, int(cfg.get("log_freq", 100)))
|
||||
console_progress = bool(cfg.get("console_progress", False))
|
||||
obs, _ = env.reset(seed=int(cfg["seed"]))
|
||||
started_at = time.perf_counter()
|
||||
|
||||
interval_sums = {
|
||||
"reward": 0.0,
|
||||
@@ -60,17 +66,28 @@ def train_qtable(
|
||||
|
||||
if steps % log_freq == 0 and interval_count > 0:
|
||||
denom = float(interval_count)
|
||||
train_events.append(
|
||||
{
|
||||
"train/reward_mean": interval_sums["reward"] / denom,
|
||||
"train/revenue_mean": interval_sums["revenue"] / denom,
|
||||
"train/agent_prob": interval_sums["agent_prob"] / denom,
|
||||
"train/alpha_adv": interval_sums["alpha_adv"] / denom,
|
||||
"train/coi_leakage": interval_sums["coi_leakage"] / denom,
|
||||
"train/epsilon": float(epsilon),
|
||||
"train/global_step": int(steps),
|
||||
}
|
||||
)
|
||||
event = {
|
||||
"train/reward_mean": interval_sums["reward"] / denom,
|
||||
"train/revenue_mean": interval_sums["revenue"] / denom,
|
||||
"train/agent_prob": interval_sums["agent_prob"] / denom,
|
||||
"train/alpha_adv": interval_sums["alpha_adv"] / denom,
|
||||
"train/coi_leakage": interval_sums["coi_leakage"] / denom,
|
||||
"train/epsilon": float(epsilon),
|
||||
"train/global_step": int(steps),
|
||||
}
|
||||
train_events.append(event)
|
||||
if console_progress:
|
||||
elapsed = max(time.perf_counter() - started_at, 1e-6)
|
||||
speed = steps / elapsed
|
||||
logger.info(
|
||||
"step=%d/%d reward=%.3f revenue=%.3f eps=%.4f speed=%.1f steps/s",
|
||||
steps,
|
||||
int(cfg["total_timesteps"]),
|
||||
event["train/reward_mean"],
|
||||
event["train/revenue_mean"],
|
||||
event["train/epsilon"],
|
||||
speed,
|
||||
)
|
||||
interval_sums = {key: 0.0 for key in interval_sums}
|
||||
interval_count = 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user