feat: baseline setup for RL modeling

This commit is contained in:
2026-01-22 12:52:41 +01:00
parent fa89347c4e
commit a6e6cc5d60
3 changed files with 152 additions and 37 deletions

View File

@@ -1,9 +1,8 @@
from os import kill
import numpy as np
import pandas as pd
from abc import ABC, abstractmethod
from typing import Dict, Any
from environment import BusinessLogicConstraints
from sim.rl.environment import BusinessLogicConstraints
"""
@@ -32,9 +31,11 @@ class BasePricingEngine(ABC):
"""
pass
@abstractmethod
def update(obs, reward, done, info):
pass
def update(self, observation: Dict[str, Any], reward: float, done: bool, info: Dict[str, Any]) -> None:
"""Default no-op update. Engines can override as needed."""
self.last_observation = observation
self.last_reward = reward
self.last_info = info
@@ -48,14 +49,14 @@ class WildPricingEngine(BasePricingEngine):
def __init__(self, constraints: BusinessLogicConstraints, seed: int = 0):
super().__init__(constraints, seed)
# per-product unit costs (unknown to customers; known to platform)
self.unit_cost = self.rng.uniform(8.0, 40.0, size=self.c.product_catelogue_size).astype(np.float32)
self.unit_cost = self.rng.uniform(8.0, 40.0, size=self.c.product_catalogue_size).astype(np.float32)
# online elasticity estimate (start moderately elastic)
self.e_hat = np.full((self.c.product_catelogue_size,), -1.3, dtype=np.float32)
self.e_hat = np.full((self.c.product_catalogue_size,), -1.3, dtype=np.float32)
# EWMA state for log-log regression
self.mu_logp = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
self.mu_logq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
self.cov_pq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
self.var_p = np.ones(self.c.product_catelogue_size, dtype=np.float32)
self.mu_logp = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
self.mu_logq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
self.cov_pq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
self.var_p = np.ones(self.c.product_catalogue_size, dtype=np.float32)
# knobs typical in production
self.lr = 0.08
self.ewma = 0.05
@@ -67,16 +68,16 @@ class WildPricingEngine(BasePricingEngine):
def reset(self):
super().reset()
self.e_hat = np.full((self.c.product_catelogue_size,), -1.3, dtype=np.float32)
self.mu_logp = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
self.mu_logq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
self.cov_pq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
self.var_p = np.ones(self.c.product_catelogue_size, dtype=np.float32)
self.e_hat = np.full((self.c.product_catalogue_size,), -1.3, dtype=np.float32)
self.mu_logp = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
self.mu_logq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
self.cov_pq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
self.var_p = np.ones(self.c.product_catalogue_size, dtype=np.float32)
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
self.step_count += 1
# extract demand signal (from env observation) as proxy for sales
demand = observation.get('demand', np.zeros(self.c.product_catelogue_size, dtype=np.float32))
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
return self._update_from_demand(current_prices, demand)
def _update_from_demand(self, prices: np.ndarray, sold: np.ndarray) -> np.ndarray:
@@ -140,7 +141,7 @@ class SimpleDemandEngine(BasePricingEngine):
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
self.step_count += 1
demand = observation.get('demand', np.zeros(self.c.product_catelogue_size, dtype=np.float32))
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
if self.prev_demand is None:
self.prev_demand = demand.copy()
return current_prices.copy()
@@ -187,15 +188,15 @@ class ThompsonSamplingEngine(BasePricingEngine):
def __init__(self, constraints: BusinessLogicConstraints, seed: int = 0):
super().__init__(constraints, seed)
self.n_price_levels = 5
self.alpha = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
self.beta = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
self.alpha = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
self.beta = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
self.price_grid = None
self.last_actions = None
def reset(self):
super().reset()
self.alpha = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
self.beta = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
self.alpha = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
self.beta = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
self.price_grid = None
self.last_actions = None
@@ -206,10 +207,10 @@ class ThompsonSamplingEngine(BasePricingEngine):
lo = current_prices * 0.7
hi = current_prices * 1.3
self.price_grid = np.linspace(lo, hi, self.n_price_levels).T
demand = observation.get('demand', np.zeros(self.c.product_catelogue_size, dtype=np.float32))
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
# update beliefs based on last action
if self.last_actions is not None:
for i in range(self.c.product_catelogue_size):
for i in range(self.c.product_catalogue_size):
a = self.last_actions[i]
reward = demand[i]
if reward > 0.5:
@@ -217,9 +218,9 @@ class ThompsonSamplingEngine(BasePricingEngine):
else:
self.beta[i, a] += 1.0
# thompson sampling: sample from posterior, pick best
new_prices = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
actions = np.zeros(self.c.product_catelogue_size, dtype=int)
for i in range(self.c.product_catelogue_size):
new_prices = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
actions = np.zeros(self.c.product_catalogue_size, dtype=int)
for i in range(self.c.product_catalogue_size):
theta = self.rng.beta(self.alpha[i], self.beta[i]).astype(np.float32)
actions[i] = int(np.argmax(theta))
new_prices[i] = self.price_grid[i, actions[i]]