hotfix: updating pricing provider to better read data

This commit is contained in:
2026-02-06 12:01:12 +01:00
parent e7cb48e9cd
commit 29a13340b9

View File

@@ -2,11 +2,14 @@ import redis
import pickle import pickle
import json import json
import pandas as pd import pandas as pd
from io import StringIO
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
import os import os
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ModelRegistry: class ModelRegistry:
""" """
Lightweight model registry using Redis for storing pricing models and elasticity data. Lightweight model registry using Redis for storing pricing models and elasticity data.
@@ -14,24 +17,23 @@ class ModelRegistry:
""" """
def __init__(self, redis_host: str = None, redis_port: int = None): def __init__(self, redis_host: str = None, redis_port: int = None):
host = redis_host or os.getenv('REDIS_HOST', 'localhost') host = redis_host or os.getenv("REDIS_HOST", "localhost")
port = redis_port or int(os.getenv('REDIS_PORT', '6378')) port = redis_port or int(os.getenv("REDIS_PORT", "6378"))
self.redis_client = redis.Redis( self.redis_client = redis.Redis(
host=host, host=host, port=port, db=0, decode_responses=False
port=port,
db=0,
decode_responses=False
) )
self.metadata_prefix = "model:meta:" self.metadata_prefix = "model:meta:"
self.data_prefix = "model:data:" self.data_prefix = "model:data:"
self.elasticity_prefix = "elasticity:" self.elasticity_prefix = "elasticity:"
self.prices_prefix = "prices:" self.prices_prefix = "prices:"
def publish_elasticity(self, def publish_elasticity(
self,
elasticity_df: pd.DataFrame, elasticity_df: pd.DataFrame,
model_name: str = 'latest', model_name: str = "latest",
metadata: Optional[Dict[str, Any]] = None): metadata: Optional[Dict[str, Any]] = None,
):
""" """
Store elasticity estimates in registry. Store elasticity estimates in registry.
@@ -43,25 +45,29 @@ class ModelRegistry:
key = f"{self.elasticity_prefix}{model_name}" key = f"{self.elasticity_prefix}{model_name}"
# serialize dataframe as JSON # serialize dataframe as JSON
data_json = elasticity_df.to_json(orient='records') data_json = elasticity_df.to_json(orient="records")
# store data # store data
self.redis_client.set(key, data_json) self.redis_client.set(key, data_json)
# store metadata # store metadata
meta = metadata or {} meta = metadata or {}
meta.update({ meta.update(
'n_products': len(elasticity_df), {
'mean_elasticity': float(elasticity_df['elasticity'].mean()), "n_products": len(elasticity_df),
'model_type': 'elasticity_snapshot' "mean_elasticity": float(elasticity_df["elasticity"].mean()),
}) "model_type": "elasticity_snapshot",
}
)
meta_key = f"{self.metadata_prefix}{model_name}" meta_key = f"{self.metadata_prefix}{model_name}"
self.redis_client.set(meta_key, json.dumps(meta)) self.redis_client.set(meta_key, json.dumps(meta))
log.info(f"Published elasticity model '{model_name}' with {len(elasticity_df)} products") log.info(
f"Published elasticity model '{model_name}' with {len(elasticity_df)} products"
)
def get_elasticity(self, model_name: str = 'latest') -> Optional[pd.DataFrame]: def get_elasticity(self, model_name: str = "latest") -> Optional[pd.DataFrame]:
"""Retrieve elasticity estimates from registry.""" """Retrieve elasticity estimates from registry."""
key = f"{self.elasticity_prefix}{model_name}" key = f"{self.elasticity_prefix}{model_name}"
data_json = self.redis_client.get(key) data_json = self.redis_client.get(key)
@@ -71,14 +77,16 @@ class ModelRegistry:
# decode bytes to string if needed # decode bytes to string if needed
if isinstance(data_json, bytes): if isinstance(data_json, bytes):
data_json = data_json.decode('utf-8') data_json = data_json.decode("utf-8")
return pd.read_json(data_json, orient='records') return pd.read_json(StringIO(data_json), orient="records")
def publish_pricing_model(self, def publish_pricing_model(
self,
pricing_function, pricing_function,
model_name: str = 'latest', model_name: str = "latest",
metadata: Optional[Dict[str, Any]] = None): metadata: Optional[Dict[str, Any]] = None,
):
""" """
Store a fitted pricing function object. Store a fitted pricing function object.
@@ -95,17 +103,19 @@ class ModelRegistry:
# store metadata # store metadata
meta = metadata or {} meta = metadata or {}
meta.update({ meta.update(
'model_class': pricing_function.__class__.__name__, {
'model_type': 'pricing_function' "model_class": pricing_function.__class__.__name__,
}) "model_type": "pricing_function",
}
)
meta_key = f"{self.metadata_prefix}{model_name}" meta_key = f"{self.metadata_prefix}{model_name}"
self.redis_client.set(meta_key, json.dumps(meta)) self.redis_client.set(meta_key, json.dumps(meta))
log.info(f"Published pricing model '{model_name}' ({meta['model_class']})") log.info(f"Published pricing model '{model_name}' ({meta['model_class']})")
def get_pricing_model(self, model_name: str = 'latest'): def get_pricing_model(self, model_name: str = "latest"):
"""Retrieve a pricing function from registry.""" """Retrieve a pricing function from registry."""
key = f"{self.data_prefix}{model_name}" key = f"{self.data_prefix}{model_name}"
model_bytes = self.redis_client.get(key) model_bytes = self.redis_client.get(key)
@@ -120,21 +130,23 @@ class ModelRegistry:
models = {} models = {}
for key in self.redis_client.scan_iter(f"{self.metadata_prefix}*"): for key in self.redis_client.scan_iter(f"{self.metadata_prefix}*"):
key_str = key.decode('utf-8') if isinstance(key, bytes) else key key_str = key.decode("utf-8") if isinstance(key, bytes) else key
model_name = key_str.replace(self.metadata_prefix, '') model_name = key_str.replace(self.metadata_prefix, "")
meta_json = self.redis_client.get(key) meta_json = self.redis_client.get(key)
if meta_json: if meta_json:
if isinstance(meta_json, bytes): if isinstance(meta_json, bytes):
meta_json = meta_json.decode('utf-8') meta_json = meta_json.decode("utf-8")
models[model_name] = json.loads(meta_json) models[model_name] = json.loads(meta_json)
return models return models
def publish_prices(self, def publish_prices(
self,
prices_df: pd.DataFrame, prices_df: pd.DataFrame,
model_name: str = 'latest', model_name: str = "latest",
metadata: Optional[Dict[str, Any]] = None): metadata: Optional[Dict[str, Any]] = None,
):
"""Store predicted prices in registry. """Store predicted prices in registry.
Args: Args:
@@ -143,22 +155,19 @@ class ModelRegistry:
metadata: additional info metadata: additional info
""" """
key = f"{self.prices_prefix}{model_name}" key = f"{self.prices_prefix}{model_name}"
data_json = prices_df.to_json(orient='records') data_json = prices_df.to_json(orient="records")
self.redis_client.set(key, data_json) self.redis_client.set(key, data_json)
meta = metadata or {} meta = metadata or {}
meta.update({ meta.update({"n_products": len(prices_df), "model_type": "predicted_prices"})
'n_products': len(prices_df),
'model_type': 'predicted_prices'
})
meta_key = f"{self.metadata_prefix}prices_{model_name}" meta_key = f"{self.metadata_prefix}prices_{model_name}"
self.redis_client.set(meta_key, json.dumps(meta)) self.redis_client.set(meta_key, json.dumps(meta))
log.info(f"Published prices '{model_name}' for {len(prices_df)} products") log.info(f"Published prices '{model_name}' for {len(prices_df)} products")
def get_prices(self, model_name: str = 'latest') -> Optional[pd.DataFrame]: def get_prices(self, model_name: str = "latest") -> Optional[pd.DataFrame]:
"""Retrieve predicted prices from registry.""" """Retrieve predicted prices from registry."""
key = f"{self.prices_prefix}{model_name}" key = f"{self.prices_prefix}{model_name}"
data_json = self.redis_client.get(key) data_json = self.redis_client.get(key)
@@ -167,9 +176,9 @@ class ModelRegistry:
return None return None
if isinstance(data_json, bytes): if isinstance(data_json, bytes):
data_json = data_json.decode('utf-8') data_json = data_json.decode("utf-8")
return pd.read_json(data_json, orient='records') return pd.read_json(StringIO(data_json), orient="records")
def health_check(self) -> bool: def health_check(self) -> bool:
"""Check if Redis connection is alive.""" """Check if Redis connection is alive."""
@@ -179,7 +188,9 @@ class ModelRegistry:
except: except:
return False return False
def set_session_prices(self, session_id: str, prices: Dict[str, float], ttl: int = 1800): def set_session_prices(
self, session_id: str, prices: Dict[str, float], ttl: int = 1800
):
""" """
Store prices for a specific session. Store prices for a specific session.
THIS is the write path for session-aware pricing. THIS is the write path for session-aware pricing.
@@ -210,7 +221,9 @@ class ModelRegistry:
if price_str is None: if price_str is None:
return None return None
return float(price_str.decode('utf-8') if isinstance(price_str, bytes) else price_str) return float(
price_str.decode("utf-8") if isinstance(price_str, bytes) else price_str
)
def get_session_all_prices(self, session_id: str) -> Dict[str, float]: def get_session_all_prices(self, session_id: str) -> Dict[str, float]:
"""Get all prices for a session.""" """Get all prices for a session."""
@@ -221,6 +234,8 @@ class ModelRegistry:
return {} return {}
return { return {
(k.decode('utf-8') if isinstance(k, bytes) else k): float(v.decode('utf-8') if isinstance(v, bytes) else v) (k.decode("utf-8") if isinstance(k, bytes) else k): float(
v.decode("utf-8") if isinstance(v, bytes) else v
)
for k, v in prices_raw.items() for k, v in prices_raw.items()
} }