mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
feat: introduction of agentinc MDPs and KL divergence of > 2
This commit is contained in:
@@ -56,7 +56,27 @@ class Loader:
|
||||
def get_entries(self) -> tuple[list[str], int]:
|
||||
return self.entries, len(self.entries)
|
||||
|
||||
class AgentLoader(Loader):
|
||||
"""Loader for agent interaction data with simplified schema (direct PayloadModel format)"""
|
||||
|
||||
def _is_admin_page_simple(self, interaction: PayloadModel) -> bool:
|
||||
return interaction.page and interaction.page.startswith("/admin/")
|
||||
|
||||
def _load_sessions(self) -> dict:
|
||||
sessions = {}
|
||||
for entry in self.entries:
|
||||
int_path = f"{self.src_dir}/{entry}/int.json"
|
||||
raw = json.load(open(int_path))
|
||||
ints = [PayloadModel(**i) for i in raw]
|
||||
sessions[entry] = [i for i in ints if not self._is_admin_page_simple(i)]
|
||||
return sessions
|
||||
|
||||
if __name__ == "__main__":
|
||||
DIR = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/agents/collected_data/"
|
||||
loader = AgentLoader(DIR)
|
||||
_, n = loader.get_entries()
|
||||
print(f"Loaded {n} sessions from {DIR}")
|
||||
|
||||
DIR = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/collected_data/"
|
||||
loader = Loader(DIR)
|
||||
_, n = loader.get_entries()
|
||||
|
||||
Reference in New Issue
Block a user