mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
updating engine training for training
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
|
||||
from ..lib.callbacks import MetricsCallback
|
||||
from ..lib.callbacks import EvalMetricsCallback, MetricsCallback
|
||||
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
|
||||
from .common import evaluate, make_env
|
||||
|
||||
|
||||
@@ -117,7 +119,6 @@ def build_model(cfg: Mapping[str, Any], env: Any):
|
||||
|
||||
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
|
||||
try:
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
except ImportError as exc:
|
||||
raise ImportError("stable-baselines3 is required for SB3 models") from exc
|
||||
@@ -144,20 +145,20 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
|
||||
pass
|
||||
|
||||
metrics_callback = MetricsCallback(
|
||||
log_histograms=False,
|
||||
log_histograms=True,
|
||||
log_freq=int(cfg["log_freq"]),
|
||||
hist_freq=int(cfg.get("hist_freq", 500)),
|
||||
step_offset=int(cfg.get("wandb_step_offset", 0)),
|
||||
)
|
||||
callbacks = [metrics_callback]
|
||||
callbacks.append(
|
||||
EvalCallback(
|
||||
eval_env,
|
||||
eval_freq=int(cfg["eval_freq"]),
|
||||
n_eval_episodes=int(cfg["eval_episodes"]),
|
||||
deterministic=True,
|
||||
verbose=0,
|
||||
)
|
||||
eval_callback = EvalMetricsCallback(
|
||||
eval_env,
|
||||
eval_freq=int(cfg["eval_freq"]),
|
||||
n_eval_episodes=int(cfg["eval_episodes"]),
|
||||
step_offset=int(cfg.get("wandb_step_offset", 0)),
|
||||
deterministic=True,
|
||||
verbose=0,
|
||||
)
|
||||
callbacks = [metrics_callback, eval_callback]
|
||||
|
||||
target_steps = int(cfg["total_timesteps"])
|
||||
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
|
||||
@@ -173,6 +174,29 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
|
||||
model_path = model_dir / f"phantom_{cfg['algo']}"
|
||||
model.save(str(model_path))
|
||||
|
||||
artifact_name = checkpoint_artifact_name(
|
||||
cfg,
|
||||
backend="sb3",
|
||||
sweep_id=os.getenv("WANDB_SWEEP_ID"),
|
||||
)
|
||||
artifact_logged = False
|
||||
try:
|
||||
artifact_logged = bool(
|
||||
log_checkpoint_file(
|
||||
artifact_name,
|
||||
file_path=model_path.with_suffix(".zip"),
|
||||
artifact_file_name="model.zip",
|
||||
metadata={
|
||||
"algo": str(cfg.get("algo", "ppo")),
|
||||
"backend": "sb3",
|
||||
"seed": int(cfg.get("seed", 0)),
|
||||
"step": int(getattr(model, "num_timesteps", 0)),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
artifact_logged = False
|
||||
|
||||
metrics: dict[str, Any] = evaluate(
|
||||
model,
|
||||
eval_env,
|
||||
@@ -181,7 +205,12 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
|
||||
)
|
||||
metrics["train/global_step"] = int(model.num_timesteps)
|
||||
metrics["model/path"] = str(model_path.with_suffix(".zip"))
|
||||
metrics["_train_events"] = list(metrics_callback.events)
|
||||
metrics["model/artifact_name"] = str(artifact_name)
|
||||
metrics["model/artifact_logged"] = float(artifact_logged)
|
||||
metrics["_train_events"] = sorted(
|
||||
[*metrics_callback.events, *eval_callback.events],
|
||||
key=lambda event: int(event.get("train/global_step", 0)),
|
||||
)
|
||||
|
||||
env.close()
|
||||
eval_env.close()
|
||||
|
||||
Reference in New Issue
Block a user