mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
adding naive jax and libraries and make adjustments
This commit is contained in:
70
engine/lib/discrete.py
Normal file
70
engine/lib/discrete.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from collections import defaultdict
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DiscretePriceActionWrapper(gym.ActionWrapper):
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
n_levels: int = 9,
|
||||
min_scale: float = 0.8,
|
||||
max_scale: float = 1.2,
|
||||
):
|
||||
super().__init__(env)
|
||||
self.scales = np.linspace(min_scale, max_scale, n_levels, dtype=np.float32)
|
||||
self.action_space = spaces.Discrete(n_levels)
|
||||
|
||||
def action(self, action: int):
|
||||
scale = float(self.scales[int(action)])
|
||||
cur = np.asarray(self.env.unwrapped._prices, dtype=np.float32)
|
||||
lo, hi = self.env.unwrapped.price_bounds
|
||||
return np.clip(cur * scale, lo, hi).astype(np.float32)
|
||||
|
||||
|
||||
class EventQTable:
|
||||
def __init__(
|
||||
self,
|
||||
n_actions: int,
|
||||
n_products: int,
|
||||
price_bounds: tuple,
|
||||
lr: float = 0.1,
|
||||
gamma: float = 0.99,
|
||||
n_bins: int = 6,
|
||||
):
|
||||
self.n_actions = int(n_actions)
|
||||
self.n_products = int(n_products)
|
||||
self.lr = float(lr)
|
||||
self.gamma = float(gamma)
|
||||
self.q = defaultdict(lambda: np.zeros(self.n_actions, dtype=np.float32))
|
||||
lo, hi = price_bounds
|
||||
self.demand_bins = np.linspace(0.0, 100.0, n_bins + 1)[1:-1]
|
||||
self.price_bins = np.linspace(lo, hi, n_bins + 1)[1:-1]
|
||||
|
||||
def encode(self, obs: np.ndarray) -> tuple:
|
||||
obs = np.asarray(obs, dtype=np.float32)
|
||||
d = obs[: self.n_products]
|
||||
p = obs[self.n_products : 2 * self.n_products]
|
||||
d_mean = float(np.mean(d)) if d.size else 0.0
|
||||
d_std = float(np.std(d)) if d.size else 0.0
|
||||
p_mean = float(np.mean(p)) if p.size else 0.0
|
||||
return (
|
||||
int(np.digitize(d_mean, self.demand_bins)),
|
||||
int(np.digitize(d_std, self.demand_bins)),
|
||||
int(np.digitize(p_mean, self.price_bins)),
|
||||
)
|
||||
|
||||
def act(self, obs: np.ndarray, eps: float = 0.0) -> tuple[int, tuple]:
|
||||
s = self.encode(obs)
|
||||
if np.random.random() < eps:
|
||||
return int(np.random.randint(self.n_actions)), s
|
||||
return int(np.argmax(self.q[s])), s
|
||||
|
||||
def update(self, s: tuple, a: int, r: float, s2: tuple, done: bool):
|
||||
target = r + (0.0 if done else self.gamma * float(np.max(self.q[s2])))
|
||||
self.q[s][a] += self.lr * (target - self.q[s][a])
|
||||
|
||||
def predict(self, obs: np.ndarray, deterministic: bool = True):
|
||||
a, _ = self.act(obs, 0.0 if deterministic else 0.05)
|
||||
return a, None
|
||||
Reference in New Issue
Block a user