diff --git a/sim/rl/behavior_loader/loader.py b/sim/rl/behavior_loader/loader.py index 620576c..3336956 100644 --- a/sim/rl/behavior_loader/loader.py +++ b/sim/rl/behavior_loader/loader.py @@ -1,6 +1,6 @@ import os -from pydantic import BaseModel as Base import json +from pydantic import BaseModel as Base class PayloadModel(Base): sessionId: str @@ -30,6 +30,9 @@ class InteractionModel(Base): key: dict value: ValueModel +def _is_admin(page: str | None) -> bool: + return page is not None and page.startswith("/admin/") + class Loader: def __init__(self, src_dir: str): self.src_dir = src_dir @@ -37,17 +40,13 @@ class Loader: if not self.entries: raise ValueError("empty directory") self.data = self._load_sessions() - def _is_admin_page(self, interaction: InteractionModel) -> bool: - page = interaction.value.payload.page - return page and page.startswith("/admin/") - def _load_sessions(self) -> dict: sessions = {} for entry in self.entries: - int_path = f"{self.src_dir}/{entry}/int.json" - raw = json.load(open(int_path)) + with open(f"{self.src_dir}/{entry}/int.json") as f: + raw = json.load(f) ints = [InteractionModel(**i) for i in raw] - sessions[entry] = [i for i in ints if not self._is_admin_page(i)] + sessions[entry] = [i for i in ints if not _is_admin(i.value.payload.page)] return sessions def get_data(self) -> dict: @@ -57,40 +56,29 @@ class Loader: return self.entries, len(self.entries) class AgentLoader(Loader): - """Loader for agent interaction data with simplified schema (direct PayloadModel format)""" - - def _is_admin_page_simple(self, interaction: PayloadModel) -> bool: - return interaction.page and interaction.page.startswith("/admin/") - def _load_sessions(self) -> dict: sessions = {} for entry in self.entries: - int_path = f"{self.src_dir}/{entry}/int.json" - raw = json.load(open(int_path)) + with open(f"{self.src_dir}/{entry}/int.json") as f: + raw = json.load(f) ints = [PayloadModel(**i) for i in raw] - sessions[entry] = [i for i in ints if not self._is_admin_page_simple(i)] + sessions[entry] = [i for i in ints if not _is_admin(i.page)] return sessions class JointLoader: - """Loader for combined human (Kafka) and agent (direct) data without discrimination""" - def __init__(self, human_dir: str, agent_dir: str): - self.human_dir = human_dir - self.agent_dir = agent_dir self.human_loader = Loader(human_dir) self.agent_loader = AgentLoader(agent_dir) - self.data = self._load_joint_sessions() + self.data = self._merge() self.entries = list(self.data.keys()) - def _load_joint_sessions(self) -> dict: - sessions = {} - # load human sessions (unwrap from Kafka format to PayloadModel) - for sid, evts in self.human_loader.get_data().items(): - sessions[f"human_{sid}"] = [evt.value.payload for evt in evts] - # load agent sessions (already PayloadModel) - for sid, evts in self.agent_loader.get_data().items(): - sessions[f"agent_{sid}"] = evts - return sessions + def _merge(self) -> dict: + return { + **{f"human_{sid}": [e.value.payload for e in evts] + for sid, evts in self.human_loader.get_data().items()}, + **{f"agent_{sid}": evts + for sid, evts in self.agent_loader.get_data().items()} + } def get_data(self) -> dict: return self.data @@ -99,16 +87,11 @@ class JointLoader: return self.entries, len(self.entries) if __name__ == "__main__": - AGENT_DIR = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/agents/collected_data/" - loader = AgentLoader(AGENT_DIR) - _, n = loader.get_entries() - print(f"Loaded {n} agent sessions from {AGENT_DIR}") + agent_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/agents/collected_data/" + human_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/collected_data/" - HUMAN_DIR = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/collected_data/" - loader = Loader(HUMAN_DIR) - _, n = loader.get_entries() - print(f"Loaded {n} human sessions from {HUMAN_DIR}") - - joint_loader = JointLoader(HUMAN_DIR, AGENT_DIR) - _, n = joint_loader.get_entries() - print(f"Loaded {n} total sessions (combined) from joint loader") + for name, cls, path in [("agent", AgentLoader, agent_dir), + ("human", Loader, human_dir), + ("joint", lambda d: JointLoader(human_dir, d), agent_dir)]: + ldr = cls(path) if name != "joint" else cls(agent_dir) + print(f"Loaded {len(ldr.get_entries()[0])} {name} sessions")