mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
feat: baseline setup for RL modeling
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
from os import kill
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
from environment import BusinessLogicConstraints
|
||||
from sim.rl.environment import BusinessLogicConstraints
|
||||
|
||||
|
||||
"""
|
||||
@@ -32,9 +31,11 @@ class BasePricingEngine(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(obs, reward, done, info):
|
||||
pass
|
||||
def update(self, observation: Dict[str, Any], reward: float, done: bool, info: Dict[str, Any]) -> None:
|
||||
"""Default no-op update. Engines can override as needed."""
|
||||
self.last_observation = observation
|
||||
self.last_reward = reward
|
||||
self.last_info = info
|
||||
|
||||
|
||||
|
||||
@@ -48,14 +49,14 @@ class WildPricingEngine(BasePricingEngine):
|
||||
def __init__(self, constraints: BusinessLogicConstraints, seed: int = 0):
|
||||
super().__init__(constraints, seed)
|
||||
# per-product unit costs (unknown to customers; known to platform)
|
||||
self.unit_cost = self.rng.uniform(8.0, 40.0, size=self.c.product_catelogue_size).astype(np.float32)
|
||||
self.unit_cost = self.rng.uniform(8.0, 40.0, size=self.c.product_catalogue_size).astype(np.float32)
|
||||
# online elasticity estimate (start moderately elastic)
|
||||
self.e_hat = np.full((self.c.product_catelogue_size,), -1.3, dtype=np.float32)
|
||||
self.e_hat = np.full((self.c.product_catalogue_size,), -1.3, dtype=np.float32)
|
||||
# EWMA state for log-log regression
|
||||
self.mu_logp = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
||||
self.mu_logq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
||||
self.cov_pq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
||||
self.var_p = np.ones(self.c.product_catelogue_size, dtype=np.float32)
|
||||
self.mu_logp = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||
self.mu_logq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||
self.cov_pq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||
self.var_p = np.ones(self.c.product_catalogue_size, dtype=np.float32)
|
||||
# knobs typical in production
|
||||
self.lr = 0.08
|
||||
self.ewma = 0.05
|
||||
@@ -67,16 +68,16 @@ class WildPricingEngine(BasePricingEngine):
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self.e_hat = np.full((self.c.product_catelogue_size,), -1.3, dtype=np.float32)
|
||||
self.mu_logp = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
||||
self.mu_logq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
||||
self.cov_pq = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
||||
self.var_p = np.ones(self.c.product_catelogue_size, dtype=np.float32)
|
||||
self.e_hat = np.full((self.c.product_catalogue_size,), -1.3, dtype=np.float32)
|
||||
self.mu_logp = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||
self.mu_logq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||
self.cov_pq = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||
self.var_p = np.ones(self.c.product_catalogue_size, dtype=np.float32)
|
||||
|
||||
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
|
||||
self.step_count += 1
|
||||
# extract demand signal (from env observation) as proxy for sales
|
||||
demand = observation.get('demand', np.zeros(self.c.product_catelogue_size, dtype=np.float32))
|
||||
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
|
||||
return self._update_from_demand(current_prices, demand)
|
||||
|
||||
def _update_from_demand(self, prices: np.ndarray, sold: np.ndarray) -> np.ndarray:
|
||||
@@ -140,7 +141,7 @@ class SimpleDemandEngine(BasePricingEngine):
|
||||
|
||||
def compute_prices(self, current_prices: np.ndarray, observation: Dict[str, Any]) -> np.ndarray:
|
||||
self.step_count += 1
|
||||
demand = observation.get('demand', np.zeros(self.c.product_catelogue_size, dtype=np.float32))
|
||||
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
|
||||
if self.prev_demand is None:
|
||||
self.prev_demand = demand.copy()
|
||||
return current_prices.copy()
|
||||
@@ -187,15 +188,15 @@ class ThompsonSamplingEngine(BasePricingEngine):
|
||||
def __init__(self, constraints: BusinessLogicConstraints, seed: int = 0):
|
||||
super().__init__(constraints, seed)
|
||||
self.n_price_levels = 5
|
||||
self.alpha = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
|
||||
self.beta = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
|
||||
self.alpha = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
|
||||
self.beta = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
|
||||
self.price_grid = None
|
||||
self.last_actions = None
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self.alpha = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
|
||||
self.beta = np.ones((self.c.product_catelogue_size, self.n_price_levels), dtype=np.float32)
|
||||
self.alpha = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
|
||||
self.beta = np.ones((self.c.product_catalogue_size, self.n_price_levels), dtype=np.float32)
|
||||
self.price_grid = None
|
||||
self.last_actions = None
|
||||
|
||||
@@ -206,10 +207,10 @@ class ThompsonSamplingEngine(BasePricingEngine):
|
||||
lo = current_prices * 0.7
|
||||
hi = current_prices * 1.3
|
||||
self.price_grid = np.linspace(lo, hi, self.n_price_levels).T
|
||||
demand = observation.get('demand', np.zeros(self.c.product_catelogue_size, dtype=np.float32))
|
||||
demand = observation.get('demand', np.zeros(self.c.product_catalogue_size, dtype=np.float32))
|
||||
# update beliefs based on last action
|
||||
if self.last_actions is not None:
|
||||
for i in range(self.c.product_catelogue_size):
|
||||
for i in range(self.c.product_catalogue_size):
|
||||
a = self.last_actions[i]
|
||||
reward = demand[i]
|
||||
if reward > 0.5:
|
||||
@@ -217,9 +218,9 @@ class ThompsonSamplingEngine(BasePricingEngine):
|
||||
else:
|
||||
self.beta[i, a] += 1.0
|
||||
# thompson sampling: sample from posterior, pick best
|
||||
new_prices = np.zeros(self.c.product_catelogue_size, dtype=np.float32)
|
||||
actions = np.zeros(self.c.product_catelogue_size, dtype=int)
|
||||
for i in range(self.c.product_catelogue_size):
|
||||
new_prices = np.zeros(self.c.product_catalogue_size, dtype=np.float32)
|
||||
actions = np.zeros(self.c.product_catalogue_size, dtype=int)
|
||||
for i in range(self.c.product_catalogue_size):
|
||||
theta = self.rng.beta(self.alpha[i], self.beta[i]).astype(np.float32)
|
||||
actions[i] = int(np.argmax(theta))
|
||||
new_prices[i] = self.price_grid[i, actions[i]]
|
||||
|
||||
Reference in New Issue
Block a user