mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
242 lines
7.6 KiB
Python
Executable File
242 lines
7.6 KiB
Python
Executable File
import redis
|
|
import pickle
|
|
import json
|
|
import pandas as pd
|
|
from io import StringIO
|
|
from typing import Optional, Dict, Any
|
|
import os
|
|
import logging
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelRegistry:
|
|
"""
|
|
Lightweight model registry using Redis for storing pricing models and elasticity data.
|
|
Models are serialized using pickle, metadata stored as JSON.
|
|
"""
|
|
|
|
def __init__(self, redis_host: str = None, redis_port: int = None):
|
|
host = redis_host or os.getenv("REDIS_HOST", "localhost")
|
|
port = redis_port or int(os.getenv("REDIS_PORT", "6378"))
|
|
|
|
self.redis_client = redis.Redis(
|
|
host=host, port=port, db=0, decode_responses=False
|
|
)
|
|
self.metadata_prefix = "model:meta:"
|
|
self.data_prefix = "model:data:"
|
|
self.elasticity_prefix = "elasticity:"
|
|
self.prices_prefix = "prices:"
|
|
|
|
def publish_elasticity(
|
|
self,
|
|
elasticity_df: pd.DataFrame,
|
|
model_name: str = "latest",
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
):
|
|
"""
|
|
Store elasticity estimates in registry.
|
|
|
|
Args:
|
|
elasticity_df: df with [productId, elasticity, std_error, n_obs]
|
|
model_name: identifier for this elasticity snapshot
|
|
metadata: additional info (timestamp, window_size, etc)
|
|
"""
|
|
key = f"{self.elasticity_prefix}{model_name}"
|
|
|
|
# serialize dataframe as JSON
|
|
data_json = elasticity_df.to_json(orient="records")
|
|
|
|
# store data
|
|
self.redis_client.set(key, data_json)
|
|
|
|
# store metadata
|
|
meta = metadata or {}
|
|
meta.update(
|
|
{
|
|
"n_products": len(elasticity_df),
|
|
"mean_elasticity": float(elasticity_df["elasticity"].mean()),
|
|
"model_type": "elasticity_snapshot",
|
|
}
|
|
)
|
|
|
|
meta_key = f"{self.metadata_prefix}{model_name}"
|
|
self.redis_client.set(meta_key, json.dumps(meta))
|
|
|
|
log.info(
|
|
f"Published elasticity model '{model_name}' with {len(elasticity_df)} products"
|
|
)
|
|
|
|
def get_elasticity(self, model_name: str = "latest") -> Optional[pd.DataFrame]:
|
|
"""Retrieve elasticity estimates from registry."""
|
|
key = f"{self.elasticity_prefix}{model_name}"
|
|
data_json = self.redis_client.get(key)
|
|
|
|
if data_json is None:
|
|
return None
|
|
|
|
# decode bytes to string if needed
|
|
if isinstance(data_json, bytes):
|
|
data_json = data_json.decode("utf-8")
|
|
|
|
return pd.read_json(StringIO(data_json), orient="records")
|
|
|
|
def publish_pricing_model(
|
|
self,
|
|
pricing_function,
|
|
model_name: str = "latest",
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
):
|
|
"""
|
|
Store a fitted pricing function object.
|
|
|
|
Args:
|
|
pricing_function: fitted PricingFunction instance
|
|
model_name: identifier
|
|
metadata: additional info
|
|
"""
|
|
key = f"{self.data_prefix}{model_name}"
|
|
|
|
# serialize object
|
|
model_bytes = pickle.dumps(pricing_function)
|
|
self.redis_client.set(key, model_bytes)
|
|
|
|
# store metadata
|
|
meta = metadata or {}
|
|
meta.update(
|
|
{
|
|
"model_class": pricing_function.__class__.__name__,
|
|
"model_type": "pricing_function",
|
|
}
|
|
)
|
|
|
|
meta_key = f"{self.metadata_prefix}{model_name}"
|
|
self.redis_client.set(meta_key, json.dumps(meta))
|
|
|
|
log.info(f"Published pricing model '{model_name}' ({meta['model_class']})")
|
|
|
|
def get_pricing_model(self, model_name: str = "latest"):
|
|
"""Retrieve a pricing function from registry."""
|
|
key = f"{self.data_prefix}{model_name}"
|
|
model_bytes = self.redis_client.get(key)
|
|
|
|
if model_bytes is None:
|
|
return None
|
|
|
|
return pickle.loads(model_bytes)
|
|
|
|
def list_models(self) -> Dict[str, Any]:
|
|
"""List all registered models with metadata."""
|
|
models = {}
|
|
|
|
for key in self.redis_client.scan_iter(f"{self.metadata_prefix}*"):
|
|
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
|
model_name = key_str.replace(self.metadata_prefix, "")
|
|
meta_json = self.redis_client.get(key)
|
|
|
|
if meta_json:
|
|
if isinstance(meta_json, bytes):
|
|
meta_json = meta_json.decode("utf-8")
|
|
models[model_name] = json.loads(meta_json)
|
|
|
|
return models
|
|
|
|
def publish_prices(
|
|
self,
|
|
prices_df: pd.DataFrame,
|
|
model_name: str = "latest",
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
):
|
|
"""Store predicted prices in registry.
|
|
|
|
Args:
|
|
prices_df: df with [productId, predicted_price, ...]
|
|
model_name: identifier for this price snapshot
|
|
metadata: additional info
|
|
"""
|
|
key = f"{self.prices_prefix}{model_name}"
|
|
data_json = prices_df.to_json(orient="records")
|
|
|
|
self.redis_client.set(key, data_json)
|
|
|
|
meta = metadata or {}
|
|
meta.update({"n_products": len(prices_df), "model_type": "predicted_prices"})
|
|
|
|
meta_key = f"{self.metadata_prefix}prices_{model_name}"
|
|
self.redis_client.set(meta_key, json.dumps(meta))
|
|
|
|
log.info(f"Published prices '{model_name}' for {len(prices_df)} products")
|
|
|
|
def get_prices(self, model_name: str = "latest") -> Optional[pd.DataFrame]:
|
|
"""Retrieve predicted prices from registry."""
|
|
key = f"{self.prices_prefix}{model_name}"
|
|
data_json = self.redis_client.get(key)
|
|
|
|
if data_json is None:
|
|
return None
|
|
|
|
if isinstance(data_json, bytes):
|
|
data_json = data_json.decode("utf-8")
|
|
|
|
return pd.read_json(StringIO(data_json), orient="records")
|
|
|
|
def health_check(self) -> bool:
|
|
"""Check if Redis connection is alive."""
|
|
try:
|
|
self.redis_client.ping()
|
|
return True
|
|
except:
|
|
return False
|
|
|
|
def set_session_prices(
|
|
self, session_id: str, prices: Dict[str, float], ttl: int = 1800
|
|
):
|
|
"""
|
|
Store prices for a specific session.
|
|
THIS is the write path for session-aware pricing.
|
|
|
|
Args:
|
|
session_id: session identifier
|
|
prices: dict of {productId: price}
|
|
ttl: time-to-live in seconds (default 30min)
|
|
"""
|
|
if not prices:
|
|
return
|
|
|
|
key = f"session:{session_id}:prices"
|
|
# use Redis hash for O(1) lookup per product
|
|
self.redis_client.hset(key, mapping={k: str(v) for k, v in prices.items()})
|
|
self.redis_client.expire(key, ttl)
|
|
|
|
def get_session_price(self, session_id: str, product_id: str) -> Optional[float]:
|
|
"""
|
|
Lookup price for (sessionId, productId).
|
|
THIS is the read path for fast provider lookup.
|
|
|
|
Returns: price or None if not found
|
|
"""
|
|
key = f"session:{session_id}:prices"
|
|
price_str = self.redis_client.hget(key, product_id)
|
|
|
|
if price_str is None:
|
|
return None
|
|
|
|
return float(
|
|
price_str.decode("utf-8") if isinstance(price_str, bytes) else price_str
|
|
)
|
|
|
|
def get_session_all_prices(self, session_id: str) -> Dict[str, float]:
|
|
"""Get all prices for a session."""
|
|
key = f"session:{session_id}:prices"
|
|
prices_raw = self.redis_client.hgetall(key)
|
|
|
|
if not prices_raw:
|
|
return {}
|
|
|
|
return {
|
|
(k.decode("utf-8") if isinstance(k, bytes) else k): float(
|
|
v.decode("utf-8") if isinstance(v, bytes) else v
|
|
)
|
|
for k, v in prices_raw.items()
|
|
}
|