minor refactors to codebase to implement DRO

This commit is contained in:
2026-02-14 14:53:30 +01:00
parent 895eea5674
commit bc6c481d03
6 changed files with 195 additions and 75 deletions

View File

@@ -6,11 +6,26 @@ from .lib import EconomicMetricsWrapper, MetricsCallback
wandb.init(
project="phantom-pricing",
config={"alpha": 0.3, "n_products": 10, "total_timesteps": 50000}
config={
"alpha": 0.3,
"n_products": 10,
"total_timesteps": 50000,
"robust_radius": 0.15,
"robust_points": 5,
"lambda_coi": 0.2,
},
)
env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
eval_env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
env_kwargs = {
"n_products": 10,
"alpha": 0.3,
"lambda_coi": 0.2,
"robust_radius": 0.15,
"robust_points": 5,
"render_mode": None,
}
env = EconomicMetricsWrapper(PHANTOM(**env_kwargs))
eval_env = EconomicMetricsWrapper(PHANTOM(**env_kwargs))
model = SAC(
"MultiInputPolicy",
@@ -31,11 +46,12 @@ model.save("phantom_sac")
wandb.finish()
# test trained policy
env = PHANTOM(n_products=10, alpha=0.3, render_mode=None)
env = PHANTOM(**env_kwargs)
obs, _ = env.reset()
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, term, trunc, _ = env.step(action)
env.render()
if term or trunc: break
if term or trunc:
break
env.close()