refactoring training spc setup and benchmarking

This commit is contained in:
2026-03-08 18:30:53 +01:00
parent 9fafb26ec8
commit 73246d7dd8
36 changed files with 2180 additions and 613 deletions

View File

@@ -1,38 +1,39 @@
from .demand import estimate_demand, estimate_weighted_demand, generate_demand_for_actor
from .behavior import sample_behavior, get_transition_models, trajectory_to_events
from .render import DashboardRenderer, style_axis
from .wrappers import EconomicMetricsWrapper
from .callbacks import MetricsCallback, EvalMetricsCallback, CheckpointArtifactCallback
from .providers import (
ProviderBenchmark,
ProviderResult,
BenchmarkConfig,
RandomBaseline,
SurgeBaseline,
)
from .coi import compute_uplift_coi, extract_purchases, compute_agent_probability
from .discrete import EventQTable
from __future__ import annotations
__all__ = [
"estimate_demand",
"estimate_weighted_demand",
"generate_demand_for_actor",
"sample_behavior",
"get_transition_models",
"trajectory_to_events",
"DashboardRenderer",
"style_axis",
"EconomicMetricsWrapper",
"MetricsCallback",
"EvalMetricsCallback",
"CheckpointArtifactCallback",
"ProviderBenchmark",
"ProviderResult",
"BenchmarkConfig",
"RandomBaseline",
"SurgeBaseline",
"compute_uplift_coi",
"extract_purchases",
"compute_agent_probability",
"EventQTable",
]
from importlib import import_module
_EXPORTS: dict[str, tuple[str, str]] = {
"estimate_demand": (".demand", "estimate_demand"),
"estimate_weighted_demand": (".demand", "estimate_weighted_demand"),
"generate_demand_for_actor": (".demand", "generate_demand_for_actor"),
"sample_behavior": (".behavior", "sample_behavior"),
"get_transition_models": (".behavior", "get_transition_models"),
"trajectory_to_events": (".behavior", "trajectory_to_events"),
"DashboardRenderer": (".render", "DashboardRenderer"),
"style_axis": (".render", "style_axis"),
"EconomicMetricsWrapper": (".wrappers", "EconomicMetricsWrapper"),
"MetricsCallback": (".callbacks", "MetricsCallback"),
"EvalMetricsCallback": (".callbacks", "EvalMetricsCallback"),
"CheckpointArtifactCallback": (".callbacks", "CheckpointArtifactCallback"),
"ProviderBenchmark": (".providers", "ProviderBenchmark"),
"ProviderResult": (".providers", "ProviderResult"),
"BenchmarkConfig": (".providers", "BenchmarkConfig"),
"RandomBaseline": (".providers", "RandomBaseline"),
"SurgeBaseline": (".providers", "SurgeBaseline"),
"compute_uplift_coi": (".coi", "compute_uplift_coi"),
"extract_purchases": (".coi", "extract_purchases"),
"compute_agent_probability": (".coi", "compute_agent_probability"),
"EventQTable": (".discrete", "EventQTable"),
}
__all__ = sorted(_EXPORTS)
def __getattr__(name: str):
if name not in _EXPORTS:
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
module_name, attr_name = _EXPORTS[name]
module = import_module(module_name, package=__name__)
value = getattr(module, attr_name)
globals()[name] = value
return value