mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
shock: defining new lab environment and formulation
This commit is contained in:
17
lab/outlet/__init__.py
Normal file
17
lab/outlet/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .constants import Side, MechanismType, InstrumentType, OpportunityType, EventType, LogLevel
|
||||
from .types import (Instrument, InstrumentSet, Quote, Opportunity, Execution,
|
||||
StepEvent, StepLogs, StepMetrics, MarketState, HiddenState, Observation, StepResult)
|
||||
from .stock import PositionModel, PositionConfig, make_instruments
|
||||
from .platform import Platform, PlatformConfig
|
||||
from .observation import DefaultObservationBuilder, ObservationConfig
|
||||
from .mechanisms import PostedPriceMechanism, TwoSidedMechanism, AuctionMechanism
|
||||
|
||||
__all__ = [
|
||||
'Side', 'MechanismType', 'InstrumentType', 'OpportunityType', 'EventType', 'LogLevel',
|
||||
'Instrument', 'InstrumentSet', 'Quote', 'Opportunity', 'Execution',
|
||||
'StepEvent', 'StepLogs', 'StepMetrics', 'MarketState', 'HiddenState', 'Observation', 'StepResult',
|
||||
'PositionModel', 'PositionConfig', 'make_instruments',
|
||||
'Platform', 'PlatformConfig',
|
||||
'DefaultObservationBuilder', 'ObservationConfig',
|
||||
'PostedPriceMechanism', 'TwoSidedMechanism', 'AuctionMechanism',
|
||||
]
|
||||
83
lab/outlet/constants.py
Normal file
83
lab/outlet/constants.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Constants and enumerations for the Quote-Control simulator.
|
||||
|
||||
This module defines the core enums used throughout the platform to ensure
|
||||
type safety and consistent semantics across different pricing mechanisms.
|
||||
"""
|
||||
from enum import Enum, auto
|
||||
|
||||
class Side(Enum):
|
||||
"""Transaction side indicator.
|
||||
|
||||
Attributes:
|
||||
BUY: Buyer-initiated transaction (customer purchases, market buy order)
|
||||
SELL: Seller-initiated transaction (market sell order, short sale)
|
||||
"""
|
||||
BUY = auto()
|
||||
SELL = auto()
|
||||
|
||||
class MechanismType(Enum):
|
||||
"""Pricing mechanism type defining how quotes translate to executions.
|
||||
|
||||
Attributes:
|
||||
POSTED_PRICE: Single posted price per instrument (retail dynamic pricing)
|
||||
TWO_SIDED_QUOTE: Bid-ask spread quoting (market making, liquidity provision)
|
||||
AUCTION: Reserve price or bid shading (ad auctions, marketplaces)
|
||||
"""
|
||||
POSTED_PRICE = auto()
|
||||
TWO_SIDED_QUOTE = auto()
|
||||
AUCTION = auto()
|
||||
|
||||
class InstrumentType(Enum):
|
||||
"""Type of instrument being priced.
|
||||
|
||||
Attributes:
|
||||
SKU: Retail product with inventory constraints
|
||||
ASSET: Financial instrument with position limits
|
||||
LOAN: Credit product with interest rate pricing
|
||||
SUBSCRIPTION: Recurring service with periodic fees
|
||||
"""
|
||||
SKU = auto()
|
||||
ASSET = auto()
|
||||
LOAN = auto()
|
||||
SUBSCRIPTION = auto()
|
||||
|
||||
class OpportunityType(Enum):
|
||||
"""Type of arrival opportunity.
|
||||
|
||||
Attributes:
|
||||
SESSION: Retail browsing session with potential purchase intent
|
||||
MARKET_ORDER: Financial market order arrival (buy or sell)
|
||||
REQUEST: Service or credit request requiring quote response
|
||||
"""
|
||||
SESSION = auto()
|
||||
MARKET_ORDER = auto()
|
||||
REQUEST = auto()
|
||||
|
||||
class EventType(Enum):
|
||||
"""Type of logged event during simulation.
|
||||
|
||||
Attributes:
|
||||
ARRIVAL: New opportunity arrived in the system
|
||||
EXPOSURE: Quote was shown to an arrival
|
||||
EXECUTION: Transaction was executed
|
||||
ABANDON: Opportunity abandoned without execution
|
||||
CANCEL: Pending order was cancelled
|
||||
"""
|
||||
ARRIVAL = auto()
|
||||
EXPOSURE = auto()
|
||||
EXECUTION = auto()
|
||||
ABANDON = auto()
|
||||
CANCEL = auto()
|
||||
|
||||
class LogLevel(Enum):
|
||||
"""Verbosity level for step logging.
|
||||
|
||||
Attributes:
|
||||
NONE: No logging, fastest execution
|
||||
AGG_ONLY: Only aggregate statistics per step
|
||||
FULL: Full event-level logging with propensities for OPE
|
||||
"""
|
||||
NONE = auto()
|
||||
AGG_ONLY = auto()
|
||||
FULL = auto()
|
||||
86
lab/outlet/gym_wrapper.py
Normal file
86
lab/outlet/gym_wrapper.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Gymnasium-compatible wrapper for the Quote-Control platform.
|
||||
|
||||
Provides a standard Gym interface for RL training:
|
||||
- observation_space: Box space with flattened observation
|
||||
- action_space: Box space with price multipliers [0.5, 2.0]
|
||||
- reset(), step(), render(), close() methods
|
||||
|
||||
Example:
|
||||
>>> from lab.config import make_retail_platform
|
||||
>>> from lab.outlet.gym_wrapper import QuoteGymEnv
|
||||
>>> env = QuoteGymEnv(make_retail_platform())
|
||||
>>> obs, info = env.reset()
|
||||
>>> obs, reward, done, truncated, info = env.step(env.action_space.sample())
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
HAS_GYM = True
|
||||
except ImportError:
|
||||
HAS_GYM = False
|
||||
|
||||
from .platform import Platform, PlatformConfig
|
||||
from .types import Quote, InstrumentSet, StepResult
|
||||
|
||||
class QuoteGymEnv:
|
||||
"""Gymnasium-compatible environment wrapper.
|
||||
|
||||
Wraps a Platform instance with standard Gym interface.
|
||||
Actions are price multipliers in [0.5, 2.0] applied to reference prices.
|
||||
Observations are flattened numpy arrays containing quotes, fills, exposures.
|
||||
"""
|
||||
|
||||
def __init__(self, platform: Platform):
|
||||
if not HAS_GYM:
|
||||
raise ImportError("gymnasium required for QuoteGymEnv")
|
||||
self.platform = platform
|
||||
self.n = platform.instruments.n
|
||||
self._last_result: StepResult | None = None
|
||||
|
||||
# action space: price adjustments as multipliers [0.5, 2.0]
|
||||
self.action_space = spaces.Box(low=0.5, high=2.0, shape=(self.n,), dtype=np.float32)
|
||||
|
||||
# observation space
|
||||
obs_dim = self.n * 4 # quotes + fills + exposures + position
|
||||
if platform.market:
|
||||
obs_dim += self.n # competitor quotes
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf,
|
||||
shape=(obs_dim,), dtype=np.float32)
|
||||
|
||||
def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[np.ndarray, dict]:
|
||||
result = self.platform.reset(seed)
|
||||
self._last_result = result
|
||||
return result.obs.to_flat().astype(np.float32), result.info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict]:
|
||||
# convert action (multipliers) to absolute prices
|
||||
refs = self.platform.instruments.refs
|
||||
prices = refs * action
|
||||
result = self.platform.step(prices)
|
||||
self._last_result = result
|
||||
return (result.obs.to_flat().astype(np.float32), result.reward,
|
||||
result.terminated, result.truncated, result.info)
|
||||
|
||||
def render(self) -> None:
|
||||
if self._last_result:
|
||||
m = self._last_result.metrics
|
||||
print(f"t={self.platform._t} pnl={m.pnl:.2f} units={m.units_traded:.0f} "
|
||||
f"conv={m.conversion:.3f} vol={m.volatility:.3f}")
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
def make_env(platform: Platform) -> QuoteGymEnv:
|
||||
return QuoteGymEnv(platform)
|
||||
|
||||
if HAS_GYM:
|
||||
# register if gymnasium available
|
||||
try:
|
||||
gym.register(id='QuoteControl-v0', entry_point='outlet.gym_wrapper:QuoteGymEnv')
|
||||
except:
|
||||
pass # already registered or other issue
|
||||
57
lab/outlet/math_util.py
Normal file
57
lab/outlet/math_util.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Numerical utilities for stable computation.
|
||||
|
||||
This module provides numerically stable implementations of common operations:
|
||||
- safe_exp, safe_log: Avoid overflow/underflow
|
||||
- softmax: Numerically stable softmax
|
||||
- sigmoid, clamp: Standard transformations
|
||||
- intensity_decay: Avellaneda-Stoikov fill intensity
|
||||
- inventory_penalty: Quadratic inventory risk
|
||||
- poisson_arrivals, hawkes_intensity: Arrival process helpers
|
||||
|
||||
All functions accept both scalars and numpy arrays.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
EPS = 1e-8 # small constant to avoid division by zero
|
||||
MAX_EXP = 700.0 # maximum safe exponent to avoid overflow
|
||||
|
||||
def safe_exp(x: np.ndarray | float) -> np.ndarray | float:
|
||||
return np.exp(np.clip(x, -MAX_EXP, MAX_EXP))
|
||||
|
||||
def safe_log(x: np.ndarray | float) -> np.ndarray | float:
|
||||
return np.log(np.maximum(x, EPS))
|
||||
|
||||
def clamp(x: np.ndarray | float, lo: float, hi: float) -> np.ndarray | float:
|
||||
return np.clip(x, lo, hi)
|
||||
|
||||
def sigmoid(x: np.ndarray | float) -> np.ndarray | float:
|
||||
return 1.0 / (1.0 + safe_exp(-x))
|
||||
|
||||
def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
|
||||
x_max = np.max(x, axis=axis, keepdims=True)
|
||||
exp_x = safe_exp(x - x_max)
|
||||
return exp_x / (np.sum(exp_x, axis=axis, keepdims=True) + EPS)
|
||||
|
||||
def geometric_series(base: float, ratio: float, n: int) -> np.ndarray:
|
||||
return base * (ratio ** np.arange(n))
|
||||
|
||||
def ema(old: float, new: float, alpha: float = 0.1) -> float:
|
||||
return alpha * new + (1 - alpha) * old
|
||||
|
||||
def intensity_decay(distance: float, kappa: float = 1.0) -> float:
|
||||
"""Avellaneda-Stoikov style fill intensity decay with quote distance"""
|
||||
return safe_exp(-kappa * distance)
|
||||
|
||||
def inventory_penalty(q: float, gamma: float = 0.1, sigma: float = 1.0) -> float:
|
||||
"""Quadratic inventory risk penalty"""
|
||||
return gamma * sigma**2 * q**2 / 2
|
||||
|
||||
def poisson_arrivals(rate: float, dt: float, rng: np.random.Generator) -> int:
|
||||
return rng.poisson(rate * dt)
|
||||
|
||||
def hawkes_intensity(base: float, history: np.ndarray, alpha: float, beta: float, t: float) -> float:
|
||||
"""Self-exciting Hawkes process intensity"""
|
||||
if len(history) == 0: return base
|
||||
decays = safe_exp(-beta * (t - history[history < t]))
|
||||
return base + alpha * np.sum(decays)
|
||||
5
lab/outlet/mechanisms/__init__.py
Normal file
5
lab/outlet/mechanisms/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .posted_price import PostedPriceMechanism
|
||||
from .two_sided import TwoSidedMechanism
|
||||
from .auction import AuctionMechanism
|
||||
|
||||
__all__ = ['PostedPriceMechanism', 'TwoSidedMechanism', 'AuctionMechanism']
|
||||
73
lab/outlet/mechanisms/auction.py
Normal file
73
lab/outlet/mechanisms/auction.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Auction mechanism for reserve pricing and bid shading.
|
||||
|
||||
In this mechanism, the agent sets reserve prices that affect
|
||||
win probability and clearing prices. Used for ad auctions,
|
||||
marketplace auctions, and similar settings.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from ..types import Quote, Opportunity, Execution, InstrumentSet, MarketState
|
||||
from ..constants import Side
|
||||
from ..math_util import clamp, sigmoid
|
||||
|
||||
@dataclass
|
||||
class AuctionConfig:
|
||||
"""Configuration for auction mechanism.
|
||||
|
||||
Attributes:
|
||||
min_reserve: Minimum reserve price
|
||||
max_reserve: Maximum reserve price
|
||||
base_win_prob: Baseline win probability at reference reserve
|
||||
sensitivity: How much higher reserves reduce win probability
|
||||
"""
|
||||
min_reserve: float = 0.0
|
||||
max_reserve: float = 100.0
|
||||
base_win_prob: float = 0.3
|
||||
sensitivity: float = 2.0
|
||||
|
||||
class AuctionMechanism:
|
||||
"""Auction mechanism for reserve pricing.
|
||||
|
||||
The agent sets reserve prices that affect:
|
||||
- Win probability: higher reserves reduce chance of winning
|
||||
- Clearing price: bounded between reserve and simulated max bid
|
||||
|
||||
Win probability: base_prob * sigmoid(-sensitivity * (reserve - ref) / ref)
|
||||
Clearing price: max(reserve, min(max_bid, reserve + random_increment))
|
||||
|
||||
Only BUY-side opportunities are processed (auction wins).
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: AuctionConfig | None = None):
|
||||
self.cfg = cfg or AuctionConfig()
|
||||
|
||||
def apply_quote(self, quote: Quote, instruments: InstrumentSet,
|
||||
rng: np.random.Generator) -> Quote:
|
||||
reserves = clamp(quote.prices, self.cfg.min_reserve, self.cfg.max_reserve)
|
||||
return Quote(prices=reserves, propensity=quote.propensity, metadata=quote.metadata)
|
||||
|
||||
def process_opportunity(self, opp: Opportunity, quote: Quote,
|
||||
instruments: InstrumentSet, market: MarketState | None,
|
||||
rng: np.random.Generator) -> Execution | None:
|
||||
if opp.side != Side.BUY: return None
|
||||
idx = int(opp.instrument_id)
|
||||
reserve = float(quote.prices[idx])
|
||||
ref = instruments.refs[idx]
|
||||
|
||||
# win probability decreases with higher reserve
|
||||
relative_reserve = (reserve - ref) / (ref + 1e-8)
|
||||
win_prob = self.cfg.base_win_prob * sigmoid(-self.cfg.sensitivity * relative_reserve)
|
||||
|
||||
if rng.random() > win_prob: return None
|
||||
|
||||
# clearing price is between reserve and some max bid (simulated)
|
||||
max_bid = ref * (1 + rng.exponential(0.2))
|
||||
clearing = max(reserve, min(max_bid, reserve + rng.exponential(0.1) * ref))
|
||||
|
||||
return Execution(
|
||||
opportunity_id=opp.id, instrument_id=opp.instrument_id,
|
||||
side=opp.side, size_requested=opp.size, size_filled=opp.size,
|
||||
price=clearing, propensity=quote.propensity * win_prob, t=opp.t
|
||||
)
|
||||
84
lab/outlet/mechanisms/posted_price.py
Normal file
84
lab/outlet/mechanisms/posted_price.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Posted price mechanism for retail dynamic pricing.
|
||||
|
||||
In this mechanism, the agent posts a single price per instrument.
|
||||
Buyers decide whether to purchase based on the posted price.
|
||||
This is the standard e-commerce dynamic pricing model.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from ..types import Quote, Opportunity, Execution, InstrumentSet, MarketState
|
||||
from ..constants import Side
|
||||
from ..math_util import clamp
|
||||
|
||||
@dataclass
|
||||
class PostedPriceConfig:
|
||||
"""Configuration for posted price mechanism.
|
||||
|
||||
Attributes:
|
||||
min_price: Absolute minimum price
|
||||
max_price: Absolute maximum price
|
||||
max_delta_pct: Maximum price change per step as fraction of previous
|
||||
min_margin_pct: Minimum margin over cost basis
|
||||
round_to: Price rounding granularity (None = no rounding)
|
||||
"""
|
||||
min_price: float = 0.01
|
||||
max_price: float = 1000.0
|
||||
max_delta_pct: float = 0.2
|
||||
min_margin_pct: float = 0.05
|
||||
round_to: float | None = 0.01
|
||||
|
||||
class PostedPriceMechanism:
|
||||
"""Posted price mechanism for retail dynamic pricing.
|
||||
|
||||
The agent posts a single price per product. Constraints enforced:
|
||||
- Prices within [min_price, max_price]
|
||||
- Margin at least min_margin_pct above cost
|
||||
- Price changes limited to max_delta_pct per step
|
||||
- Prices rounded to round_to granularity
|
||||
|
||||
Only BUY-side opportunities are processed (customers purchasing).
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: PostedPriceConfig | None = None):
|
||||
self.cfg = cfg or PostedPriceConfig()
|
||||
|
||||
def apply_quote(self, quote: Quote, instruments: InstrumentSet,
|
||||
rng: np.random.Generator) -> Quote:
|
||||
prices = quote.prices.copy()
|
||||
costs = instruments.costs
|
||||
refs = instruments.refs
|
||||
c = self.cfg
|
||||
|
||||
# enforce min margin
|
||||
min_prices = costs * (1 + c.min_margin_pct)
|
||||
prices = np.maximum(prices, min_prices)
|
||||
|
||||
# enforce absolute bounds
|
||||
prices = clamp(prices, c.min_price, c.max_price)
|
||||
|
||||
# enforce max delta if we have history
|
||||
if 'prev_prices' in quote.metadata:
|
||||
prev = quote.metadata['prev_prices']
|
||||
max_change = prev * c.max_delta_pct
|
||||
prices = clamp(prices, prev - max_change, prev + max_change)
|
||||
|
||||
# round prices
|
||||
if c.round_to:
|
||||
prices = np.round(prices / c.round_to) * c.round_to
|
||||
|
||||
return Quote(prices=prices, propensity=quote.propensity,
|
||||
metadata={**quote.metadata, 'prev_prices': prices})
|
||||
|
||||
def process_opportunity(self, opp: Opportunity, quote: Quote,
|
||||
instruments: InstrumentSet, market: MarketState | None,
|
||||
rng: np.random.Generator) -> Execution | None:
|
||||
if opp.side != Side.BUY: return None # posted price is buy-only
|
||||
idx = int(opp.instrument_id)
|
||||
price = float(quote.prices[idx])
|
||||
return Execution(
|
||||
opportunity_id=opp.id, instrument_id=opp.instrument_id,
|
||||
side=opp.side, size_requested=opp.size, size_filled=opp.size,
|
||||
price=price, propensity=quote.propensity, t=opp.t
|
||||
)
|
||||
89
lab/outlet/mechanisms/two_sided.py
Normal file
89
lab/outlet/mechanisms/two_sided.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Two-sided quoting mechanism for market making.
|
||||
|
||||
In this mechanism, the agent posts both bid and ask prices.
|
||||
Execution depends on the distance from the market mid-price.
|
||||
This models liquidity provision in financial markets.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from ..types import Quote, Opportunity, Execution, InstrumentSet, MarketState
|
||||
from ..constants import Side
|
||||
from ..math_util import clamp, intensity_decay
|
||||
|
||||
@dataclass
|
||||
class TwoSidedConfig:
|
||||
"""Configuration for two-sided quoting mechanism.
|
||||
|
||||
Attributes:
|
||||
min_spread: Minimum bid-ask spread
|
||||
max_spread: Maximum bid-ask spread
|
||||
min_price: Absolute minimum price
|
||||
max_price: Absolute maximum price
|
||||
fill_kappa: Intensity decay parameter (higher = faster decay with distance)
|
||||
"""
|
||||
min_spread: float = 0.01
|
||||
max_spread: float = 0.5
|
||||
min_price: float = 0.01
|
||||
max_price: float = 10000.0
|
||||
fill_kappa: float = 1.5
|
||||
|
||||
class TwoSidedMechanism:
|
||||
"""Two-sided quoting mechanism for market making.
|
||||
|
||||
The agent posts bid (buy) and ask (sell) prices around a mid-point.
|
||||
Fill probability decays exponentially with distance from mid-price,
|
||||
following the Avellaneda-Stoikov intensity model.
|
||||
|
||||
Both BUY and SELL opportunities are processed:
|
||||
- BUY: customer buys at agent's ask price
|
||||
- SELL: customer sells at agent's bid price
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: TwoSidedConfig | None = None):
|
||||
self.cfg = cfg or TwoSidedConfig()
|
||||
|
||||
def apply_quote(self, quote: Quote, instruments: InstrumentSet,
|
||||
rng: np.random.Generator) -> Quote:
|
||||
prices = quote.prices.copy()
|
||||
spreads = quote.spreads.copy() if quote.spreads is not None else np.full_like(prices, 0.02)
|
||||
c = self.cfg
|
||||
|
||||
prices = clamp(prices, c.min_price, c.max_price)
|
||||
spreads = clamp(spreads, c.min_spread, c.max_spread)
|
||||
|
||||
# ensure bids < asks
|
||||
half_spread = spreads / 2
|
||||
bids = prices - half_spread
|
||||
asks = prices + half_spread
|
||||
bids = np.maximum(bids, c.min_price)
|
||||
asks = np.minimum(asks, c.max_price)
|
||||
spreads = asks - bids
|
||||
prices = (bids + asks) / 2
|
||||
|
||||
return Quote(prices=prices, spreads=spreads, propensity=quote.propensity,
|
||||
metadata=quote.metadata)
|
||||
|
||||
def process_opportunity(self, opp: Opportunity, quote: Quote,
|
||||
instruments: InstrumentSet, market: MarketState | None,
|
||||
rng: np.random.Generator) -> Execution | None:
|
||||
idx = int(opp.instrument_id)
|
||||
mid = market.mid_prices[idx] if market and market.mid_prices is not None else quote.prices[idx]
|
||||
|
||||
if opp.side == Side.BUY:
|
||||
price = float(quote.asks[idx]) if quote.asks is not None else float(quote.prices[idx])
|
||||
distance = price - mid
|
||||
else:
|
||||
price = float(quote.bids[idx]) if quote.bids is not None else float(quote.prices[idx])
|
||||
distance = mid - price
|
||||
|
||||
# probabilistic fill based on distance from mid
|
||||
fill_prob = intensity_decay(abs(distance), self.cfg.fill_kappa)
|
||||
if rng.random() > fill_prob: return None
|
||||
|
||||
return Execution(
|
||||
opportunity_id=opp.id, instrument_id=opp.instrument_id,
|
||||
side=opp.side, size_requested=opp.size, size_filled=opp.size,
|
||||
price=price, propensity=quote.propensity * fill_prob, t=opp.t
|
||||
)
|
||||
11
lab/outlet/objectives/__init__.py
Normal file
11
lab/outlet/objectives/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .base import BaseObjective, CompositeObjective
|
||||
from .penalties import (PnLObjective, VolatilityPenalty, HoldingCostPenalty,
|
||||
LostOpportunityCostPenalty, InventoryRiskPenalty, SpreadCaptureReward)
|
||||
from .factory import make_objective, make_composite, retail_objective, market_making_objective
|
||||
|
||||
__all__ = [
|
||||
'BaseObjective', 'CompositeObjective',
|
||||
'PnLObjective', 'VolatilityPenalty', 'HoldingCostPenalty',
|
||||
'LostOpportunityCostPenalty', 'InventoryRiskPenalty', 'SpreadCaptureReward',
|
||||
'make_objective', 'make_composite', 'retail_objective', 'market_making_objective',
|
||||
]
|
||||
48
lab/outlet/objectives/base.py
Normal file
48
lab/outlet/objectives/base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Base classes for reward objectives.
|
||||
|
||||
Objectives compute scalar rewards from step metrics. The CompositeObjective
|
||||
allows combining multiple objectives with weights for multi-objective optimization.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from ..types import Quote, InstrumentSet, StepMetrics, HiddenState, Observation
|
||||
|
||||
class BaseObjective(ABC):
|
||||
"""Abstract base class for reward objectives.
|
||||
|
||||
Subclasses must implement reward() and breakdown() methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reward(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> float: ...
|
||||
|
||||
@abstractmethod
|
||||
def breakdown(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> dict[str, float]: ...
|
||||
|
||||
class CompositeObjective(BaseObjective):
|
||||
"""Weighted sum of multiple objectives.
|
||||
|
||||
Allows combining multiple reward terms (e.g., PnL - holding_cost - volatility).
|
||||
|
||||
Args:
|
||||
objectives: List of (objective, weight) tuples
|
||||
"""
|
||||
|
||||
def __init__(self, objectives: list[tuple[BaseObjective, float]]):
|
||||
self.objectives = objectives
|
||||
|
||||
def reward(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> float:
|
||||
return sum(w * obj.reward(quote, instruments, metrics, hidden, obs)
|
||||
for obj, w in self.objectives)
|
||||
|
||||
def breakdown(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> dict[str, float]:
|
||||
bd = {}
|
||||
for obj, w in self.objectives:
|
||||
for k, v in obj.breakdown(quote, instruments, metrics, hidden, obs).items():
|
||||
bd[k] = w * v
|
||||
return bd
|
||||
82
lab/outlet/objectives/factory.py
Normal file
82
lab/outlet/objectives/factory.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Factory functions for creating objectives.
|
||||
|
||||
Provides:
|
||||
- make_objective: Create single objective by name
|
||||
- make_composite: Create weighted combination of objectives
|
||||
- retail_objective: Default objective for retail pricing
|
||||
- market_making_objective: Default objective for market making
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from .base import BaseObjective, CompositeObjective
|
||||
from .penalties import (PnLObjective, VolatilityPenalty, HoldingCostPenalty,
|
||||
LostOpportunityCostPenalty, InventoryRiskPenalty, SpreadCaptureReward)
|
||||
|
||||
REGISTRY: dict[str, type[BaseObjective]] = {
|
||||
'pnl': PnLObjective,
|
||||
'volatility': VolatilityPenalty,
|
||||
'holding_cost': HoldingCostPenalty,
|
||||
'lost_opportunity': LostOpportunityCostPenalty,
|
||||
'inventory_risk': InventoryRiskPenalty,
|
||||
'spread_capture': SpreadCaptureReward,
|
||||
}
|
||||
|
||||
def make_objective(name: str, **kwargs) -> BaseObjective:
|
||||
"""Create an objective by name.
|
||||
|
||||
Args:
|
||||
name: Objective name (pnl, volatility, holding_cost, lost_opportunity,
|
||||
inventory_risk, spread_capture)
|
||||
**kwargs: Passed to objective constructor
|
||||
|
||||
Returns:
|
||||
Instantiated objective
|
||||
"""
|
||||
if name not in REGISTRY:
|
||||
raise ValueError(f"Unknown objective: {name}. Available: {list(REGISTRY.keys())}")
|
||||
return REGISTRY[name](**kwargs)
|
||||
|
||||
def make_composite(spec: list[tuple[str, float, dict]] | dict[str, float]) -> CompositeObjective:
|
||||
"""Create composite objective from specification.
|
||||
|
||||
Args:
|
||||
spec: Either:
|
||||
- list of (name, weight, kwargs) tuples for full control
|
||||
- dict of {name: weight} for simple cases
|
||||
|
||||
Returns:
|
||||
CompositeObjective with specified components
|
||||
"""
|
||||
objectives = []
|
||||
if isinstance(spec, dict):
|
||||
for name, weight in spec.items():
|
||||
objectives.append((make_objective(name), weight))
|
||||
else:
|
||||
for name, weight, kwargs in spec:
|
||||
objectives.append((make_objective(name, **kwargs), weight))
|
||||
return CompositeObjective(objectives)
|
||||
|
||||
def retail_objective(volatility_weight: float = 0.1, holding_weight: float = 0.5,
|
||||
stockout_weight: float = 0.3) -> CompositeObjective:
|
||||
"""Default objective for retail dynamic pricing.
|
||||
|
||||
Reward = PnL - volatility_weight*volatility - holding_weight*holding_cost
|
||||
- stockout_weight*lost_opportunity
|
||||
"""
|
||||
return make_composite({
|
||||
'pnl': 1.0,
|
||||
'volatility': volatility_weight,
|
||||
'holding_cost': holding_weight,
|
||||
'lost_opportunity': stockout_weight,
|
||||
})
|
||||
|
||||
def market_making_objective(gamma: float = 0.1, sigma: float = 1.0) -> CompositeObjective:
|
||||
"""Default objective for market making.
|
||||
|
||||
Reward = PnL + 0.5*spread_capture - inventory_risk(gamma, sigma)
|
||||
"""
|
||||
return CompositeObjective([
|
||||
(PnLObjective(), 1.0),
|
||||
(SpreadCaptureReward(), 0.5),
|
||||
(InventoryRiskPenalty(gamma=gamma, sigma=sigma), 1.0),
|
||||
])
|
||||
101
lab/outlet/objectives/penalties.py
Normal file
101
lab/outlet/objectives/penalties.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Standard objective components and penalties.
|
||||
|
||||
This module provides common reward terms:
|
||||
- PnLObjective: Basic profit and loss
|
||||
- VolatilityPenalty: Penalize price volatility for UX
|
||||
- HoldingCostPenalty: Inventory holding cost
|
||||
- LostOpportunityCostPenalty: Stockout/missed fill cost
|
||||
- InventoryRiskPenalty: Quadratic inventory risk (market making)
|
||||
- SpreadCaptureReward: Bid-ask spread capture (market making)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from .base import BaseObjective
|
||||
from ..types import Quote, InstrumentSet, StepMetrics, HiddenState, Observation
|
||||
from ..math_util import inventory_penalty
|
||||
|
||||
class PnLObjective(BaseObjective):
|
||||
"""Profit and loss reward (revenue - cost)."""
|
||||
|
||||
def reward(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> float:
|
||||
return metrics.pnl
|
||||
|
||||
def breakdown(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> dict[str, float]:
|
||||
return {'pnl': metrics.pnl, 'revenue': metrics.revenue, 'cost': metrics.cost}
|
||||
|
||||
class VolatilityPenalty(BaseObjective):
|
||||
"""Penalize price volatility for user experience."""
|
||||
|
||||
def __init__(self, scale: float = 1.0):
|
||||
self.scale = scale
|
||||
|
||||
def reward(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> float:
|
||||
return -self.scale * metrics.volatility
|
||||
|
||||
def breakdown(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> dict[str, float]:
|
||||
return {'volatility_penalty': -self.scale * metrics.volatility}
|
||||
|
||||
class HoldingCostPenalty(BaseObjective):
|
||||
"""Penalty for inventory holding costs."""
|
||||
|
||||
def __init__(self, scale: float = 1.0):
|
||||
self.scale = scale
|
||||
|
||||
def reward(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> float:
|
||||
return -self.scale * metrics.position_cost
|
||||
|
||||
def breakdown(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> dict[str, float]:
|
||||
return {'holding_cost_penalty': -self.scale * metrics.position_cost}
|
||||
|
||||
class LostOpportunityCostPenalty(BaseObjective):
|
||||
"""Penalty for lost sales due to stockouts or missed fills."""
|
||||
|
||||
def __init__(self, scale: float = 1.0):
|
||||
self.scale = scale
|
||||
|
||||
def reward(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> float:
|
||||
return -self.scale * metrics.lost_opportunity
|
||||
|
||||
def breakdown(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> dict[str, float]:
|
||||
return {'lost_opportunity_penalty': -self.scale * metrics.lost_opportunity}
|
||||
|
||||
class InventoryRiskPenalty(BaseObjective):
|
||||
"""Quadratic inventory risk penalty (Avellaneda-Stoikov style).
|
||||
|
||||
Penalty = gamma * sigma^2 * q^2 / 2, where q is total position.
|
||||
Encourages market makers to keep inventory near zero.
|
||||
"""
|
||||
|
||||
def __init__(self, gamma: float = 0.1, sigma: float = 1.0):
|
||||
self.gamma = gamma
|
||||
self.sigma = sigma
|
||||
|
||||
def reward(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> float:
|
||||
if obs.position is None: return 0.0
|
||||
q = np.sum(obs.position)
|
||||
return -inventory_penalty(q, self.gamma, self.sigma)
|
||||
|
||||
def breakdown(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> dict[str, float]:
|
||||
return {'inventory_risk_penalty': self.reward(quote, instruments, metrics, hidden, obs)}
|
||||
|
||||
class SpreadCaptureReward(BaseObjective):
|
||||
"""Reward for capturing bid-ask spread in market making."""
|
||||
|
||||
def reward(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> float:
|
||||
return metrics.spread_capture
|
||||
|
||||
def breakdown(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState, obs: Observation) -> dict[str, float]:
|
||||
return {'spread_capture': metrics.spread_capture}
|
||||
92
lab/outlet/observation.py
Normal file
92
lab/outlet/observation.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Observation construction with demand censoring.
|
||||
|
||||
This module provides the ObservationBuilder that constructs agent observations
|
||||
from step data. The key invariant is that observations only contain censored
|
||||
data (fills) and never true demand, ensuring proper research conditions.
|
||||
|
||||
The ObservationConfig controls what is included in observations:
|
||||
- Position visibility
|
||||
- Market/competitor visibility
|
||||
- Demand proxy method
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from .types import Quote, InstrumentSet, StepLogs, StepMetrics, MarketState, HiddenState, Observation
|
||||
|
||||
@dataclass
|
||||
class ObservationConfig:
|
||||
"""Configuration for observation construction.
|
||||
|
||||
Attributes:
|
||||
include_position: Include current position in observation
|
||||
include_market: Include market/competitor state in observation
|
||||
mask_true_demand: If True, observation excludes true demand (research mode)
|
||||
demand_proxy: Method for demand proxy ('fills', 'exposures', 'weighted')
|
||||
exposure_weights: Weights for weighted demand proxy
|
||||
"""
|
||||
include_position: bool = True
|
||||
include_market: bool = True
|
||||
mask_true_demand: bool = True
|
||||
demand_proxy: str = 'fills'
|
||||
exposure_weights: dict[str, float] | None = None
|
||||
|
||||
class DefaultObservationBuilder:
|
||||
"""Constructs censored observations for the agent.
|
||||
|
||||
Ensures the key research invariant: observations contain only
|
||||
censored fills (realized sales), never true demand. True demand
|
||||
is placed in the info dict for research analysis only.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: ObservationConfig | None = None):
|
||||
self.cfg = cfg or ObservationConfig()
|
||||
|
||||
def build(self, quote: Quote, instruments: InstrumentSet, logs: StepLogs,
|
||||
metrics: StepMetrics, market: MarketState | None,
|
||||
hidden: HiddenState, mask_demand: bool, t: int) -> Observation:
|
||||
n = instruments.n
|
||||
cfg = self.cfg
|
||||
|
||||
# always show censored fills
|
||||
fills = logs.censored_fills if logs.censored_fills is not None else np.zeros(n)
|
||||
|
||||
# compute exposures from logs
|
||||
if logs.events:
|
||||
exposures = np.zeros(n)
|
||||
for e in logs.events:
|
||||
if e.instrument_id is not None:
|
||||
exposures[e.instrument_id] += 1
|
||||
else:
|
||||
exposures = logs.aggregates.get('exposures', np.zeros(n))
|
||||
|
||||
# position - only if configured and available
|
||||
position = None
|
||||
if cfg.include_position and instruments.position is not None:
|
||||
position = instruments.position.copy()
|
||||
|
||||
# market state - only if configured
|
||||
obs_market = market if cfg.include_market else None
|
||||
|
||||
return Observation(
|
||||
quotes=quote.prices.copy(),
|
||||
position=position,
|
||||
fills=fills,
|
||||
exposures=exposures,
|
||||
market=obs_market,
|
||||
t=t
|
||||
)
|
||||
|
||||
def make_space(self, n_instruments: int, include_market: bool = True) -> dict:
|
||||
"""Returns dict describing observation space for gym"""
|
||||
space = {
|
||||
'quotes': {'shape': (n_instruments,), 'low': 0, 'high': np.inf},
|
||||
'fills': {'shape': (n_instruments,), 'low': 0, 'high': np.inf},
|
||||
'exposures': {'shape': (n_instruments,), 'low': 0, 'high': np.inf},
|
||||
}
|
||||
if self.cfg.include_position:
|
||||
space['position'] = {'shape': (n_instruments,), 'low': -np.inf, 'high': np.inf}
|
||||
if include_market:
|
||||
space['competitor_quotes'] = {'shape': (n_instruments,), 'low': 0, 'high': np.inf}
|
||||
return space
|
||||
285
lab/outlet/platform.py
Normal file
285
lab/outlet/platform.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Main simulation platform orchestrating the Quote-Control loop.
|
||||
|
||||
The Platform class is the central coordinator that:
|
||||
1. Receives pricing actions (quotes) from the agent
|
||||
2. Generates arrivals via the ArrivalModel
|
||||
3. Processes executions via Mechanism and ExecutionModel
|
||||
4. Applies position censorship via PositionModel
|
||||
5. Computes metrics and reward via Objective
|
||||
6. Returns censored observations
|
||||
|
||||
Example:
|
||||
>>> from lab.config import make_retail_platform
|
||||
>>> platform = make_retail_platform()
|
||||
>>> result = platform.reset(seed=42)
|
||||
>>> result = platform.step(platform.instruments.refs * 1.1)
|
||||
>>> print(f"PnL: {result.metrics.pnl:.2f}")
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
from .types import (Quote, Opportunity, Execution, InstrumentSet, StepLogs, StepMetrics,
|
||||
StepEvent, MarketState, HiddenState, Observation, StepResult)
|
||||
from .constants import LogLevel, EventType, Side
|
||||
from .protocols import Mechanism, ArrivalModel, ExecutionModel, PositionModel, MarketModel, ObservationBuilder, Objective
|
||||
from .stock import PositionModel as DefaultPositionModel, PositionConfig
|
||||
from .observation import DefaultObservationBuilder, ObservationConfig
|
||||
from .objectives.factory import retail_objective
|
||||
|
||||
@dataclass
|
||||
class PlatformConfig:
|
||||
"""Configuration for the simulation platform.
|
||||
|
||||
Attributes:
|
||||
n_instruments: Number of instruments in the simulation
|
||||
max_steps: Maximum steps before episode terminates
|
||||
dt: Time duration per step (affects arrival rates)
|
||||
log_level: Verbosity of logging (NONE, AGG_ONLY, FULL)
|
||||
mask_demand: If True, observations exclude true demand (research mode)
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
n_instruments: int = 10
|
||||
max_steps: int = 1000
|
||||
dt: float = 1.0
|
||||
log_level: LogLevel = LogLevel.AGG_ONLY
|
||||
mask_demand: bool = True
|
||||
seed: int | None = None
|
||||
|
||||
class Platform:
|
||||
"""Main simulation orchestrator implementing Quote -> Arrival -> Execution -> Position.
|
||||
|
||||
The Platform coordinates all components to simulate a pricing environment:
|
||||
- Mechanism: validates quotes and determines execution logic
|
||||
- ArrivalModel: generates demand opportunities
|
||||
- ExecutionModel: computes acceptance probabilities
|
||||
- PositionModel: manages inventory/position and censorship
|
||||
- MarketModel: updates competitor/market state
|
||||
- ObservationBuilder: constructs censored observations
|
||||
- Objective: computes reward from metrics
|
||||
|
||||
Attributes:
|
||||
instruments: The instrument set being priced
|
||||
mechanism: Quote validation and execution mechanism
|
||||
arrival: Demand arrival generator
|
||||
execution: Acceptance probability model
|
||||
position: Inventory/position manager
|
||||
market: Competitor/market dynamics (optional)
|
||||
obs_builder: Observation constructor
|
||||
objective: Reward function
|
||||
cfg: Platform configuration
|
||||
"""
|
||||
|
||||
def __init__(self, instruments: InstrumentSet, mechanism: Mechanism,
|
||||
arrival: ArrivalModel, execution: ExecutionModel,
|
||||
position: PositionModel | None = None,
|
||||
market: MarketModel | None = None,
|
||||
obs_builder: ObservationBuilder | None = None,
|
||||
objective: Objective | None = None,
|
||||
cfg: PlatformConfig | None = None):
|
||||
self.instruments = instruments
|
||||
self.mechanism = mechanism
|
||||
self.arrival = arrival
|
||||
self.execution = execution
|
||||
self.position = position or DefaultPositionModel(PositionConfig())
|
||||
self.market = market
|
||||
self.obs_builder = obs_builder or DefaultObservationBuilder()
|
||||
self.objective = objective or retail_objective()
|
||||
self.cfg = cfg or PlatformConfig(n_instruments=instruments.n)
|
||||
|
||||
self._t: int = 0
|
||||
self._rng: np.random.Generator = np.random.default_rng(self.cfg.seed)
|
||||
self._quote: Quote | None = None
|
||||
self._market_state: MarketState | None = None
|
||||
self._hidden: HiddenState = HiddenState()
|
||||
self._prev_prices: np.ndarray | None = None
|
||||
|
||||
def reset(self, seed: int | None = None) -> StepResult:
|
||||
"""Reset the platform to initial state.
|
||||
|
||||
Args:
|
||||
seed: Random seed (overrides config seed if provided)
|
||||
|
||||
Returns:
|
||||
Initial StepResult with zeroed metrics and initial observation
|
||||
"""
|
||||
self._t = 0
|
||||
self._rng = np.random.default_rng(seed or self.cfg.seed)
|
||||
self._hidden = HiddenState()
|
||||
self._prev_prices = self.instruments.refs.copy()
|
||||
|
||||
# reset position
|
||||
self.position.reset(self.instruments, self._rng)
|
||||
self.instruments.position = self.position.position
|
||||
|
||||
# initial quote at reference prices
|
||||
self._quote = Quote(prices=self.instruments.refs.copy(), propensity=1.0,
|
||||
metadata={'prev_prices': self._prev_prices})
|
||||
self._quote = self.mechanism.apply_quote(self._quote, self.instruments, self._rng)
|
||||
|
||||
# initial market state
|
||||
if self.market:
|
||||
self._market_state = self.market.step(0, self._quote, self._hidden, self._rng)
|
||||
|
||||
# build initial observation
|
||||
logs = StepLogs(aggregates={'reset': True},
|
||||
true_demand=np.zeros(self.instruments.n),
|
||||
censored_fills=np.zeros(self.instruments.n))
|
||||
metrics = StepMetrics()
|
||||
obs = self.obs_builder.build(self._quote, self.instruments, logs, metrics,
|
||||
self._market_state, self._hidden, self.cfg.mask_demand, 0)
|
||||
|
||||
return StepResult(obs=obs, reward=0.0, terminated=False, truncated=False,
|
||||
info={'true_demand': logs.true_demand}, metrics=metrics,
|
||||
logs=logs, hidden=self._hidden)
|
||||
|
||||
def step(self, action: np.ndarray, propensity: float = 1.0) -> StepResult:
|
||||
"""Execute one simulation step with the given pricing action.
|
||||
|
||||
The step proceeds as follows:
|
||||
1. Apply quote constraints via mechanism
|
||||
2. Update market/competitor state
|
||||
3. Generate arrivals
|
||||
4. Process arrivals -> executions with acceptance check
|
||||
5. Apply position censorship to executions
|
||||
6. Update position state
|
||||
7. Compute metrics (PnL, costs, etc.)
|
||||
8. Build logs with propensities
|
||||
9. Construct censored observation
|
||||
10. Compute reward
|
||||
|
||||
Args:
|
||||
action: Price vector for all instruments
|
||||
propensity: P(action | behavior policy) for OPE logging
|
||||
|
||||
Returns:
|
||||
StepResult containing observation, reward, metrics, logs, and hidden state
|
||||
"""
|
||||
self._t += 1
|
||||
cfg = self.cfg
|
||||
|
||||
# 1. apply quote from action
|
||||
self._quote = Quote(prices=action, propensity=propensity,
|
||||
metadata={'prev_prices': self._prev_prices})
|
||||
self._quote = self.mechanism.apply_quote(self._quote, self.instruments, self._rng)
|
||||
self._prev_prices = self._quote.prices.copy()
|
||||
self._hidden.quote_history.append(self._quote.prices.copy())
|
||||
|
||||
# 2. update market/competitors
|
||||
if self.market:
|
||||
self._market_state = self.market.step(self._t, self._quote, self._hidden, self._rng)
|
||||
self._hidden.market_history.append(self._market_state)
|
||||
|
||||
# 3. generate arrivals
|
||||
opps = self.arrival.sample(self._t, cfg.dt, self.instruments,
|
||||
self._market_state, self._hidden, self._rng)
|
||||
|
||||
# 4. process opportunities -> executions
|
||||
executions: list[Execution] = []
|
||||
events: list[StepEvent] = []
|
||||
true_demand = np.zeros(self.instruments.n)
|
||||
|
||||
for opp in opps:
|
||||
# log exposure
|
||||
if cfg.log_level == LogLevel.FULL:
|
||||
events.append(StepEvent(t=opp.t, type=EventType.EXPOSURE,
|
||||
instrument_id=opp.instrument_id,
|
||||
opportunity_id=opp.id,
|
||||
price=float(self._quote.prices[opp.instrument_id]),
|
||||
propensity=self._quote.propensity))
|
||||
|
||||
# check acceptance
|
||||
prob = self.execution.prob(opp, self._quote, self.instruments,
|
||||
self._market_state, self._rng)
|
||||
if self._rng.random() < prob:
|
||||
# create execution
|
||||
exe = self.mechanism.process_opportunity(opp, self._quote, self.instruments,
|
||||
self._market_state, self._rng)
|
||||
if exe:
|
||||
true_demand[exe.instrument_id] += exe.size_requested
|
||||
# apply position censorship
|
||||
exe = self.position.apply_execution(exe)
|
||||
executions.append(exe)
|
||||
if cfg.log_level == LogLevel.FULL:
|
||||
events.append(StepEvent(t=exe.t, type=EventType.EXECUTION,
|
||||
instrument_id=exe.instrument_id,
|
||||
opportunity_id=exe.opportunity_id,
|
||||
price=exe.price, size=exe.size_filled,
|
||||
propensity=exe.propensity))
|
||||
|
||||
# 5. update position state
|
||||
self.position.step(self._t)
|
||||
self.instruments.position = self.position.position
|
||||
|
||||
# 6. compute metrics
|
||||
censored_fills = np.zeros(self.instruments.n)
|
||||
revenue = 0.0
|
||||
cost = 0.0
|
||||
spread_capture = 0.0
|
||||
|
||||
for exe in executions:
|
||||
censored_fills[exe.instrument_id] += exe.size_filled
|
||||
if exe.side == Side.BUY:
|
||||
revenue += exe.price * exe.size_filled
|
||||
cost += self.instruments.costs[exe.instrument_id] * exe.size_filled
|
||||
else:
|
||||
revenue -= exe.price * exe.size_filled
|
||||
cost -= self.instruments.costs[exe.instrument_id] * exe.size_filled
|
||||
# spread capture for market making
|
||||
if self._quote.spreads is not None and self._market_state and self._market_state.mid_prices is not None:
|
||||
mid = self._market_state.mid_prices[exe.instrument_id]
|
||||
if exe.side == Side.BUY:
|
||||
spread_capture += (exe.price - mid) * exe.size_filled
|
||||
else:
|
||||
spread_capture += (mid - exe.price) * exe.size_filled
|
||||
|
||||
pnl = revenue - cost
|
||||
units = float(np.sum(censored_fills))
|
||||
lost = float(np.sum(true_demand - censored_fills))
|
||||
|
||||
# volatility
|
||||
volatility = 0.0
|
||||
if len(self._hidden.quote_history) > 1:
|
||||
prev = self._hidden.quote_history[-2]
|
||||
volatility = float(np.mean(np.abs(self._quote.prices - prev) / (prev + 1e-8)))
|
||||
|
||||
metrics = StepMetrics(
|
||||
pnl=pnl, revenue=revenue, cost=cost, units_traded=units,
|
||||
position_cost=self.position.holding_cost,
|
||||
lost_opportunity=self.position.shortage_cost + lost * np.mean(self._quote.prices) * 0.1,
|
||||
spread_capture=spread_capture, volatility=volatility,
|
||||
conversion=units / (len(opps) + 1e-8),
|
||||
per_instrument={'fills': censored_fills, 'demand': true_demand}
|
||||
)
|
||||
|
||||
# 7. build logs
|
||||
logs = StepLogs(
|
||||
events=events if cfg.log_level == LogLevel.FULL else None,
|
||||
executions=executions if cfg.log_level == LogLevel.FULL else None,
|
||||
aggregates={'n_arrivals': len(opps), 'n_executions': len(executions),
|
||||
'exposures': np.bincount([o.instrument_id for o in opps],
|
||||
minlength=self.instruments.n).astype(float)},
|
||||
true_demand=true_demand,
|
||||
censored_fills=censored_fills
|
||||
)
|
||||
|
||||
# 8. build observation
|
||||
obs = self.obs_builder.build(self._quote, self.instruments, logs, metrics,
|
||||
self._market_state, self._hidden, cfg.mask_demand, self._t)
|
||||
|
||||
# 9. compute reward
|
||||
reward = self.objective.reward(self._quote, self.instruments, metrics, self._hidden, obs)
|
||||
breakdown = self.objective.breakdown(self._quote, self.instruments, metrics, self._hidden, obs)
|
||||
# print(f"Step {self._t}: Reward={reward:.2f}, Breakdown={breakdown}")
|
||||
|
||||
|
||||
# 10. check termination
|
||||
terminated = self._t >= cfg.max_steps
|
||||
truncated = False
|
||||
|
||||
info = {'true_demand': true_demand, 'breakdown': self.objective.breakdown(
|
||||
self._quote, self.instruments, metrics, self._hidden, obs)}
|
||||
|
||||
return StepResult(obs=obs, reward=reward, terminated=terminated, truncated=truncated,
|
||||
info=info, metrics=metrics, logs=logs, hidden=self._hidden)
|
||||
297
lab/outlet/protocols.py
Normal file
297
lab/outlet/protocols.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Protocol definitions for pluggable simulator components.
|
||||
|
||||
This module defines the interfaces (Protocols) that allow swapping different
|
||||
implementations for each stage of the Quote -> Arrival -> Execution -> Position
|
||||
pipeline. All protocols use structural subtyping (duck typing).
|
||||
|
||||
Protocols:
|
||||
Mechanism: How quotes translate to executions (posted price, two-sided, auction)
|
||||
ArrivalModel: How opportunities arrive (Poisson, Hawkes, sessions)
|
||||
ExecutionModel: Acceptance probability given quote (elasticity, intensity)
|
||||
PositionModel: Inventory/position management and censorship
|
||||
MarketModel: Competitor/market dynamics
|
||||
ObservationBuilder: Constructs agent observations with censoring
|
||||
Objective: Computes reward from metrics
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Protocol, Any, TYPE_CHECKING
|
||||
import numpy as np
|
||||
if TYPE_CHECKING:
|
||||
from .types import (Quote, Opportunity, Execution, InstrumentSet, StepLogs,
|
||||
StepMetrics, HiddenState, Observation, MarketState)
|
||||
from .constants import LogLevel
|
||||
|
||||
class Mechanism(Protocol):
|
||||
"""Defines how quotes translate to executions.
|
||||
|
||||
The Mechanism is the core abstraction that differentiates pricing domains:
|
||||
- PostedPrice: single price, buyer decides to purchase or not
|
||||
- TwoSided: bid/ask spread, execution depends on distance from mid
|
||||
- Auction: reserve price affects win probability and clearing price
|
||||
|
||||
Methods:
|
||||
apply_quote: Enforce constraints and return valid quote
|
||||
process_opportunity: Determine execution given opportunity and quote
|
||||
"""
|
||||
def apply_quote(self, quote: Quote, instruments: InstrumentSet,
|
||||
rng: np.random.Generator) -> Quote:
|
||||
"""Apply mechanism-specific constraints to a quote.
|
||||
|
||||
Args:
|
||||
quote: Raw quote from policy
|
||||
instruments: Current instrument set with costs/refs
|
||||
rng: Random generator for stochastic constraints
|
||||
|
||||
Returns:
|
||||
Constrained quote satisfying mechanism rules (min margin, max delta, etc.)
|
||||
"""
|
||||
...
|
||||
|
||||
def process_opportunity(self, opp: Opportunity, quote: Quote,
|
||||
instruments: InstrumentSet, market: MarketState | None,
|
||||
rng: np.random.Generator) -> Execution | None:
|
||||
"""Process an opportunity against the current quote.
|
||||
|
||||
Args:
|
||||
opp: Incoming opportunity (session, order, request)
|
||||
quote: Current posted quote
|
||||
instruments: Instrument set
|
||||
market: Current market state (competitor prices, mid-prices)
|
||||
rng: Random generator
|
||||
|
||||
Returns:
|
||||
Execution if opportunity converts, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
class ArrivalModel(Protocol):
|
||||
"""Generates opportunities (demand arrivals) for each step.
|
||||
|
||||
Different arrival models capture different demand dynamics:
|
||||
- Poisson: constant rate, memoryless
|
||||
- Hawkes: self-exciting, clustered arrivals
|
||||
- Session: retail browsing with multi-product views
|
||||
|
||||
Methods:
|
||||
sample: Generate opportunities for a time interval
|
||||
"""
|
||||
def sample(self, t: float, dt: float, instruments: InstrumentSet,
|
||||
market: MarketState | None, hidden: HiddenState,
|
||||
rng: np.random.Generator) -> list[Opportunity]:
|
||||
"""Sample opportunities for time interval [t, t+dt).
|
||||
|
||||
Args:
|
||||
t: Current time
|
||||
dt: Time interval length
|
||||
instruments: Available instruments
|
||||
market: Current market state
|
||||
hidden: Hidden state (contains demand intensity, contamination)
|
||||
rng: Random generator
|
||||
|
||||
Returns:
|
||||
List of opportunities arriving in this interval
|
||||
"""
|
||||
...
|
||||
|
||||
class ExecutionModel(Protocol):
|
||||
"""Computes acceptance/execution probability given quote and context.
|
||||
|
||||
Different models capture different demand responses:
|
||||
- Elasticity: price sensitivity with competitor cross-effects
|
||||
- Intensity: distance-based fill probability (market making)
|
||||
- Logit: discrete choice model
|
||||
|
||||
Methods:
|
||||
prob: Compute acceptance probability
|
||||
uncensor: Estimate true demand from censored fills
|
||||
"""
|
||||
def prob(self, opp: Opportunity, quote: Quote, instruments: InstrumentSet,
|
||||
market: MarketState | None, rng: np.random.Generator) -> float:
|
||||
"""Compute probability that opportunity accepts the quote.
|
||||
|
||||
Args:
|
||||
opp: Opportunity to evaluate
|
||||
quote: Current quote
|
||||
instruments: Instrument set
|
||||
market: Market state (competitor prices affect cross-elasticity)
|
||||
rng: Random generator
|
||||
|
||||
Returns:
|
||||
Probability in [0, 1] that opportunity executes
|
||||
"""
|
||||
...
|
||||
|
||||
def uncensor(self, fills: np.ndarray, instruments: InstrumentSet,
|
||||
context: dict[str, Any] | None = None) -> np.ndarray:
|
||||
"""Estimate true demand from censored fills.
|
||||
|
||||
Used for demand estimation research under inventory censorship.
|
||||
|
||||
Args:
|
||||
fills: Observed (censored) fill counts
|
||||
instruments: Instrument set
|
||||
context: Additional context (exposures, prices shown)
|
||||
|
||||
Returns:
|
||||
Estimated true demand counts
|
||||
"""
|
||||
...
|
||||
|
||||
class PositionModel(Protocol):
|
||||
"""Manages inventory (retail) or position (finance).
|
||||
|
||||
Handles:
|
||||
- Position constraints and censorship
|
||||
- Holding costs (retail) or inventory risk (finance)
|
||||
- Replenishment and order receipt
|
||||
|
||||
Methods:
|
||||
reset: Initialize position state
|
||||
available: Query available capacity for a trade
|
||||
apply_execution: Censor execution by available position
|
||||
step: Process time-based updates (replenishment, holding cost)
|
||||
|
||||
Properties:
|
||||
position: Current position vector
|
||||
holding_cost: Cost incurred this step from holding position
|
||||
"""
|
||||
def reset(self, instruments: InstrumentSet, rng: np.random.Generator) -> None:
|
||||
"""Initialize position state for new episode."""
|
||||
...
|
||||
|
||||
def available(self, instrument_id: int, side: Any) -> float:
|
||||
"""Query available capacity for a trade.
|
||||
|
||||
Args:
|
||||
instrument_id: Which instrument
|
||||
side: BUY or SELL
|
||||
|
||||
Returns:
|
||||
Maximum tradeable size given current position
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_execution(self, exe: Execution) -> Execution:
|
||||
"""Apply position constraints to an execution.
|
||||
|
||||
Args:
|
||||
exe: Proposed execution with size_requested
|
||||
|
||||
Returns:
|
||||
Censored execution with size_filled <= available capacity
|
||||
"""
|
||||
...
|
||||
|
||||
def step(self, t: float) -> None:
|
||||
"""Process time-based position updates.
|
||||
|
||||
Handles replenishment receipt, holding cost calculation, etc.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def position(self) -> np.ndarray:
|
||||
"""Current position vector (positive=long/inventory, negative=short)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def holding_cost(self) -> float:
|
||||
"""Holding cost incurred this step."""
|
||||
...
|
||||
|
||||
class MarketModel(Protocol):
|
||||
"""Models external market dynamics and competitor behavior.
|
||||
|
||||
For retail: competitor price dynamics (static, reactive, stochastic)
|
||||
For finance: mid-price process (GBM, mean-reverting)
|
||||
|
||||
Methods:
|
||||
step: Update market state given agent's quotes
|
||||
"""
|
||||
def step(self, t: float, self_quotes: Quote, hidden: HiddenState,
|
||||
rng: np.random.Generator) -> MarketState:
|
||||
"""Update market state for this timestep.
|
||||
|
||||
Args:
|
||||
t: Current time
|
||||
self_quotes: Agent's current quotes (competitors may react)
|
||||
hidden: Hidden state (regime info)
|
||||
rng: Random generator
|
||||
|
||||
Returns:
|
||||
Updated market state with competitor prices, mid-prices, volatility
|
||||
"""
|
||||
...
|
||||
|
||||
class ObservationBuilder(Protocol):
|
||||
"""Constructs agent observations with appropriate censoring.
|
||||
|
||||
Critical for research: ensures agent only sees censored fills,
|
||||
never true demand (which goes in info dict).
|
||||
|
||||
Methods:
|
||||
build: Construct observation from step data
|
||||
"""
|
||||
def build(self, quote: Quote, instruments: InstrumentSet, logs: StepLogs,
|
||||
metrics: StepMetrics, market: MarketState | None,
|
||||
hidden: HiddenState, mask_demand: bool, t: int) -> Observation:
|
||||
"""Build observation for agent.
|
||||
|
||||
Args:
|
||||
quote: Current quote
|
||||
instruments: Instrument set with positions
|
||||
logs: Step logs with true_demand and censored_fills
|
||||
metrics: Computed metrics
|
||||
market: Market state
|
||||
hidden: Hidden state (not included in obs)
|
||||
mask_demand: If True, exclude true demand from observation
|
||||
t: Current timestep
|
||||
|
||||
Returns:
|
||||
Observation containing only observable quantities
|
||||
"""
|
||||
...
|
||||
|
||||
class Objective(Protocol):
|
||||
"""Computes reward from step metrics.
|
||||
|
||||
Supports composite objectives with weighted terms:
|
||||
- PnL (profit)
|
||||
- Position costs (holding, inventory risk)
|
||||
- Lost opportunity (stockouts)
|
||||
- Volatility penalty (UX)
|
||||
- Spread capture (market making)
|
||||
|
||||
Methods:
|
||||
reward: Compute scalar reward
|
||||
breakdown: Get per-term contribution for analysis
|
||||
"""
|
||||
def reward(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState,
|
||||
obs: Observation) -> float:
|
||||
"""Compute scalar reward for this step.
|
||||
|
||||
Args:
|
||||
quote: Current quote
|
||||
instruments: Instrument set
|
||||
metrics: Step metrics (pnl, costs, etc.)
|
||||
hidden: Hidden state
|
||||
obs: Agent observation
|
||||
|
||||
Returns:
|
||||
Scalar reward value
|
||||
"""
|
||||
...
|
||||
|
||||
def breakdown(self, quote: Quote, instruments: InstrumentSet,
|
||||
metrics: StepMetrics, hidden: HiddenState,
|
||||
obs: Observation) -> dict[str, float]:
|
||||
"""Get reward breakdown by component.
|
||||
|
||||
Useful for analyzing which terms dominate the reward.
|
||||
|
||||
Returns:
|
||||
Dict mapping term names to their contributions
|
||||
"""
|
||||
...
|
||||
151
lab/outlet/stock.py
Normal file
151
lab/outlet/stock.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Inventory/position management and instrument factories.
|
||||
|
||||
This module provides:
|
||||
- PositionConfig: Configuration for position constraints and costs
|
||||
- PositionModel: Manages inventory (retail) or position (finance)
|
||||
- make_instruments: Factory for creating instrument sets
|
||||
|
||||
The PositionModel handles demand censorship by limiting executions
|
||||
to available inventory, computing holding costs, and managing replenishment.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
import numpy as np
|
||||
from .types import Instrument, InstrumentSet, Execution
|
||||
from .constants import Side, InstrumentType
|
||||
|
||||
@dataclass
|
||||
class PositionConfig:
|
||||
"""Configuration for position/inventory management.
|
||||
|
||||
Attributes:
|
||||
initial_position: Starting inventory (None = unlimited, float = same for all)
|
||||
max_position: Maximum long position per instrument
|
||||
min_position: Maximum short position (negative, for finance)
|
||||
holding_cost_rate: Cost per unit per step for holding inventory
|
||||
shortage_cost_rate: Opportunity cost rate for stockouts
|
||||
lead_time: Steps until replenishment orders arrive
|
||||
"""
|
||||
initial_position: np.ndarray | float | None = None
|
||||
max_position: float = 1000.0
|
||||
min_position: float = -1000.0
|
||||
holding_cost_rate: float = 0.001
|
||||
shortage_cost_rate: float = 0.05
|
||||
lead_time: int = 0
|
||||
|
||||
@dataclass
|
||||
class PositionModel:
|
||||
"""Manages inventory (retail) or position (finance) with censorship.
|
||||
|
||||
Key responsibilities:
|
||||
- Track current position per instrument
|
||||
- Censor executions when position is insufficient
|
||||
- Compute holding costs per step
|
||||
- Track shortage/stockout costs
|
||||
- Handle replenishment orders with lead time
|
||||
|
||||
For retail: position is inventory (positive), selling reduces it
|
||||
For finance: position can be positive (long) or negative (short)
|
||||
"""
|
||||
cfg: PositionConfig
|
||||
n: int = 0
|
||||
_position: np.ndarray = field(default_factory=lambda: np.array([]))
|
||||
_pending_orders: list[tuple[int, np.ndarray]] = field(default_factory=list)
|
||||
_step_holding_cost: float = 0.0
|
||||
_step_shortage_cost: float = 0.0
|
||||
|
||||
def reset(self, instruments: InstrumentSet, rng: np.random.Generator) -> None:
|
||||
self.n = instruments.n
|
||||
if self.cfg.initial_position is None:
|
||||
self._position = np.full(self.n, np.inf) # unlimited
|
||||
elif isinstance(self.cfg.initial_position, (int, float)):
|
||||
self._position = np.full(self.n, float(self.cfg.initial_position))
|
||||
else:
|
||||
self._position = self.cfg.initial_position.copy().astype(np.float64)
|
||||
self._pending_orders = []
|
||||
self._step_holding_cost = 0.0
|
||||
self._step_shortage_cost = 0.0
|
||||
|
||||
def available(self, instrument_id: int, side: Side) -> float:
|
||||
pos = self._position[instrument_id]
|
||||
if np.isinf(pos): return np.inf
|
||||
if side == Side.BUY:
|
||||
return max(0, pos) # can sell up to current inventory
|
||||
else:
|
||||
return max(0, self.cfg.max_position - pos) # can buy up to max
|
||||
|
||||
def apply_execution(self, exe: Execution) -> Execution:
|
||||
idx = int(exe.instrument_id)
|
||||
avail = self.available(idx, exe.side)
|
||||
filled = min(exe.size_requested, avail)
|
||||
shortage = exe.size_requested - filled
|
||||
|
||||
if exe.side == Side.BUY:
|
||||
self._position[idx] -= filled # sold from inventory
|
||||
else:
|
||||
self._position[idx] += filled # bought into inventory
|
||||
|
||||
if shortage > 0:
|
||||
self._step_shortage_cost += shortage * exe.price * self.cfg.shortage_cost_rate
|
||||
|
||||
return Execution(
|
||||
opportunity_id=exe.opportunity_id, instrument_id=exe.instrument_id,
|
||||
side=exe.side, size_requested=exe.size_requested,
|
||||
size_filled=filled, price=exe.price, propensity=exe.propensity, t=exe.t
|
||||
)
|
||||
|
||||
def order(self, quantity: np.ndarray) -> None:
|
||||
if self.cfg.lead_time > 0:
|
||||
self._pending_orders.append((self.cfg.lead_time, quantity.copy()))
|
||||
else:
|
||||
self._position += quantity
|
||||
|
||||
def step(self, t: float) -> None:
|
||||
# compute holding cost
|
||||
pos = np.where(np.isinf(self._position), 0, self._position)
|
||||
self._step_holding_cost = float(np.sum(np.abs(pos)) * self.cfg.holding_cost_rate)
|
||||
|
||||
# receive pending orders
|
||||
new_pending = []
|
||||
for (remaining, qty) in self._pending_orders:
|
||||
if remaining <= 1:
|
||||
self._position += qty
|
||||
else:
|
||||
new_pending.append((remaining - 1, qty))
|
||||
self._pending_orders = new_pending
|
||||
|
||||
@property
|
||||
def position(self) -> np.ndarray:
|
||||
return np.where(np.isinf(self._position), -1, self._position)
|
||||
|
||||
@property
|
||||
def holding_cost(self) -> float:
|
||||
return self._step_holding_cost
|
||||
|
||||
@property
|
||||
def shortage_cost(self) -> float:
|
||||
return self._step_shortage_cost
|
||||
|
||||
def make_instruments(n: int, cost_range: tuple[float, float] = (1.0, 10.0),
|
||||
margin_range: tuple[float, float] = (0.2, 0.5),
|
||||
inst_type: InstrumentType = InstrumentType.SKU,
|
||||
rng: np.random.Generator | None = None) -> InstrumentSet:
|
||||
"""Factory function to create a random instrument set.
|
||||
|
||||
Args:
|
||||
n: Number of instruments to create
|
||||
cost_range: (min, max) for uniform cost sampling
|
||||
margin_range: (min, max) for uniform margin sampling
|
||||
inst_type: Type of instruments (SKU, ASSET, etc.)
|
||||
rng: Random generator (uses default if None)
|
||||
|
||||
Returns:
|
||||
InstrumentSet with n instruments having random costs and margins
|
||||
"""
|
||||
rng = rng or np.random.default_rng()
|
||||
costs = rng.uniform(*cost_range, n)
|
||||
margins = rng.uniform(*margin_range, n)
|
||||
items = [Instrument(id=i, type=inst_type, cost_basis=c, reference_price=c*(1+m))
|
||||
for i, (c, m) in enumerate(zip(costs, margins))]
|
||||
return InstrumentSet(instruments=items)
|
||||
318
lab/outlet/types.py
Normal file
318
lab/outlet/types.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Core data types for the Quote-Control simulator.
|
||||
|
||||
This module defines the fundamental data structures used throughout the platform:
|
||||
- Identifiers (InstrumentId, OpportunityId, AgentId)
|
||||
- Domain objects (Instrument, Quote, Opportunity, Execution)
|
||||
- Logging structures (StepEvent, StepLogs, StepMetrics)
|
||||
- State containers (MarketState, HiddenState, Observation, StepResult)
|
||||
|
||||
All dataclasses are designed to be serializable and numpy-compatible.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, NewType
|
||||
import numpy as np
|
||||
from .constants import Side, InstrumentType, OpportunityType, EventType
|
||||
|
||||
InstrumentId = NewType('InstrumentId', int) # unique instrument index
|
||||
OpportunityId = NewType('OpportunityId', str) # unique opportunity/session ID
|
||||
AgentId = NewType('AgentId', str) # unique agent/actor ID
|
||||
|
||||
@dataclass
|
||||
class Instrument:
|
||||
"""Represents a priceable entity in the simulation.
|
||||
|
||||
An instrument can be a retail SKU, financial asset, loan product, or subscription.
|
||||
The cost_basis represents the fundamental value (marginal cost for retail,
|
||||
mid-price for assets, funding rate for loans).
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for this instrument
|
||||
type: Category of instrument (SKU, ASSET, LOAN, SUBSCRIPTION)
|
||||
cost_basis: Fundamental cost or value (marginal cost, mid-price, funding rate)
|
||||
reference_price: Base or fair price used for action scaling
|
||||
attrs: Additional attributes (quality score, category, volatility, etc.)
|
||||
"""
|
||||
id: InstrumentId
|
||||
type: InstrumentType
|
||||
cost_basis: float
|
||||
reference_price: float
|
||||
attrs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class InstrumentSet:
|
||||
"""Collection of instruments with optional position tracking.
|
||||
|
||||
Provides vectorized access to instrument properties for efficient computation.
|
||||
Position can be positive (long/inventory) or negative (short) for financial assets.
|
||||
|
||||
Attributes:
|
||||
instruments: List of Instrument objects
|
||||
position: Current position per instrument (None = unlimited capacity)
|
||||
|
||||
Properties:
|
||||
n: Number of instruments
|
||||
costs: Vector of cost bases
|
||||
refs: Vector of reference prices
|
||||
"""
|
||||
instruments: list[Instrument]
|
||||
position: np.ndarray | None = None
|
||||
|
||||
@property
|
||||
def n(self) -> int: return len(self.instruments)
|
||||
@property
|
||||
def costs(self) -> np.ndarray: return np.array([i.cost_basis for i in self.instruments], np.float32)
|
||||
@property
|
||||
def refs(self) -> np.ndarray: return np.array([i.reference_price for i in self.instruments], np.float32)
|
||||
|
||||
@dataclass
|
||||
class Quote:
|
||||
"""Price quote set by the policy - the action in the MDP.
|
||||
|
||||
Supports multiple quoting mechanisms:
|
||||
- Posted price: only `prices` field used
|
||||
- Two-sided: `prices` as mid, `spreads` for bid-ask width
|
||||
- Auction: `prices` as reserve prices
|
||||
|
||||
The propensity field is critical for off-policy evaluation (OPE).
|
||||
|
||||
Attributes:
|
||||
prices: Posted prices (retail) or mid-quotes (market making)
|
||||
spreads: Bid-ask spread width for two-sided quoting (None for posted price)
|
||||
propensity: P(this quote | behavior policy) for importance sampling
|
||||
metadata: Additional info (prev_prices for delta constraints, etc.)
|
||||
|
||||
Properties:
|
||||
bids: Computed bid prices (mid - spread/2)
|
||||
asks: Computed ask prices (mid + spread/2)
|
||||
"""
|
||||
prices: np.ndarray
|
||||
spreads: np.ndarray | None = None
|
||||
propensity: float = 1.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def bids(self) -> np.ndarray | None:
|
||||
return self.prices - self.spreads/2 if self.spreads is not None else None
|
||||
@property
|
||||
def asks(self) -> np.ndarray | None:
|
||||
return self.prices + self.spreads/2 if self.spreads is not None else None
|
||||
|
||||
@dataclass
|
||||
class Opportunity:
|
||||
"""An arrival event that may result in a transaction.
|
||||
|
||||
Opportunities are the demand side of the simulation:
|
||||
- Retail: browsing session with purchase intent
|
||||
- Market making: incoming market order
|
||||
- Lending: loan application
|
||||
|
||||
The context dict carries segment/type information used by execution models.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for this opportunity
|
||||
type: Category (SESSION, MARKET_ORDER, REQUEST)
|
||||
side: BUY or SELL intent
|
||||
instrument_id: Which instrument the opportunity targets
|
||||
size: Requested transaction size (units, shares, principal)
|
||||
t: Arrival timestamp
|
||||
context: Segment info (is_scraper, credit_score, urgency, etc.)
|
||||
"""
|
||||
id: OpportunityId
|
||||
type: OpportunityType
|
||||
side: Side
|
||||
instrument_id: InstrumentId
|
||||
size: float = 1.0
|
||||
t: float = 0.0
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class Execution:
|
||||
"""A realized transaction after acceptance and position censorship.
|
||||
|
||||
The difference between size_requested and size_filled represents
|
||||
censored demand due to inventory/position constraints.
|
||||
|
||||
Attributes:
|
||||
opportunity_id: Links back to the originating Opportunity
|
||||
instrument_id: Which instrument was traded
|
||||
side: BUY or SELL
|
||||
size_requested: Original requested size (true demand)
|
||||
size_filled: Actual filled size after censorship
|
||||
price: Execution price
|
||||
propensity: Combined propensity for OPE (quote * acceptance)
|
||||
t: Execution timestamp
|
||||
"""
|
||||
opportunity_id: OpportunityId
|
||||
instrument_id: InstrumentId
|
||||
side: Side
|
||||
size_requested: float
|
||||
size_filled: float
|
||||
price: float
|
||||
propensity: float = 1.0
|
||||
t: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class StepEvent:
|
||||
"""Generic logged event"""
|
||||
t: float
|
||||
type: EventType
|
||||
instrument_id: InstrumentId | None = None
|
||||
opportunity_id: OpportunityId | None = None
|
||||
price: float | None = None
|
||||
size: float | None = None
|
||||
propensity: float = 1.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class StepLogs:
|
||||
"""Container for all logging data from a simulation step.
|
||||
|
||||
Supports both detailed event logging (for OPE) and aggregate-only mode
|
||||
(for fast simulation). The true_demand vs censored_fills distinction
|
||||
is critical for research on demand estimation under censorship.
|
||||
|
||||
Attributes:
|
||||
events: Detailed event log (None if LogLevel != FULL)
|
||||
executions: List of executed transactions (None if LogLevel != FULL)
|
||||
aggregates: Always-available aggregate statistics
|
||||
true_demand: Oracle demand before censorship (for research, not in obs)
|
||||
censored_fills: Realized fills after position constraints (observable)
|
||||
"""
|
||||
events: list[StepEvent] | None = None
|
||||
executions: list[Execution] | None = None
|
||||
aggregates: dict[str, Any] = field(default_factory=dict)
|
||||
true_demand: np.ndarray | None = None
|
||||
censored_fills: np.ndarray | None = None
|
||||
|
||||
@dataclass
|
||||
class StepMetrics:
|
||||
"""Computed metrics for a single simulation step.
|
||||
|
||||
Metrics are domain-aware: retail uses revenue/cost/holding_cost,
|
||||
market making uses spread_capture and inventory risk.
|
||||
|
||||
Attributes:
|
||||
pnl: Profit and loss (revenue - cost for retail, mark-to-market for finance)
|
||||
revenue: Gross revenue from sales/executions
|
||||
cost: Cost of goods sold or position acquisition cost
|
||||
units_traded: Total units/shares transacted
|
||||
position_cost: Holding cost (retail) or inventory risk penalty (finance)
|
||||
lost_opportunity: Cost of stockouts or missed fills
|
||||
spread_capture: Bid-ask spread captured (market making)
|
||||
volatility: Price volatility metric for UX consideration
|
||||
conversion: Fill rate (executions / opportunities)
|
||||
per_instrument: Per-instrument breakdowns (fills, demand, etc.)
|
||||
"""
|
||||
pnl: float = 0.0
|
||||
revenue: float = 0.0
|
||||
cost: float = 0.0
|
||||
units_traded: float = 0.0
|
||||
position_cost: float = 0.0
|
||||
lost_opportunity: float = 0.0
|
||||
spread_capture: float = 0.0
|
||||
volatility: float = 0.0
|
||||
conversion: float = 0.0
|
||||
per_instrument: dict[str, np.ndarray] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class MarketState:
|
||||
"""External market conditions and competitor state.
|
||||
|
||||
For retail: competitor_quotes drives cross-elasticity effects.
|
||||
For finance: mid_prices and volatility drive execution dynamics.
|
||||
|
||||
Attributes:
|
||||
competitor_quotes: Competitor posted prices (retail)
|
||||
mid_prices: Market mid-prices for assets (finance)
|
||||
volatility: Per-instrument volatility estimate
|
||||
regime: Market regime identifier (normal, price_war, high_vol, etc.)
|
||||
t: Timestamp of this market state
|
||||
"""
|
||||
competitor_quotes: np.ndarray | None = None
|
||||
mid_prices: np.ndarray | None = None
|
||||
volatility: np.ndarray | None = None
|
||||
regime: str = 'normal'
|
||||
t: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class HiddenState:
|
||||
"""Internal simulator state not exposed to the agent.
|
||||
|
||||
Contains oracle information for research analysis and
|
||||
history needed for non-stationary dynamics.
|
||||
|
||||
Attributes:
|
||||
true_demand_intensity: Latent demand multiplier
|
||||
contamination: Fraction of arrivals that are adversarial/scraper
|
||||
regime: Current market/competitor regime
|
||||
quote_history: History of agent quotes for volatility calculation
|
||||
market_history: History of market states for analysis
|
||||
"""
|
||||
true_demand_intensity: float = 1.0
|
||||
contamination: float = 0.0
|
||||
regime: str = 'normal'
|
||||
quote_history: list[np.ndarray] = field(default_factory=list)
|
||||
market_history: list[MarketState] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class Observation:
|
||||
"""Observable state provided to the agent - censored view only.
|
||||
|
||||
Critical invariant: Observation never contains true_demand, only
|
||||
censored fills. This enforces the censorship research setting.
|
||||
|
||||
Attributes:
|
||||
quotes: Current posted quotes (the agent's last action)
|
||||
position: Current inventory/position state
|
||||
fills: Censored execution counts per instrument
|
||||
exposures: Opportunity exposure counts per instrument
|
||||
market: Observable market state (competitor prices, volatility)
|
||||
t: Current timestep
|
||||
extra: Additional observable features
|
||||
|
||||
Methods:
|
||||
to_flat: Flatten to numpy array for gym compatibility
|
||||
"""
|
||||
quotes: np.ndarray
|
||||
position: np.ndarray | None
|
||||
fills: np.ndarray
|
||||
exposures: np.ndarray
|
||||
market: MarketState | None
|
||||
t: int
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_flat(self) -> np.ndarray:
|
||||
"""Flatten observation to 1D numpy array for gym environments."""
|
||||
parts = [self.quotes, self.fills, self.exposures]
|
||||
if self.position is not None: parts.append(self.position)
|
||||
if self.market and self.market.competitor_quotes is not None:
|
||||
parts.append(self.market.competitor_quotes)
|
||||
return np.concatenate([p.flatten() for p in parts])
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
"""Complete result from a simulation step.
|
||||
|
||||
Follows gymnasium convention for obs, reward, terminated, truncated, info.
|
||||
Additionally provides metrics, logs, and hidden state for research.
|
||||
|
||||
Attributes:
|
||||
obs: Observable state (censored)
|
||||
reward: Scalar reward from objective function
|
||||
terminated: Episode ended naturally (max_steps reached)
|
||||
truncated: Episode ended early (bankruptcy, constraint violation)
|
||||
info: Additional info dict (contains true_demand for research)
|
||||
metrics: Computed metrics for this step
|
||||
logs: Event logs and aggregates
|
||||
hidden: Internal simulator state (oracle info)
|
||||
"""
|
||||
obs: Observation
|
||||
reward: float
|
||||
terminated: bool
|
||||
truncated: bool
|
||||
info: dict[str, Any]
|
||||
metrics: StepMetrics
|
||||
logs: StepLogs
|
||||
hidden: HiddenState
|
||||
Reference in New Issue
Block a user