class separaiblity significance

This commit is contained in:
2026-02-28 21:38:46 +01:00
parent 8f20359c8c
commit 233ce3be34
5 changed files with 285 additions and 57 deletions

View File

@@ -11,7 +11,7 @@ from pathlib import Path
# import lib utilities for optional use - models keep their own _state_repr for backwards compat
# with the specific event structure (evt.value.payload)
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / 'lib'))
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "lib"))
try:
from lib.state import make_state_repr as lib_make_state_repr
from lib.features import transition_histogram as lib_transition_histogram
@@ -37,7 +37,8 @@ class BehaviorModel:
def _extract_sessions(self) -> List[List[str]]:
trajs = []
for evts in self.data.values():
if len(evts) < 2: continue
if len(evts) < 2:
continue
states = [self._state_repr(e) for e in sorted(evts, key=self._sort_key)]
trajs.append(states)
return trajs
@@ -59,8 +60,10 @@ class BehaviorModel:
return rwd
def _normalize_trans(self, cnts: Dict) -> Dict:
return {s: {s_n: cnt/sum(nxt.values()) for s_n, cnt in nxt.items()}
for s, nxt in cnts.items()}
return {
s: {s_n: cnt / sum(nxt.values()) for s_n, cnt in nxt.items()}
for s, nxt in cnts.items()
}
def build_MDP(self) -> Dict:
trajs = self._extract_sessions()
@@ -69,34 +72,40 @@ class BehaviorModel:
state_rwd = self._calc_rewards(trajs)
self.mdp = {
'states': sorted(states),
'num_states': len(states),
'transitions': trans_prob,
'state_values': {s: np.mean(r) for s, r in state_rwd.items()},
'state_rewards': state_rwd,
'trans_counts': trans_cnt,
"states": sorted(states),
"num_states": len(states),
"transitions": trans_prob,
"state_values": {s: np.mean(r) for s, r in state_rwd.items()},
"state_rewards": state_rwd,
"trans_counts": trans_cnt,
}
return self.mdp
def transition_prob(self, s: str, s_next: str) -> float:
if not self.mdp: raise ValueError("build MDP first")
return self.mdp['transitions'].get(s, {}).get(s_next, 0.0)
if not self.mdp:
raise ValueError("build MDP first")
return self.mdp["transitions"].get(s, {}).get(s_next, 0.0)
def state_value(self, s: str) -> float:
if not self.mdp: raise ValueError("build MDP first")
return self.mdp['state_values'].get(s, 0.0)
if not self.mdp:
raise ValueError("build MDP first")
return self.mdp["state_values"].get(s, 0.0)
def sample_traj(self, start: str, max_len: int = 50) -> List[str]:
if not self.mdp: raise ValueError("build MDP first")
if not self.mdp:
raise ValueError("build MDP first")
path, curr = [start], start
for _ in range(max_len):
nxt = self.mdp['transitions'].get(curr, {})
if not nxt: break
nxt = self.mdp["transitions"].get(curr, {})
if not nxt:
break
curr = np.random.choice(list(nxt.keys()), p=list(nxt.values()))
path.append(curr)
return path
def extract_trajectory_features(self, events: List, max_trans_dim: int = 50) -> np.ndarray:
def extract_trajectory_features(
self, events: List, max_trans_dim: int = 50
) -> np.ndarray:
"""Convert trajectory to feature vector using MDP structure for contrastive learning"""
if not self.mdp:
self.build_MDP()
@@ -108,7 +117,11 @@ class BehaviorModel:
trans_counts = defaultdict(int)
for s, s_next in zip(states, states[1:]):
trans_counts[(s, s_next)] += 1
all_trans = [(s, t) for s in self.mdp['states'] for t in self.mdp['transitions'].get(s, {}).keys()]
all_trans = [
(s, t)
for s in self.mdp["states"]
for t in self.mdp["transitions"].get(s, {}).keys()
]
trans_vec = [trans_counts.get(tr, 0) for tr in all_trans[:max_trans_dim]]
trans_vec = trans_vec + [0] * (max_trans_dim - len(trans_vec)) # pad
total_trans = sum(trans_counts.values()) or 1
@@ -116,11 +129,13 @@ class BehaviorModel:
# state coverage ratio
visited = set(states)
features.append(len(visited) / max(self.mdp['num_states'], 1))
features.append(len(visited) / max(self.mdp["num_states"], 1))
# temporal entropy of transitions
if len(states) > 1:
trans_probs = [self.transition_prob(s, s_n) for s, s_n in zip(states, states[1:])]
trans_probs = [
self.transition_prob(s, s_n) for s, s_n in zip(states, states[1:])
]
entropy = -sum(p * np.log(p + 1e-10) for p in trans_probs if p > 0)
features.append(entropy / max(len(states), 1))
else:
@@ -150,6 +165,7 @@ class AgentBehaviorModel(BehaviorModel):
def _sort_key(self, evt):
return evt.ts
class JointBehaviorModel(BehaviorModel):
def __init__(self, human_dir: str, agent_dir: str):
self.loader = JointLoader(human_dir, agent_dir)
@@ -163,73 +179,164 @@ class JointBehaviorModel(BehaviorModel):
def _sort_key(self, evt):
return evt.ts
def aggregate_event_transitions(mdp: Dict) -> Dict[str, Dict[str, float]]:
evt_trans = defaultdict(lambda: defaultdict(float))
for s, trans in mdp['transitions'].items():
src = s.split('|')[2]
for s, trans in mdp["transitions"].items():
src = s.split("|")[2]
for s_next, prob in trans.items():
dst = s_next.split('|')[2]
dst = s_next.split("|")[2]
evt_trans[src][dst] += prob
for src in evt_trans:
total = sum(evt_trans[src].values())
if total > 0:
evt_trans[src] = {dst: p/total for dst, p in evt_trans[src].items()}
evt_trans[src] = {dst: p / total for dst, p in evt_trans[src].items()}
return dict(evt_trans)
def visualize_mdp(model: BehaviorModel, threshold: float = 0.05, output: str = "mdp_graph",
fmt: str = "svg", view: bool = False, export_dot: bool = False):
if not model.mdp: raise ValueError("build MDP first")
def visualize_mdp(
model: BehaviorModel,
threshold: float = 0.05,
output: str = "mdp_graph",
fmt: str = "svg",
view: bool = False,
export_dot: bool = False,
):
if not model.mdp:
raise ValueError("build MDP first")
evt_trans = aggregate_event_transitions(model.mdp)
g = graphviz.Digraph(format=fmt)
g.attr(rankdir='LR', size='30')
g.attr('node', shape='circle', width='1', height='1')
g.attr(rankdir="LR", size="30")
g.attr("node", shape="circle", width="1", height="1")
events = set(evt_trans.keys()) | {e for trans in evt_trans.values() for e in trans.keys()}
events = set(evt_trans.keys()) | {
e for trans in evt_trans.values() for e in trans.keys()
}
for evt in events:
g.node(evt)
for src, dsts in evt_trans.items():
for dst, prob in dsts.items():
if prob > threshold:
g.edge(src, dst, label=f'{prob:.2f}')
g.edge(src, dst, label=f"{prob:.2f}")
g.render(output, view=view, cleanup=True)
print(f"Saved MDP graph to {output}.{fmt}")
if export_dot:
with open(f"{output}.dot", 'w') as f:
with open(f"{output}.dot", "w") as f:
f.write(g.source)
print(f"Exported DOT source to {output}.dot")
return g
def kl_divergence(p: Dict[str, float], q: Dict[str, float]) -> float:
eps = 1e-10
# p + log(p / q) summed over all keys in P
return sum((p[k] + eps) * np.log((p[k] + eps) / (q.get(k, 0.0) + eps)) for k in p)
def _build_subset_mdp(model: BehaviorModel, session_ids: List) -> Dict:
trajs = []
for sid in session_ids:
evts = model.data.get(sid, [])
if len(evts) < 2:
continue
states = [model._state_repr(e) for e in sorted(evts, key=model._sort_key)]
trajs.append(states)
trans_cnt, _ = model._calc_transitions(trajs)
return {"transitions": model._normalize_trans(trans_cnt)}
def _avg_event_kl(
src_evt: Dict[str, Dict[str, float]], dst_evt: Dict[str, Dict[str, float]]
) -> float:
common = set(src_evt.keys()) & set(dst_evt.keys())
if not common:
return 0.0
return float(np.mean([kl_divergence(src_evt[e], dst_evt[e]) for e in common]))
def bootstrap_intra_class_divergence(
model: BehaviorModel,
n_bootstrap: int = 100,
seed: int = 42,
) -> Dict[str, float]:
session_ids = list(model.data.keys())
n = len(session_ids)
if n < 2:
return {
"mean": 0.0,
"std": 0.0,
"q05": 0.0,
"q95": 0.0,
"n_bootstrap": 0,
"scores": [],
"available": False,
"num_sessions": int(n),
}
half = n // 2
rng = np.random.default_rng(seed)
scores = []
for _ in range(n_bootstrap):
perm = rng.permutation(session_ids)
split_a, split_b = perm[:half], perm[half:]
mdp_a = _build_subset_mdp(model, list(split_a))
mdp_b = _build_subset_mdp(model, list(split_b))
score = _avg_event_kl(
aggregate_event_transitions(mdp_a),
aggregate_event_transitions(mdp_b),
)
scores.append(score)
arr = np.array(scores, dtype=float)
return {
"mean": float(np.mean(arr)),
"std": float(np.std(arr)),
"q05": float(np.quantile(arr, 0.05)),
"q95": float(np.quantile(arr, 0.95)),
"n_bootstrap": int(n_bootstrap),
"scores": arr.tolist(),
"available": True,
"num_sessions": int(n),
}
if __name__ == "__main__":
base_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments"
human_dir, agent_dir = f"{base_dir}/collected_data/", f"{base_dir}/agents/collected_data/"
human_dir, agent_dir = (
f"{base_dir}/collected_data/",
f"{base_dir}/agents/collected_data/",
)
human_model = BehaviorModel(human_dir)
human_mdp = human_model.build_MDP()
print(f"Built MDP: {human_mdp['num_states']} states, "
f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions")
if not human_mdp['states']:
print(
f"Built MDP: {human_mdp['num_states']} states, "
f"{sum(len(t) for t in human_mdp['transitions'].values())} transitions"
)
if not human_mdp["states"]:
exit("No states found")
visualize_mdp(human_model, threshold=0.05, output="human_mdp_viz", fmt="pdf", export_dot=True)
visualize_mdp(
human_model, threshold=0.05, output="human_mdp_viz", fmt="pdf", export_dot=True
)
agent_model = AgentBehaviorModel(agent_dir)
agent_mdp = agent_model.build_MDP()
print(f"AGENT... Built MDP: {agent_mdp['num_states']} states, "
f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions")
if not agent_mdp['states']:
print(
f"AGENT... Built MDP: {agent_mdp['num_states']} states, "
f"{sum(len(t) for t in agent_mdp['transitions'].values())} transitions"
)
if not agent_mdp["states"]:
exit("No states found")
visualize_mdp(agent_model, threshold=0.05, output="agent_mdp_viz", fmt="pdf", export_dot=True)
visualize_mdp(
agent_model, threshold=0.05, output="agent_mdp_viz", fmt="pdf", export_dot=True
)
human_evt = aggregate_event_transitions(human_mdp)
agent_evt = aggregate_event_transitions(agent_mdp)
@@ -239,8 +346,11 @@ if __name__ == "__main__":
if not common:
exit("No common event types for KL divergence analysis")
kl_divs = sorted([(e, kl_divergence(human_evt[e], agent_evt[e])) for e in common],
key=lambda x: x[1], reverse=True)
kl_divs = sorted(
[(e, kl_divergence(human_evt[e], agent_evt[e])) for e in common],
key=lambda x: x[1],
reverse=True,
)
print(f"Average KL divergence: {np.mean([kl for _, kl in kl_divs]):.4f}")
print("\nMost divergent event types:")
@@ -250,9 +360,55 @@ if __name__ == "__main__":
print("\n=== Joint Model (Human + Agent Combined) ===")
joint_model = JointBehaviorModel(human_dir, agent_dir)
joint_mdp = joint_model.build_MDP()
print(f"Built joint MDP: {joint_mdp['num_states']} states, "
f"{sum(len(t) for t in joint_mdp['transitions'].values())} transitions")
if joint_mdp['states']:
visualize_mdp(joint_model, threshold=0.05, output="joint_mdp_viz", fmt="pdf", export_dot=True)
print(
f"Built joint MDP: {joint_mdp['num_states']} states, "
f"{sum(len(t) for t in joint_mdp['transitions'].values())} transitions"
)
if joint_mdp["states"]:
visualize_mdp(
joint_model,
threshold=0.05,
output="joint_mdp_viz",
fmt="pdf",
export_dot=True,
)
# TODO: setup intra class divergence as baseline for evaluating and adding significance to the divergence which we observe across class
inter_class_avg = float(np.mean([kl for _, kl in kl_divs]))
human_intra = bootstrap_intra_class_divergence(
human_model, n_bootstrap=100, seed=42
)
agent_intra = bootstrap_intra_class_divergence(
agent_model, n_bootstrap=100, seed=43
)
pooled_scores = human_intra["scores"] + agent_intra["scores"]
if not pooled_scores:
pooled_scores = [0.0]
pooled_null = np.array(pooled_scores, dtype=float)
p_empirical = float(
(np.sum(pooled_null >= inter_class_avg) + 1) / (len(pooled_null) + 1)
)
print("\nIntra-class KL bootstrap baseline:")
if human_intra["available"]:
print(
f" Human split KL: {human_intra['mean']:.4f} +- {human_intra['std']:.4f} "
f"(5-95%: {human_intra['q05']:.4f}-{human_intra['q95']:.4f}, n_sessions={human_intra['num_sessions']})"
)
else:
print(
f" Human split KL: unavailable (need >=2 sessions, got {human_intra['num_sessions']})"
)
if agent_intra["available"]:
print(
f" Agent split KL: {agent_intra['mean']:.4f} +- {agent_intra['std']:.4f} "
f"(5-95%: {agent_intra['q05']:.4f}-{agent_intra['q95']:.4f}, n_sessions={agent_intra['num_sessions']})"
)
else:
print(
f" Agent split KL: unavailable (need >=2 sessions, got {agent_intra['num_sessions']})"
)
print(f" Between-class KL: {inter_class_avg:.4f}")
print(
f" Lift vs pooled intra mean: {inter_class_avg / max(float(np.mean(pooled_null)), 1e-10):.2f}x"
)
print(f" Empirical p-value (inter > intra): {p_empirical:.4f}")