feat: joint loader

This commit is contained in:
2026-01-13 16:42:50 +01:00
parent af23d2f736
commit 87a35fad2c
2 changed files with 70 additions and 9 deletions

View File

@@ -1,5 +1,5 @@
from experiments.agents.base import Agent
from loader import Loader, AgentLoader
from loader import Loader, AgentLoader, JointLoader
from collections import defaultdict
from typing import Dict, List, Tuple, Set
import numpy as np
@@ -109,6 +109,28 @@ class AgentBehaviorModel(BehaviorModel):
trajectories.append(states)
return trajectories
class JointBehaviorModel(BehaviorModel):
"""behavior model for combined human+agent data (flat PayloadModel distribution)"""
def __init__(self, human_dir: str = DIR, agent_dir: str = AGENT_DIR):
self.loader = JointLoader(human_dir, agent_dir)
self.data = self.loader.get_data()
self.entries, self.num_entries = self.loader.get_entries()
self.mdp = None
def _state_repr(self, evt) -> str:
# direct access to PayloadModel fields (JointLoader unwraps to PayloadModel)
return f"{evt.page or 'unk'}|{evt.productId or 'none'}|{evt.eventName}"
def _extract_sessions(self):
trajectories = []
for sid, evts in self.data.items():
if len(evts) < 2: continue
# sort by timestamp string (ISO format sorts lexicographically)
states = [self._state_repr(e) for e in sorted(evts, key=lambda x: x.ts)]
trajectories.append(states)
return trajectories
def aggregate_event_transitions(mdp: Dict) -> Dict[str, Dict[str, float]]:
"""aggregate state transitions by event type and normalize"""
evt_trans = defaultdict(lambda: defaultdict(float))
@@ -209,3 +231,11 @@ if __name__ == "__main__":
print(f"\nMost divergent event types:")
for evt, kl in kl_divs:
print(f" {evt}: {kl:.4f}")
# build joint model (combined distribution)
print("\n=== Joint Model (Human + Agent Combined) ===")
joint_model = JointBehaviorModel()
joint_mdp = joint_model.build_MDP()
print(f"Built joint MDP: {joint_mdp['num_states']} states, {sum(len(t) for t in joint_mdp['transitions'].values())} transitions")
if joint_mdp['states']:
visualize_mdp(joint_model, threshold=0.05, output="joint_mdp_viz", fmt="pdf", export_dot=True)