tpu ready remodel

This commit is contained in:
2026-03-11 20:49:28 +01:00
parent fa2dde8307
commit d3a4febfde
13 changed files with 63 additions and 156 deletions

View File

@@ -1,5 +1,4 @@
from sys import platform
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from .lib.demand import generate_demand_for_actor, estimate_demand
from .lib.behavior import get_adjusted_transitions, sample_behavior_from_transitions
@@ -8,9 +7,6 @@ from logging import INFO, getLogger
logger = getLogger(__name__)
logger.setLevel(INFO)
# shared pool; reused across act() calls to avoid per-call thread-spawn overhead
_pool = ThreadPoolExecutor(max_workers=4)
class MarketEngine:
"""implements separate demand distributions for humans and agents per Section 3.1.1"""
@@ -54,16 +50,14 @@ class MarketEngine:
agent_transitions = get_adjusted_transitions(demand_a, human=False)
# sample N trajectories in parallel; each chain is independent so threads
# do not share state and numpy's per-call RNG is thread-safe
h_futs = [
_pool.submit(sample_behavior_from_transitions, human_transitions)
human_t = [
sample_behavior_from_transitions(human_transitions)
for _ in range(self.Nhumans)
]
a_futs = [
_pool.submit(sample_behavior_from_transitions, agent_transitions)
agent_t = [
sample_behavior_from_transitions(agent_transitions)
for _ in range(self.Nagents)
]
human_t = [f.result() for f in h_futs]
agent_t = [f.result() for f in a_futs]
# store trajectories for agent probability calculation
self.last_trajectories = human_t + agent_t
return estimate_demand(self.last_trajectories, self.action_weights)

View File

@@ -143,6 +143,11 @@ def get_adjusted_transitions(condition, human=True) -> _TransitionTable:
cache_key = (human, tuple(np.round(condition, 4).tolist()))
if cache_key in _transition_cache:
return _transition_cache[cache_key]
# prevent OOM by capping cache size
if len(_transition_cache) > 100:
_transition_cache.clear()
base_pivot = _get_base_pivot(human)
df = adjust_behavior_to_condition(condition, base_pivot)
table = _TransitionTable(df)