hotfix: updating pricing provider to better read data

This commit is contained in:
2026-02-06 12:01:12 +01:00
parent e7cb48e9cd
commit 29a13340b9

View File

@@ -2,11 +2,14 @@ import redis
import pickle
import json
import pandas as pd
from io import StringIO
from typing import Optional, Dict, Any
import os
import logging
log = logging.getLogger(__name__)
class ModelRegistry:
"""
Lightweight model registry using Redis for storing pricing models and elasticity data.
@@ -14,24 +17,23 @@ class ModelRegistry:
"""
def __init__(self, redis_host: str = None, redis_port: int = None):
host = redis_host or os.getenv('REDIS_HOST', 'localhost')
port = redis_port or int(os.getenv('REDIS_PORT', '6378'))
host = redis_host or os.getenv("REDIS_HOST", "localhost")
port = redis_port or int(os.getenv("REDIS_PORT", "6378"))
self.redis_client = redis.Redis(
host=host,
port=port,
db=0,
decode_responses=False
host=host, port=port, db=0, decode_responses=False
)
self.metadata_prefix = "model:meta:"
self.data_prefix = "model:data:"
self.elasticity_prefix = "elasticity:"
self.prices_prefix = "prices:"
def publish_elasticity(self,
def publish_elasticity(
self,
elasticity_df: pd.DataFrame,
model_name: str = 'latest',
metadata: Optional[Dict[str, Any]] = None):
model_name: str = "latest",
metadata: Optional[Dict[str, Any]] = None,
):
"""
Store elasticity estimates in registry.
@@ -43,25 +45,29 @@ class ModelRegistry:
key = f"{self.elasticity_prefix}{model_name}"
# serialize dataframe as JSON
data_json = elasticity_df.to_json(orient='records')
data_json = elasticity_df.to_json(orient="records")
# store data
self.redis_client.set(key, data_json)
# store metadata
meta = metadata or {}
meta.update({
'n_products': len(elasticity_df),
'mean_elasticity': float(elasticity_df['elasticity'].mean()),
'model_type': 'elasticity_snapshot'
})
meta.update(
{
"n_products": len(elasticity_df),
"mean_elasticity": float(elasticity_df["elasticity"].mean()),
"model_type": "elasticity_snapshot",
}
)
meta_key = f"{self.metadata_prefix}{model_name}"
self.redis_client.set(meta_key, json.dumps(meta))
log.info(f"Published elasticity model '{model_name}' with {len(elasticity_df)} products")
log.info(
f"Published elasticity model '{model_name}' with {len(elasticity_df)} products"
)
def get_elasticity(self, model_name: str = 'latest') -> Optional[pd.DataFrame]:
def get_elasticity(self, model_name: str = "latest") -> Optional[pd.DataFrame]:
"""Retrieve elasticity estimates from registry."""
key = f"{self.elasticity_prefix}{model_name}"
data_json = self.redis_client.get(key)
@@ -71,14 +77,16 @@ class ModelRegistry:
# decode bytes to string if needed
if isinstance(data_json, bytes):
data_json = data_json.decode('utf-8')
data_json = data_json.decode("utf-8")
return pd.read_json(data_json, orient='records')
return pd.read_json(StringIO(data_json), orient="records")
def publish_pricing_model(self,
def publish_pricing_model(
self,
pricing_function,
model_name: str = 'latest',
metadata: Optional[Dict[str, Any]] = None):
model_name: str = "latest",
metadata: Optional[Dict[str, Any]] = None,
):
"""
Store a fitted pricing function object.
@@ -95,17 +103,19 @@ class ModelRegistry:
# store metadata
meta = metadata or {}
meta.update({
'model_class': pricing_function.__class__.__name__,
'model_type': 'pricing_function'
})
meta.update(
{
"model_class": pricing_function.__class__.__name__,
"model_type": "pricing_function",
}
)
meta_key = f"{self.metadata_prefix}{model_name}"
self.redis_client.set(meta_key, json.dumps(meta))
log.info(f"Published pricing model '{model_name}' ({meta['model_class']})")
def get_pricing_model(self, model_name: str = 'latest'):
def get_pricing_model(self, model_name: str = "latest"):
"""Retrieve a pricing function from registry."""
key = f"{self.data_prefix}{model_name}"
model_bytes = self.redis_client.get(key)
@@ -120,21 +130,23 @@ class ModelRegistry:
models = {}
for key in self.redis_client.scan_iter(f"{self.metadata_prefix}*"):
key_str = key.decode('utf-8') if isinstance(key, bytes) else key
model_name = key_str.replace(self.metadata_prefix, '')
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
model_name = key_str.replace(self.metadata_prefix, "")
meta_json = self.redis_client.get(key)
if meta_json:
if isinstance(meta_json, bytes):
meta_json = meta_json.decode('utf-8')
meta_json = meta_json.decode("utf-8")
models[model_name] = json.loads(meta_json)
return models
def publish_prices(self,
def publish_prices(
self,
prices_df: pd.DataFrame,
model_name: str = 'latest',
metadata: Optional[Dict[str, Any]] = None):
model_name: str = "latest",
metadata: Optional[Dict[str, Any]] = None,
):
"""Store predicted prices in registry.
Args:
@@ -143,22 +155,19 @@ class ModelRegistry:
metadata: additional info
"""
key = f"{self.prices_prefix}{model_name}"
data_json = prices_df.to_json(orient='records')
data_json = prices_df.to_json(orient="records")
self.redis_client.set(key, data_json)
meta = metadata or {}
meta.update({
'n_products': len(prices_df),
'model_type': 'predicted_prices'
})
meta.update({"n_products": len(prices_df), "model_type": "predicted_prices"})
meta_key = f"{self.metadata_prefix}prices_{model_name}"
self.redis_client.set(meta_key, json.dumps(meta))
log.info(f"Published prices '{model_name}' for {len(prices_df)} products")
def get_prices(self, model_name: str = 'latest') -> Optional[pd.DataFrame]:
def get_prices(self, model_name: str = "latest") -> Optional[pd.DataFrame]:
"""Retrieve predicted prices from registry."""
key = f"{self.prices_prefix}{model_name}"
data_json = self.redis_client.get(key)
@@ -167,9 +176,9 @@ class ModelRegistry:
return None
if isinstance(data_json, bytes):
data_json = data_json.decode('utf-8')
data_json = data_json.decode("utf-8")
return pd.read_json(data_json, orient='records')
return pd.read_json(StringIO(data_json), orient="records")
def health_check(self) -> bool:
"""Check if Redis connection is alive."""
@@ -179,7 +188,9 @@ class ModelRegistry:
except:
return False
def set_session_prices(self, session_id: str, prices: Dict[str, float], ttl: int = 1800):
def set_session_prices(
self, session_id: str, prices: Dict[str, float], ttl: int = 1800
):
"""
Store prices for a specific session.
THIS is the write path for session-aware pricing.
@@ -210,7 +221,9 @@ class ModelRegistry:
if price_str is None:
return None
return float(price_str.decode('utf-8') if isinstance(price_str, bytes) else price_str)
return float(
price_str.decode("utf-8") if isinstance(price_str, bytes) else price_str
)
def get_session_all_prices(self, session_id: str) -> Dict[str, float]:
"""Get all prices for a session."""
@@ -221,6 +234,8 @@ class ModelRegistry:
return {}
return {
(k.decode('utf-8') if isinstance(k, bytes) else k): float(v.decode('utf-8') if isinstance(v, bytes) else v)
(k.decode("utf-8") if isinstance(k, bytes) else k): float(
v.decode("utf-8") if isinstance(v, bytes) else v
)
for k, v in prices_raw.items()
}