mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
refactored training approaches
This commit is contained in:
@@ -3,11 +3,16 @@ from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parents[2]))
|
||||
|
||||
from sim.rl.behavior_loader.models import (
|
||||
BehaviorModel,
|
||||
AgentBehaviorModel,
|
||||
aggregate_event_transitions,
|
||||
)
|
||||
try:
|
||||
from sim.rl.behavior_loader.models import (
|
||||
BehaviorModel,
|
||||
AgentBehaviorModel,
|
||||
aggregate_event_transitions,
|
||||
)
|
||||
except ImportError:
|
||||
BehaviorModel = None
|
||||
AgentBehaviorModel = None
|
||||
aggregate_event_transitions = None
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from .demand import generate_demand_for_actor
|
||||
@@ -20,6 +25,12 @@ _cache = {} # lazy cache for models and base pivots
|
||||
|
||||
|
||||
def _get_base_pivot(human: bool):
|
||||
if (
|
||||
BehaviorModel is None
|
||||
or AgentBehaviorModel is None
|
||||
or aggregate_event_transitions is None
|
||||
):
|
||||
raise ImportError("behavior loader dependencies are unavailable")
|
||||
key = "human" if human else "agent"
|
||||
if key not in _cache:
|
||||
model = BehaviorModel(human_dir) if human else AgentBehaviorModel(agent_dir)
|
||||
@@ -34,6 +45,13 @@ def get_transition_models():
|
||||
returns:
|
||||
tuple: (human_transitions, agent_transitions) as dicts of event->event->prob
|
||||
"""
|
||||
if (
|
||||
BehaviorModel is None
|
||||
or AgentBehaviorModel is None
|
||||
or aggregate_event_transitions is None
|
||||
):
|
||||
raise ImportError("behavior loader dependencies are unavailable")
|
||||
|
||||
human_model = BehaviorModel(human_dir)
|
||||
agent_model = AgentBehaviorModel(agent_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user