Files
PHANTOM/lib/model_registry.py

242 lines
7.6 KiB
Python
Executable File

import redis
import pickle
import json
import pandas as pd
from io import StringIO
from typing import Optional, Dict, Any
import os
import logging
log = logging.getLogger(__name__)
class ModelRegistry:
"""
Lightweight model registry using Redis for storing pricing models and elasticity data.
Models are serialized using pickle, metadata stored as JSON.
"""
def __init__(self, redis_host: str = None, redis_port: int = None):
host = redis_host or os.getenv("REDIS_HOST", "localhost")
port = redis_port or int(os.getenv("REDIS_PORT", "6378"))
self.redis_client = redis.Redis(
host=host, port=port, db=0, decode_responses=False
)
self.metadata_prefix = "model:meta:"
self.data_prefix = "model:data:"
self.elasticity_prefix = "elasticity:"
self.prices_prefix = "prices:"
def publish_elasticity(
self,
elasticity_df: pd.DataFrame,
model_name: str = "latest",
metadata: Optional[Dict[str, Any]] = None,
):
"""
Store elasticity estimates in registry.
Args:
elasticity_df: df with [productId, elasticity, std_error, n_obs]
model_name: identifier for this elasticity snapshot
metadata: additional info (timestamp, window_size, etc)
"""
key = f"{self.elasticity_prefix}{model_name}"
# serialize dataframe as JSON
data_json = elasticity_df.to_json(orient="records")
# store data
self.redis_client.set(key, data_json)
# store metadata
meta = metadata or {}
meta.update(
{
"n_products": len(elasticity_df),
"mean_elasticity": float(elasticity_df["elasticity"].mean()),
"model_type": "elasticity_snapshot",
}
)
meta_key = f"{self.metadata_prefix}{model_name}"
self.redis_client.set(meta_key, json.dumps(meta))
log.info(
f"Published elasticity model '{model_name}' with {len(elasticity_df)} products"
)
def get_elasticity(self, model_name: str = "latest") -> Optional[pd.DataFrame]:
"""Retrieve elasticity estimates from registry."""
key = f"{self.elasticity_prefix}{model_name}"
data_json = self.redis_client.get(key)
if data_json is None:
return None
# decode bytes to string if needed
if isinstance(data_json, bytes):
data_json = data_json.decode("utf-8")
return pd.read_json(StringIO(data_json), orient="records")
def publish_pricing_model(
self,
pricing_function,
model_name: str = "latest",
metadata: Optional[Dict[str, Any]] = None,
):
"""
Store a fitted pricing function object.
Args:
pricing_function: fitted PricingFunction instance
model_name: identifier
metadata: additional info
"""
key = f"{self.data_prefix}{model_name}"
# serialize object
model_bytes = pickle.dumps(pricing_function)
self.redis_client.set(key, model_bytes)
# store metadata
meta = metadata or {}
meta.update(
{
"model_class": pricing_function.__class__.__name__,
"model_type": "pricing_function",
}
)
meta_key = f"{self.metadata_prefix}{model_name}"
self.redis_client.set(meta_key, json.dumps(meta))
log.info(f"Published pricing model '{model_name}' ({meta['model_class']})")
def get_pricing_model(self, model_name: str = "latest"):
"""Retrieve a pricing function from registry."""
key = f"{self.data_prefix}{model_name}"
model_bytes = self.redis_client.get(key)
if model_bytes is None:
return None
return pickle.loads(model_bytes)
def list_models(self) -> Dict[str, Any]:
"""List all registered models with metadata."""
models = {}
for key in self.redis_client.scan_iter(f"{self.metadata_prefix}*"):
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
model_name = key_str.replace(self.metadata_prefix, "")
meta_json = self.redis_client.get(key)
if meta_json:
if isinstance(meta_json, bytes):
meta_json = meta_json.decode("utf-8")
models[model_name] = json.loads(meta_json)
return models
def publish_prices(
self,
prices_df: pd.DataFrame,
model_name: str = "latest",
metadata: Optional[Dict[str, Any]] = None,
):
"""Store predicted prices in registry.
Args:
prices_df: df with [productId, predicted_price, ...]
model_name: identifier for this price snapshot
metadata: additional info
"""
key = f"{self.prices_prefix}{model_name}"
data_json = prices_df.to_json(orient="records")
self.redis_client.set(key, data_json)
meta = metadata or {}
meta.update({"n_products": len(prices_df), "model_type": "predicted_prices"})
meta_key = f"{self.metadata_prefix}prices_{model_name}"
self.redis_client.set(meta_key, json.dumps(meta))
log.info(f"Published prices '{model_name}' for {len(prices_df)} products")
def get_prices(self, model_name: str = "latest") -> Optional[pd.DataFrame]:
"""Retrieve predicted prices from registry."""
key = f"{self.prices_prefix}{model_name}"
data_json = self.redis_client.get(key)
if data_json is None:
return None
if isinstance(data_json, bytes):
data_json = data_json.decode("utf-8")
return pd.read_json(StringIO(data_json), orient="records")
def health_check(self) -> bool:
"""Check if Redis connection is alive."""
try:
self.redis_client.ping()
return True
except:
return False
def set_session_prices(
self, session_id: str, prices: Dict[str, float], ttl: int = 1800
):
"""
Store prices for a specific session.
THIS is the write path for session-aware pricing.
Args:
session_id: session identifier
prices: dict of {productId: price}
ttl: time-to-live in seconds (default 30min)
"""
if not prices:
return
key = f"session:{session_id}:prices"
# use Redis hash for O(1) lookup per product
self.redis_client.hset(key, mapping={k: str(v) for k, v in prices.items()})
self.redis_client.expire(key, ttl)
def get_session_price(self, session_id: str, product_id: str) -> Optional[float]:
"""
Lookup price for (sessionId, productId).
THIS is the read path for fast provider lookup.
Returns: price or None if not found
"""
key = f"session:{session_id}:prices"
price_str = self.redis_client.hget(key, product_id)
if price_str is None:
return None
return float(
price_str.decode("utf-8") if isinstance(price_str, bytes) else price_str
)
def get_session_all_prices(self, session_id: str) -> Dict[str, float]:
"""Get all prices for a session."""
key = f"session:{session_id}:prices"
prices_raw = self.redis_client.hgetall(key)
if not prices_raw:
return {}
return {
(k.decode("utf-8") if isinstance(k, bytes) else k): float(
v.decode("utf-8") if isinstance(v, bytes) else v
)
for k, v in prices_raw.items()
}