mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
minor refactors to codebase to implement DRO
This commit is contained in:
@@ -6,11 +6,26 @@ from .lib import EconomicMetricsWrapper, MetricsCallback
|
||||
|
||||
wandb.init(
|
||||
project="phantom-pricing",
|
||||
config={"alpha": 0.3, "n_products": 10, "total_timesteps": 50000}
|
||||
config={
|
||||
"alpha": 0.3,
|
||||
"n_products": 10,
|
||||
"total_timesteps": 50000,
|
||||
"robust_radius": 0.15,
|
||||
"robust_points": 5,
|
||||
"lambda_coi": 0.2,
|
||||
},
|
||||
)
|
||||
|
||||
env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
|
||||
eval_env = EconomicMetricsWrapper(PHANTOM(n_products=10, alpha=0.3, render_mode=None))
|
||||
env_kwargs = {
|
||||
"n_products": 10,
|
||||
"alpha": 0.3,
|
||||
"lambda_coi": 0.2,
|
||||
"robust_radius": 0.15,
|
||||
"robust_points": 5,
|
||||
"render_mode": None,
|
||||
}
|
||||
env = EconomicMetricsWrapper(PHANTOM(**env_kwargs))
|
||||
eval_env = EconomicMetricsWrapper(PHANTOM(**env_kwargs))
|
||||
|
||||
model = SAC(
|
||||
"MultiInputPolicy",
|
||||
@@ -31,11 +46,12 @@ model.save("phantom_sac")
|
||||
wandb.finish()
|
||||
|
||||
# test trained policy
|
||||
env = PHANTOM(n_products=10, alpha=0.3, render_mode=None)
|
||||
env = PHANTOM(**env_kwargs)
|
||||
obs, _ = env.reset()
|
||||
for _ in range(100):
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, reward, term, trunc, _ = env.step(action)
|
||||
env.render()
|
||||
if term or trunc: break
|
||||
if term or trunc:
|
||||
break
|
||||
env.close()
|
||||
|
||||
Reference in New Issue
Block a user