mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
migrating weak learning
This commit is contained in:
@@ -1,30 +0,0 @@
|
|||||||
from sim.rl.behavior_loader.loader import AgentLoader, Loader, JointLoader
|
|
||||||
from sim.rl.behavior_loader.loader import PayloadModel
|
|
||||||
from arch import WeakClassifier
|
|
||||||
|
|
||||||
agent_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/agents/collected_data/"
|
|
||||||
human_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/collected_data/"
|
|
||||||
|
|
||||||
def augment_trajectory(trajectory : list[PayloadModel], augmentation_rate: float = 0.1) -> list[PayloadModel]:
|
|
||||||
# augmentations possible:
|
|
||||||
# return a sub-trajectory window of the original trajectory
|
|
||||||
# insert random noise events
|
|
||||||
# shuffle a few events (find a few indices and swap them with i+1 neighbor)
|
|
||||||
# adjust metadata
|
|
||||||
return trajectory
|
|
||||||
|
|
||||||
|
|
||||||
def train():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
joint_loader = JointLoader(human_dir, agent_dir)
|
|
||||||
data = joint_loader.get_data()
|
|
||||||
entries, num_entries = joint_loader.get_entries()
|
|
||||||
print(f"Loaded {num_entries} entries")
|
|
||||||
# TODO: augment
|
|
||||||
# fit model
|
|
||||||
model = WeakClassifier()
|
|
||||||
model.fit(data)
|
|
||||||
246
experiments/ml/weak_train.py
Normal file
246
experiments/ml/weak_train.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
import sys
|
||||||
|
sys.path.insert(0, "/home/velocitatem/Documents/Projects/PHANTOM/sim/rl/behavior_loader")
|
||||||
|
sys.path.insert(0, "/home/velocitatem/Documents/Projects/PHANTOM/experiments/ml")
|
||||||
|
|
||||||
|
from sim.rl.behavior_loader.loader import AgentLoader, Loader, JointLoader, PayloadModel
|
||||||
|
from sim.rl.behavior_loader.models import JointBehaviorModel
|
||||||
|
from arch import ContrastiveWeakClassifier, contrastive_loss, featurize_trajectory
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from copy import deepcopy
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torch.optim import Adam
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
RUNS_DIR = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/ml/runs"
|
||||||
|
agent_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/agents/collected_data/"
|
||||||
|
human_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/collected_data/"
|
||||||
|
|
||||||
|
|
||||||
|
def _perturb_ts(evt: PayloadModel, jitter_ms: int = 500) -> PayloadModel:
|
||||||
|
"""Add random jitter to event timestamp"""
|
||||||
|
new_evt = deepcopy(evt)
|
||||||
|
try:
|
||||||
|
ts = datetime.fromisoformat(evt.ts.replace('Z', '+00:00'))
|
||||||
|
delta = timedelta(milliseconds=random.randint(-jitter_ms, jitter_ms))
|
||||||
|
new_evt.ts = (ts + delta).isoformat()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return new_evt
|
||||||
|
|
||||||
|
|
||||||
|
def augment_trajectory(trajectory: List[PayloadModel], rate: float = 0.1) -> List[PayloadModel]:
|
||||||
|
"""Apply random augmentation to trajectory for contrastive learning"""
|
||||||
|
if len(trajectory) < 2:
|
||||||
|
return trajectory
|
||||||
|
|
||||||
|
aug_type = random.choice(['window', 'shuffle', 'noise', 'drop'])
|
||||||
|
|
||||||
|
if aug_type == 'window': # random contiguous sub-sequence (70-100% length)
|
||||||
|
min_len = max(2, int(len(trajectory) * 0.7))
|
||||||
|
sub_len = random.randint(min_len, len(trajectory))
|
||||||
|
start = random.randint(0, len(trajectory) - sub_len)
|
||||||
|
return trajectory[start:start + sub_len]
|
||||||
|
|
||||||
|
elif aug_type == 'shuffle': # swap adjacent pairs with probability rate
|
||||||
|
result = list(trajectory)
|
||||||
|
for i in range(len(result) - 1):
|
||||||
|
if random.random() < rate:
|
||||||
|
result[i], result[i + 1] = result[i + 1], result[i]
|
||||||
|
return result
|
||||||
|
|
||||||
|
elif aug_type == 'drop': # drop events with probability rate
|
||||||
|
result = [e for e in trajectory if random.random() > rate]
|
||||||
|
return result if len(result) >= 2 else trajectory[:2]
|
||||||
|
|
||||||
|
elif aug_type == 'noise': # perturb timestamps
|
||||||
|
return [_perturb_ts(e, jitter_ms=500) for e in trajectory]
|
||||||
|
|
||||||
|
return trajectory
|
||||||
|
|
||||||
|
|
||||||
|
class TripletDataset(Dataset):
|
||||||
|
"""Generate (anchor, positive, negative) triplets on-the-fly with augmentation"""
|
||||||
|
def __init__(self, data: Dict[str, List[PayloadModel]], mdp: Optional[Dict], augment_fn, input_dim: int = 64, multiplier: int = 10):
|
||||||
|
self.sessions = list(data.items())
|
||||||
|
self.human_ids = [i for i, (sid, _) in enumerate(self.sessions) if sid.startswith('human_')]
|
||||||
|
self.agent_ids = [i for i, (sid, _) in enumerate(self.sessions) if sid.startswith('agent_')]
|
||||||
|
self.mdp = mdp
|
||||||
|
self.augment = augment_fn
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
|
if not self.human_ids or not self.agent_ids:
|
||||||
|
raise ValueError(f"Need both human ({len(self.human_ids)}) and agent ({len(self.agent_ids)}) sessions")
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.sessions) * self.multiplier
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
anchor_idx = idx % len(self.sessions)
|
||||||
|
sid, events = self.sessions[anchor_idx]
|
||||||
|
is_human = sid.startswith('human_')
|
||||||
|
|
||||||
|
anchor = featurize_trajectory(events, self.mdp, self.input_dim)
|
||||||
|
positive = featurize_trajectory(self.augment(events), self.mdp, self.input_dim)
|
||||||
|
|
||||||
|
neg_pool = self.agent_ids if is_human else self.human_ids
|
||||||
|
neg_idx = random.choice(neg_pool)
|
||||||
|
negative = featurize_trajectory(self.sessions[neg_idx][1], self.mdp, self.input_dim)
|
||||||
|
|
||||||
|
label = 0 if is_human else 1 # 0=human, 1=agent
|
||||||
|
return (torch.tensor(anchor, dtype=torch.float32),
|
||||||
|
torch.tensor(positive, dtype=torch.float32),
|
||||||
|
torch.tensor(negative, dtype=torch.float32),
|
||||||
|
torch.tensor(label, dtype=torch.long))
|
||||||
|
|
||||||
|
|
||||||
|
def train(epochs: int = 100, lr: float = 1e-3, batch_size: int = 4, input_dim: int = 64,
|
||||||
|
embed_dim: int = 32, margin: float = 0.3, verbose: bool = True, run_name: str = None):
|
||||||
|
"""Train contrastive weak classifier on human/agent trajectories"""
|
||||||
|
joint = JointLoader(human_dir, agent_dir)
|
||||||
|
data = joint.get_data()
|
||||||
|
if verbose:
|
||||||
|
print(f"Loaded {len(data)} sessions")
|
||||||
|
|
||||||
|
joint_model = JointBehaviorModel(human_dir, agent_dir)
|
||||||
|
ref_mdp = joint_model.build_MDP()
|
||||||
|
|
||||||
|
dataset = TripletDataset(data, ref_mdp, augment_trajectory, input_dim=input_dim)
|
||||||
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||||
|
|
||||||
|
model = ContrastiveWeakClassifier(input_dim=input_dim, embed_dim=embed_dim, margin=margin)
|
||||||
|
model.to_device()
|
||||||
|
|
||||||
|
run_name = run_name or f"d{input_dim}_e{embed_dim}_lr{lr}_m{margin}_{datetime.now():%Y%m%d_%H%M%S}"
|
||||||
|
writer = SummaryWriter(f"{RUNS_DIR}/train/{run_name}")
|
||||||
|
|
||||||
|
optimizer = Adam(list(model.encoder.parameters()) + list(model.classifier.parameters()), lr=lr)
|
||||||
|
ce_loss_fn = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
best_loss = float('inf')
|
||||||
|
for epoch in range(epochs):
|
||||||
|
model.encoder.train()
|
||||||
|
model.classifier.train()
|
||||||
|
total_loss, n_batches = 0.0, 0
|
||||||
|
|
||||||
|
for anchor, positive, negative, labels in loader:
|
||||||
|
anchor, positive, negative, labels = [t.to(model.device) for t in [anchor, positive, negative, labels]]
|
||||||
|
z_a, z_p, z_n = [model.encoder(t.unsqueeze(1)) for t in [anchor, positive, negative]]
|
||||||
|
|
||||||
|
trip_loss = contrastive_loss(z_a, z_p, z_n, margin=model.margin)
|
||||||
|
ce = ce_loss_fn(model.classifier(z_a), labels)
|
||||||
|
loss = trip_loss + 0.5 * ce
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.item()
|
||||||
|
n_batches += 1
|
||||||
|
|
||||||
|
avg_loss = total_loss / max(n_batches, 1)
|
||||||
|
writer.add_scalar('loss', avg_loss, epoch)
|
||||||
|
|
||||||
|
if verbose and (epoch + 1) % 10 == 0:
|
||||||
|
print(f"Epoch {epoch+1}/{epochs}: loss={avg_loss:.4f}")
|
||||||
|
if avg_loss < best_loss:
|
||||||
|
best_loss = avg_loss
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
if verbose:
|
||||||
|
print(f"Done. Best={best_loss:.4f} TB:{RUNS_DIR}/train/{run_name}")
|
||||||
|
|
||||||
|
return model, ref_mdp
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_loocv(input_dim: int = 64, embed_dim: int = 32, epochs_per_fold: int = 50,
|
||||||
|
lr: float = 1e-3, margin: float = 0.3, run_name: str = None):
|
||||||
|
"""Leave-one-out cross-validation given limited samples"""
|
||||||
|
joint = JointLoader(human_dir, agent_dir)
|
||||||
|
data = joint.get_data()
|
||||||
|
session_ids = list(data.keys())
|
||||||
|
|
||||||
|
joint_model = JointBehaviorModel(human_dir, agent_dir)
|
||||||
|
ref_mdp = joint_model.build_MDP()
|
||||||
|
|
||||||
|
run_name = run_name or f"loocv_d{input_dim}_e{embed_dim}_m{margin}_{datetime.now():%Y%m%d_%H%M%S}"
|
||||||
|
writer = SummaryWriter(f"{RUNS_DIR}/eval/{run_name}")
|
||||||
|
|
||||||
|
predictions, actuals = [], []
|
||||||
|
|
||||||
|
for fold_idx, test_sid in enumerate(session_ids):
|
||||||
|
train_data = {k: v for k, v in data.items() if k != test_sid}
|
||||||
|
test_events = data[test_sid]
|
||||||
|
test_label = 0 if test_sid.startswith('human_') else 1
|
||||||
|
|
||||||
|
n_human = sum(1 for k in train_data if k.startswith('human_'))
|
||||||
|
n_agent = sum(1 for k in train_data if k.startswith('agent_'))
|
||||||
|
if n_human == 0 or n_agent == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset = TripletDataset(train_data, ref_mdp, augment_trajectory, input_dim=input_dim, multiplier=5)
|
||||||
|
loader = DataLoader(dataset, batch_size=2, shuffle=True, drop_last=True)
|
||||||
|
|
||||||
|
model = ContrastiveWeakClassifier(input_dim=input_dim, embed_dim=embed_dim, margin=margin)
|
||||||
|
model.to_device()
|
||||||
|
optimizer = Adam(list(model.encoder.parameters()) + list(model.classifier.parameters()), lr=lr)
|
||||||
|
|
||||||
|
model.encoder.train()
|
||||||
|
model.classifier.train()
|
||||||
|
for _ in range(epochs_per_fold):
|
||||||
|
for anchor, positive, negative, labels in loader:
|
||||||
|
z_a, z_p, z_n = [model.encoder(t.unsqueeze(1).to(model.device)) for t in [anchor, positive, negative]]
|
||||||
|
loss = contrastive_loss(z_a, z_p, z_n, margin=margin)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
test_feat = featurize_trajectory(test_events, ref_mdp, input_dim)
|
||||||
|
pred = model.predict(test_feat.reshape(1, -1))[0]
|
||||||
|
predictions.append(pred)
|
||||||
|
actuals.append(test_label)
|
||||||
|
print(f" {test_sid[:12]}...: pred={pred}, actual={test_label}, {'OK' if pred == test_label else 'MISS'}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
if predictions:
|
||||||
|
acc = sum(p == a for p, a in zip(predictions, actuals)) / len(predictions)
|
||||||
|
tp = sum(1 for p, a in zip(predictions, actuals) if p == 1 and a == 1)
|
||||||
|
fp = sum(1 for p, a in zip(predictions, actuals) if p == 1 and a == 0)
|
||||||
|
fn = sum(1 for p, a in zip(predictions, actuals) if p == 0 and a == 1)
|
||||||
|
prec, rec = tp / max(tp + fp, 1), tp / max(tp + fn, 1)
|
||||||
|
f1 = 2 * prec * rec / max(prec + rec, 1e-10)
|
||||||
|
writer.add_scalar('accuracy', acc, 0)
|
||||||
|
writer.add_scalar('f1', f1, 0)
|
||||||
|
writer.add_scalar('precision', prec, 0)
|
||||||
|
writer.add_scalar('recall', rec, 0)
|
||||||
|
writer.close()
|
||||||
|
print(f"\nAccuracy: {acc:.2%} F1: {f1:.3f} TB:{RUNS_DIR}/eval/{run_name}")
|
||||||
|
return acc, predictions, actuals
|
||||||
|
writer.close()
|
||||||
|
return 0.0, [], []
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--mode', choices=['train', 'eval'], default='train')
|
||||||
|
parser.add_argument('--epochs', type=int, default=100)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--margin', type=float, default=0.3)
|
||||||
|
parser.add_argument('--input-dim', type=int, default=64)
|
||||||
|
parser.add_argument('--embed-dim', type=int, default=32)
|
||||||
|
parser.add_argument('--run-name', type=str, default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode == 'train':
|
||||||
|
model, mdp = train(epochs=args.epochs, lr=args.lr, input_dim=args.input_dim,
|
||||||
|
embed_dim=args.embed_dim, margin=args.margin, run_name=args.run_name)
|
||||||
|
else:
|
||||||
|
evaluate_loocv(input_dim=args.input_dim, embed_dim=args.embed_dim, epochs_per_fold=args.epochs,
|
||||||
|
lr=args.lr, margin=args.margin, run_name=args.run_name)
|
||||||
Reference in New Issue
Block a user