chore: better test consistency before agnet

This commit is contained in:
2026-01-12 22:33:47 +01:00
parent e89cb263d4
commit 3c141a4b6c
4 changed files with 89 additions and 3 deletions

View File

@@ -112,11 +112,14 @@ services:
depends_on:
- postgres
environment:
- AIRFLOW__CORE__EXECUTOR=SequentialExecutor
- AIRFLOW__CORE__EXECUTOR=LocalExecutor
- AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=postgresql+psycopg2://airflow:airflow@postgres/airflow
- AIRFLOW__CORE__FERNET_KEY=${AIRFLOW_FERNET_KEY}
- AIRFLOW__CORE__LOAD_EXAMPLES=false
- AIRFLOW__CORE__ENABLE_XCOM_PICKLING=true
- AIRFLOW__CORE__PARALLELISM=16
- AIRFLOW__CORE__DAG_CONCURRENCY=8
- AIRFLOW__CORE__MAX_ACTIVE_RUNS_PER_DAG=4
- _AIRFLOW_DB_MIGRATE=true
- _AIRFLOW_WWW_USER_CREATE=true
- _AIRFLOW_WWW_USER_USERNAME=admin
@@ -136,12 +139,17 @@ services:
- airflow-init
- redis
environment:
- AIRFLOW__CORE__EXECUTOR=SequentialExecutor
- AIRFLOW__CORE__EXECUTOR=LocalExecutor
- AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=postgresql+psycopg2://airflow:airflow@postgres/airflow
- AIRFLOW__CORE__FERNET_KEY=${AIRFLOW_FERNET_KEY}
- AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION=true
- AIRFLOW__CORE__LOAD_EXAMPLES=false
- AIRFLOW__CORE__ENABLE_XCOM_PICKLING=true
- AIRFLOW__CORE__PARALLELISM=16
- AIRFLOW__CORE__DAG_CONCURRENCY=8
- AIRFLOW__CORE__MAX_ACTIVE_RUNS_PER_DAG=4
- AIRFLOW__SCHEDULER__MIN_FILE_PROCESS_INTERVAL=30
- AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL=60
- AIRFLOW__WEBSERVER__EXPOSE_CONFIG=true
- AIRFLOW__WEBSERVER__SECRET_KEY=${AIRFLOW_SECRET_KEY}
- AIRFLOW__API__AUTH_BACKENDS=airflow.api.auth.backend.basic_auth
@@ -174,12 +182,18 @@ services:
redis:
condition: service_started
environment:
- AIRFLOW__CORE__EXECUTOR=SequentialExecutor
- AIRFLOW__CORE__EXECUTOR=LocalExecutor
- AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=postgresql+psycopg2://airflow:airflow@postgres/airflow
- AIRFLOW__CORE__FERNET_KEY=${AIRFLOW_FERNET_KEY}
- AIRFLOW__CORE__DAGS_ARE_PAUSED_AT_CREATION=true
- AIRFLOW__CORE__LOAD_EXAMPLES=false
- AIRFLOW__CORE__ENABLE_XCOM_PICKLING=true
- AIRFLOW__CORE__PARALLELISM=16
- AIRFLOW__CORE__DAG_CONCURRENCY=8
- AIRFLOW__CORE__MAX_ACTIVE_RUNS_PER_DAG=4
- AIRFLOW__SCHEDULER__MIN_FILE_PROCESS_INTERVAL=30
- AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL=60
- AIRFLOW__SCHEDULER__PARSING_PROCESSES=2
- AIRFLOW__WEBSERVER__SECRET_KEY=${AIRFLOW_SECRET_KEY}
- AIRFLOW__API__AUTH_BACKENDS=airflow.api.auth.backend.basic_auth
- KAFKA_HOST=kafka

View File

@@ -57,3 +57,13 @@ class ElasticityBasedPricer(PricingFunction):
# enforce bounds
prices = np.clip(prices, self.price_floor, self.price_ceil)
return prices
def _get_features(self, state_space=None) -> np.ndarray:
"""Extract elasticity, demand, and demand deviation for each product"""
if state_space is None or self.elasticity is None:
n = len(self.elasticity) if self.elasticity is not None else 0
return np.zeros((n, 3))
demand = np.asarray(state_space.demand)
demand_dev = (demand - self.mean_demand) / (self.mean_demand + 1e-6)
return np.column_stack([self.elasticity, demand, demand_dev])

View File

@@ -107,6 +107,36 @@ class SessionAwarePricer(PricingFunction):
return prices
def _get_features(self, state_space=None) -> np.ndarray:
"""Extract elasticity, demand, and session features"""
if state_space is None or self.elasticity is None:
n = len(self.elasticity) if self.elasticity is not None else 0
return np.zeros((n, 5))
demand = np.asarray(state_space.demand)
n_products = len(demand)
# extract session features
velocity = 0.0
view_depth = 0.0
cart_to_view = 0.0
if not state_space.session_features.empty:
sf = state_space.session_features.iloc[0]
velocity = sf.get('interaction_velocity', 0.0)
view_depth = sf.get('product_view_depth', 0.0)
cart_to_view = sf.get('cart_to_view_ratio', 0.0)
# broadcast session features to all products
features = np.column_stack([
self.elasticity,
demand,
np.full(n_products, velocity),
np.full(n_products, view_depth),
np.full(n_products, cart_to_view)
])
return features
class ProductSpecificSessionPricer(PricingFunction):
"""
@@ -170,3 +200,12 @@ class ProductSpecificSessionPricer(PricingFunction):
prices = np.clip(base_prices, self.price_floor, self.price_ceil)
return prices
def _get_features(self, state_space=None) -> np.ndarray:
"""Extract elasticity and demand features for product-specific pricing"""
if state_space is None or self.elasticity is None:
n = len(self.elasticity) if self.elasticity is not None else 0
return np.zeros((n, 2))
demand = np.asarray(state_space.demand)
return np.column_stack([self.elasticity, demand])

View File

@@ -65,6 +65,11 @@ class StaticPricer(PricingFunction):
raise ValueError("Must call fit() or provide base_prices in constructor")
return self.base_prices.copy()
def _get_features(self, state_space=None) -> np.ndarray:
"""Static pricer uses no features, returns empty array"""
n = len(self.base_prices) if self.base_prices is not None else 0
return np.zeros((n, 0))
class RandomPricer(PricingFunction):
"""Random pricing within bounds (for baseline comparison)"""
@@ -87,6 +92,11 @@ class RandomPricer(PricingFunction):
self.n_products = len(state_space.demand)
return self.rng.uniform(self.price_min, self.price_max, size=self.n_products)
def _get_features(self, state_space=None) -> np.ndarray:
"""Random pricer uses no features"""
n = self.n_products if self.n_products else 0
return np.zeros((n, 0))
class SimpleSurgePricer(PricingFunction):
"""
@@ -133,3 +143,16 @@ class SimpleSurgePricer(PricingFunction):
new_prices[low_mask] *= self.discount_multiplier
return new_prices
def _get_features(self, state_space=None) -> np.ndarray:
"""Extract demand and base price features for each product"""
if state_space is None:
n = len(self.base_prices) if self.base_prices is not None else 0
return np.zeros((n, 2))
demand = np.asarray(state_space.demand) if hasattr(state_space, 'demand') else np.array([0])
base = np.asarray(state_space.prices) if hasattr(state_space, 'prices') else self.base_prices
if base is None:
base = np.ones(len(demand)) * 99.99
return np.column_stack([demand, base])