Files
PHANTOM/lib/separability.py
2026-02-27 12:45:46 +01:00

129 lines
4.3 KiB
Python

"""Utilities for loading separability artifacts and scoring interaction sessions."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Sequence
import joblib
import numpy as np
from experiments.ml.arch import featurize_trajectory
DEFAULT_ARTIFACT_DIR = Path("data/separability")
@dataclass
class SeparabilityArtifacts:
scaler: object
classifier: object
states: List[str]
event_transitions: Dict[str, Dict[str, float]]
feature_dim: int
def _normalize_events(raw_events: Sequence[object]) -> List[object]:
events: List[object] = []
for evt in raw_events:
if hasattr(evt, "value") and hasattr(evt.value, "payload"):
events.append(evt.value.payload)
else:
events.append(evt)
events.sort(key=lambda e: getattr(e, "ts", ""))
return events
def _event_transition_distribution(events: Sequence[object]) -> Dict[str, Dict[str, float]]:
counts: Dict[str, Dict[str, int]] = {}
for src_evt, dst_evt in zip(events, events[1:]):
src_name = getattr(src_evt, "eventName", "unknown")
dst_name = getattr(dst_evt, "eventName", "unknown")
counts.setdefault(src_name, {})
counts[src_name][dst_name] = counts[src_name].get(dst_name, 0) + 1
distribution: Dict[str, Dict[str, float]] = {}
for src, dsts in counts.items():
total = float(sum(dsts.values()))
distribution[src] = {dst: val / total for dst, val in dsts.items()} if total else {}
return distribution
def _kl_divergence(p: Dict[str, Dict[str, float]], q: Dict[str, Dict[str, float]]) -> float:
eps = 1e-10
total = 0.0
for src, dsts in p.items():
for dst, prob in dsts.items():
ref = q.get(src, {}).get(dst, 0.0)
total += (prob + eps) * np.log((prob + eps) / (ref + eps))
return float(total)
def load_artifacts(artifact_dir: Path | str = DEFAULT_ARTIFACT_DIR) -> SeparabilityArtifacts:
artifact_dir = Path(artifact_dir)
scaler_path = artifact_dir / "scaler.joblib"
model_path = artifact_dir / "classifier.joblib"
metadata_path = artifact_dir / "metadata.json"
if not (scaler_path.exists() and model_path.exists() and metadata_path.exists()):
raise FileNotFoundError(
f"Separability artifacts not found in {artifact_dir}. Run sim.strong_learner.train first."
)
scaler = joblib.load(scaler_path)
classifier = joblib.load(model_path)
with open(metadata_path, "r", encoding="utf-8") as fin:
metadata = json.load(fin)
return SeparabilityArtifacts(
scaler=scaler,
classifier=classifier,
states=list(metadata["reference_states"]),
event_transitions=metadata["event_transitions"],
feature_dim=int(metadata["feature_dim"]),
)
def score_session(
raw_events: Sequence[object],
artifacts: SeparabilityArtifacts,
) -> dict:
events = _normalize_events(raw_events)
if not events:
return {"prob_agent": 0.0, "delta_h": 0.0, "delta_a": 0.0}
reference_mdp = {"states": artifacts.states}
features = featurize_trajectory(events, mdp=reference_mdp, input_dim=artifacts.feature_dim)
scaled = artifacts.scaler.transform(features.reshape(1, -1))
prob_agent = float(artifacts.classifier.predict_proba(scaled)[0, 1])
session_dist = _event_transition_distribution(events)
delta_h = _kl_divergence(session_dist, artifacts.event_transitions.get("human", {}))
delta_a = _kl_divergence(session_dist, artifacts.event_transitions.get("agent", {}))
return {
"prob_agent": prob_agent,
"delta_h": delta_h,
"delta_a": delta_a,
}
def estimate_alpha(prob_agent: float, delta_h: float, delta_a: float, temperature: float = 1.0) -> float:
divergence_mass = delta_h + delta_a
if divergence_mass <= 1e-8:
return float(prob_agent)
ratio = delta_a / divergence_mass
blended = 0.5 * prob_agent + 0.5 * ratio
if temperature <= 0:
return float(np.clip(blended, 0.0, 1.0))
scaled = 1.0 / (1.0 + np.exp(-temperature * (blended - 0.5)))
return float(np.clip(scaled, 0.0, 1.0))
def score_sessions(raw_sessions: Iterable[Sequence[object]], artifacts: SeparabilityArtifacts) -> List[dict]:
return [score_session(events, artifacts) for events in raw_sessions]