mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
strong dataset gathering
This commit is contained in:
99
sim/strong_learner/data.py
Normal file
99
sim/strong_learner/data.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import os, requests, py7zr
|
||||
import pandas as pd
|
||||
from typing import Generator
|
||||
try:
|
||||
from sim.rl.behavior_loader.loader import PayloadModel, ValueModel, InteractionModel, Loader
|
||||
except ImportError:
|
||||
from loader import PayloadModel, ValueModel, InteractionModel, Loader
|
||||
|
||||
class YooChooseLoader(Loader):
|
||||
URL = "https://s3-eu-west-1.amazonaws.com/yc-rdata/yoochoose-data.7z"
|
||||
CLICK_COLS = ['session_id', 'ts', 'item_id', 'category']
|
||||
BUY_COLS = ['session_id', 'ts', 'item_id', 'price', 'quantity']
|
||||
|
||||
def __init__(self, root_dir: str = "data/yoochoose", chunk_size: int = 500_000, max_sessions: int = 1000):
|
||||
self.root = root_dir
|
||||
self.chunk_size = chunk_size
|
||||
self.max_sessions = max_sessions
|
||||
self.click_path = f"{root_dir}/yoochoose-clicks.dat"
|
||||
self.buy_path = f"{root_dir}/yoochoose-buys.dat"
|
||||
if not os.path.exists(self.click_path): self._setup()
|
||||
self.data = self._load_sessions(max_sessions)
|
||||
self.entries = list(self.data.keys())
|
||||
|
||||
def _setup(self):
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
zip_path = f"{self.root}/temp.7z"
|
||||
with requests.get(self.URL, stream=True) as r:
|
||||
with open(zip_path, 'wb') as f:
|
||||
for chunk in r.iter_content(8192): f.write(chunk)
|
||||
with py7zr.SevenZipFile(zip_path, 'r') as z: z.extractall(self.root)
|
||||
os.remove(zip_path)
|
||||
|
||||
def _make_interaction(self, sid: str, ts: str, item_id: str, event: str, page: str, meta: dict) -> InteractionModel:
|
||||
payload = PayloadModel(
|
||||
sessionId=sid, experimentId=None, eventName=event,
|
||||
page=page, productId=item_id, metadata=meta,
|
||||
storeMode="yoochoose", userAgent="dataset", ts=ts
|
||||
)
|
||||
return InteractionModel(
|
||||
partitionID=0, offset=0, timestamp=0, compression="",
|
||||
isTransactional=False, headers=[], key={},
|
||||
value=ValueModel(payload=payload, encoding="json", isPayloadNull=False, schemaId=1, size=0)
|
||||
)
|
||||
|
||||
def _parse_category(self, cat) -> str:
|
||||
if pd.isna(cat) or cat == "0": return "unknown"
|
||||
if cat == "S": return "special_offer"
|
||||
try:
|
||||
n = int(cat)
|
||||
return f"category_{n}" if 1 <= n <= 12 else f"brand_{n}"
|
||||
except: return str(cat)
|
||||
|
||||
def stream_clicks(self) -> Generator[InteractionModel, None, None]:
|
||||
with pd.read_csv(self.click_path, names=self.CLICK_COLS, chunksize=self.chunk_size, header=None) as reader:
|
||||
for chunk in reader:
|
||||
for r in chunk.itertuples(index=False):
|
||||
yield self._make_interaction(
|
||||
str(r.session_id), r.ts, str(r.item_id),
|
||||
"view_item_page", self._parse_category(r.category), {}
|
||||
)
|
||||
|
||||
def stream_buys(self) -> Generator[InteractionModel, None, None]:
|
||||
with pd.read_csv(self.buy_path, names=self.BUY_COLS, chunksize=self.chunk_size, header=None) as reader:
|
||||
for chunk in reader:
|
||||
for r in chunk.itertuples(index=False):
|
||||
yield self._make_interaction(
|
||||
str(r.session_id), r.ts, str(r.item_id),
|
||||
"purchase_complete", "/checkout", {"price": r.price, "quantity": r.quantity}
|
||||
)
|
||||
|
||||
def stream(self) -> Generator[InteractionModel, None, None]:
|
||||
yield from self.stream_clicks()
|
||||
yield from self.stream_buys()
|
||||
|
||||
def _load_sessions(self, max_sessions: int | None = None) -> dict:
|
||||
sessions = {}
|
||||
for interaction in self.stream():
|
||||
sid = interaction.value.payload.sessionId
|
||||
if sid not in sessions:
|
||||
if max_sessions and len(sessions) >= max_sessions: continue
|
||||
sessions[sid] = []
|
||||
sessions[sid].append(interaction)
|
||||
for sid in sessions: sessions[sid].sort(key=lambda x: x.value.payload.ts)
|
||||
return sessions
|
||||
|
||||
def get_data(self) -> dict:
|
||||
return self.data
|
||||
|
||||
def get_entries(self) -> tuple[list[str], int]:
|
||||
return self.entries, len(self.entries)
|
||||
|
||||
if __name__ == "__main__":
|
||||
loader = YooChooseLoader(max_sessions=100)
|
||||
views, purchases = 0, 0
|
||||
for sid, evts in loader.get_data().items():
|
||||
for e in evts:
|
||||
if e.value.payload.eventName == "view_item_page": views += 1
|
||||
elif e.value.payload.eventName == "purchase_complete": purchases += 1
|
||||
print(f"Loaded {len(loader.entries)} sessions: {views} view_item_page, {purchases} purchase_complete")
|
||||
Reference in New Issue
Block a user