mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
tpu ready remodel
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from sys import platform
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import numpy as np
|
||||
from .lib.demand import generate_demand_for_actor, estimate_demand
|
||||
from .lib.behavior import get_adjusted_transitions, sample_behavior_from_transitions
|
||||
@@ -8,9 +7,6 @@ from logging import INFO, getLogger
|
||||
logger = getLogger(__name__)
|
||||
logger.setLevel(INFO)
|
||||
|
||||
# shared pool; reused across act() calls to avoid per-call thread-spawn overhead
|
||||
_pool = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
class MarketEngine:
|
||||
"""implements separate demand distributions for humans and agents per Section 3.1.1"""
|
||||
@@ -54,16 +50,14 @@ class MarketEngine:
|
||||
agent_transitions = get_adjusted_transitions(demand_a, human=False)
|
||||
# sample N trajectories in parallel; each chain is independent so threads
|
||||
# do not share state and numpy's per-call RNG is thread-safe
|
||||
h_futs = [
|
||||
_pool.submit(sample_behavior_from_transitions, human_transitions)
|
||||
human_t = [
|
||||
sample_behavior_from_transitions(human_transitions)
|
||||
for _ in range(self.Nhumans)
|
||||
]
|
||||
a_futs = [
|
||||
_pool.submit(sample_behavior_from_transitions, agent_transitions)
|
||||
agent_t = [
|
||||
sample_behavior_from_transitions(agent_transitions)
|
||||
for _ in range(self.Nagents)
|
||||
]
|
||||
human_t = [f.result() for f in h_futs]
|
||||
agent_t = [f.result() for f in a_futs]
|
||||
# store trajectories for agent probability calculation
|
||||
self.last_trajectories = human_t + agent_t
|
||||
return estimate_demand(self.last_trajectories, self.action_weights)
|
||||
|
||||
Reference in New Issue
Block a user