diff --git a/lib/model_registry.py b/lib/model_registry.py index e833a1a..bb04dfc 100755 --- a/lib/model_registry.py +++ b/lib/model_registry.py @@ -2,11 +2,14 @@ import redis import pickle import json import pandas as pd +from io import StringIO from typing import Optional, Dict, Any import os import logging + log = logging.getLogger(__name__) + class ModelRegistry: """ Lightweight model registry using Redis for storing pricing models and elasticity data. @@ -14,24 +17,23 @@ class ModelRegistry: """ def __init__(self, redis_host: str = None, redis_port: int = None): - host = redis_host or os.getenv('REDIS_HOST', 'localhost') - port = redis_port or int(os.getenv('REDIS_PORT', '6378')) + host = redis_host or os.getenv("REDIS_HOST", "localhost") + port = redis_port or int(os.getenv("REDIS_PORT", "6378")) self.redis_client = redis.Redis( - host=host, - port=port, - db=0, - decode_responses=False + host=host, port=port, db=0, decode_responses=False ) self.metadata_prefix = "model:meta:" self.data_prefix = "model:data:" self.elasticity_prefix = "elasticity:" self.prices_prefix = "prices:" - def publish_elasticity(self, - elasticity_df: pd.DataFrame, - model_name: str = 'latest', - metadata: Optional[Dict[str, Any]] = None): + def publish_elasticity( + self, + elasticity_df: pd.DataFrame, + model_name: str = "latest", + metadata: Optional[Dict[str, Any]] = None, + ): """ Store elasticity estimates in registry. @@ -43,25 +45,29 @@ class ModelRegistry: key = f"{self.elasticity_prefix}{model_name}" # serialize dataframe as JSON - data_json = elasticity_df.to_json(orient='records') + data_json = elasticity_df.to_json(orient="records") # store data self.redis_client.set(key, data_json) # store metadata meta = metadata or {} - meta.update({ - 'n_products': len(elasticity_df), - 'mean_elasticity': float(elasticity_df['elasticity'].mean()), - 'model_type': 'elasticity_snapshot' - }) + meta.update( + { + "n_products": len(elasticity_df), + "mean_elasticity": float(elasticity_df["elasticity"].mean()), + "model_type": "elasticity_snapshot", + } + ) meta_key = f"{self.metadata_prefix}{model_name}" self.redis_client.set(meta_key, json.dumps(meta)) - log.info(f"Published elasticity model '{model_name}' with {len(elasticity_df)} products") + log.info( + f"Published elasticity model '{model_name}' with {len(elasticity_df)} products" + ) - def get_elasticity(self, model_name: str = 'latest') -> Optional[pd.DataFrame]: + def get_elasticity(self, model_name: str = "latest") -> Optional[pd.DataFrame]: """Retrieve elasticity estimates from registry.""" key = f"{self.elasticity_prefix}{model_name}" data_json = self.redis_client.get(key) @@ -71,14 +77,16 @@ class ModelRegistry: # decode bytes to string if needed if isinstance(data_json, bytes): - data_json = data_json.decode('utf-8') + data_json = data_json.decode("utf-8") - return pd.read_json(data_json, orient='records') + return pd.read_json(StringIO(data_json), orient="records") - def publish_pricing_model(self, - pricing_function, - model_name: str = 'latest', - metadata: Optional[Dict[str, Any]] = None): + def publish_pricing_model( + self, + pricing_function, + model_name: str = "latest", + metadata: Optional[Dict[str, Any]] = None, + ): """ Store a fitted pricing function object. @@ -95,17 +103,19 @@ class ModelRegistry: # store metadata meta = metadata or {} - meta.update({ - 'model_class': pricing_function.__class__.__name__, - 'model_type': 'pricing_function' - }) + meta.update( + { + "model_class": pricing_function.__class__.__name__, + "model_type": "pricing_function", + } + ) meta_key = f"{self.metadata_prefix}{model_name}" self.redis_client.set(meta_key, json.dumps(meta)) log.info(f"Published pricing model '{model_name}' ({meta['model_class']})") - def get_pricing_model(self, model_name: str = 'latest'): + def get_pricing_model(self, model_name: str = "latest"): """Retrieve a pricing function from registry.""" key = f"{self.data_prefix}{model_name}" model_bytes = self.redis_client.get(key) @@ -120,21 +130,23 @@ class ModelRegistry: models = {} for key in self.redis_client.scan_iter(f"{self.metadata_prefix}*"): - key_str = key.decode('utf-8') if isinstance(key, bytes) else key - model_name = key_str.replace(self.metadata_prefix, '') + key_str = key.decode("utf-8") if isinstance(key, bytes) else key + model_name = key_str.replace(self.metadata_prefix, "") meta_json = self.redis_client.get(key) if meta_json: if isinstance(meta_json, bytes): - meta_json = meta_json.decode('utf-8') + meta_json = meta_json.decode("utf-8") models[model_name] = json.loads(meta_json) return models - def publish_prices(self, - prices_df: pd.DataFrame, - model_name: str = 'latest', - metadata: Optional[Dict[str, Any]] = None): + def publish_prices( + self, + prices_df: pd.DataFrame, + model_name: str = "latest", + metadata: Optional[Dict[str, Any]] = None, + ): """Store predicted prices in registry. Args: @@ -143,22 +155,19 @@ class ModelRegistry: metadata: additional info """ key = f"{self.prices_prefix}{model_name}" - data_json = prices_df.to_json(orient='records') + data_json = prices_df.to_json(orient="records") self.redis_client.set(key, data_json) meta = metadata or {} - meta.update({ - 'n_products': len(prices_df), - 'model_type': 'predicted_prices' - }) + meta.update({"n_products": len(prices_df), "model_type": "predicted_prices"}) meta_key = f"{self.metadata_prefix}prices_{model_name}" self.redis_client.set(meta_key, json.dumps(meta)) log.info(f"Published prices '{model_name}' for {len(prices_df)} products") - def get_prices(self, model_name: str = 'latest') -> Optional[pd.DataFrame]: + def get_prices(self, model_name: str = "latest") -> Optional[pd.DataFrame]: """Retrieve predicted prices from registry.""" key = f"{self.prices_prefix}{model_name}" data_json = self.redis_client.get(key) @@ -167,9 +176,9 @@ class ModelRegistry: return None if isinstance(data_json, bytes): - data_json = data_json.decode('utf-8') + data_json = data_json.decode("utf-8") - return pd.read_json(data_json, orient='records') + return pd.read_json(StringIO(data_json), orient="records") def health_check(self) -> bool: """Check if Redis connection is alive.""" @@ -179,7 +188,9 @@ class ModelRegistry: except: return False - def set_session_prices(self, session_id: str, prices: Dict[str, float], ttl: int = 1800): + def set_session_prices( + self, session_id: str, prices: Dict[str, float], ttl: int = 1800 + ): """ Store prices for a specific session. THIS is the write path for session-aware pricing. @@ -210,7 +221,9 @@ class ModelRegistry: if price_str is None: return None - return float(price_str.decode('utf-8') if isinstance(price_str, bytes) else price_str) + return float( + price_str.decode("utf-8") if isinstance(price_str, bytes) else price_str + ) def get_session_all_prices(self, session_id: str) -> Dict[str, float]: """Get all prices for a session.""" @@ -221,6 +234,8 @@ class ModelRegistry: return {} return { - (k.decode('utf-8') if isinstance(k, bytes) else k): float(v.decode('utf-8') if isinstance(v, bytes) else v) + (k.decode("utf-8") if isinstance(k, bytes) else k): float( + v.decode("utf-8") if isinstance(v, bytes) else v + ) for k, v in prices_raw.items() }