mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
chor: fixing cross product missing data
This commit is contained in:
@@ -88,17 +88,28 @@ if __name__ == "__main__":
|
||||
interaction_data = interaction_pipeline.fit_transform(None)
|
||||
price_data = price_data_pipeline.fit_transform(None)
|
||||
|
||||
price_elasticity = elasticity_pipeline(interaction_data, price_data, window_size="30s")
|
||||
price_elasticity = price_elasticity['elasticity'].values if price_elasticity is not None and not price_elasticity.empty else np.array([])
|
||||
elasticity_df = elasticity_pipeline(interaction_data, price_data, window_size="30s")
|
||||
|
||||
price_data = price_data['price'].values if not price_data.empty else np.array([])
|
||||
# align elasticity with price data by productId, fill missing with 0
|
||||
if not price_data.empty and elasticity_df is not None and not elasticity_df.empty:
|
||||
price_data_merged = price_data.merge(
|
||||
elasticity_df[['productId', 'elasticity']],
|
||||
on='productId',
|
||||
how='left'
|
||||
).fillna({'elasticity': 0.0})
|
||||
|
||||
print(price_elasticity)
|
||||
print(price_data)
|
||||
prices = price_data_merged['price'].values
|
||||
elasticities = price_data_merged['elasticity'].values
|
||||
else:
|
||||
prices = np.array([])
|
||||
elasticities = np.array([])
|
||||
|
||||
print(elasticities)
|
||||
print(prices)
|
||||
|
||||
state_space = StateSpace(
|
||||
demand=price_elasticity,
|
||||
prices=price_data,
|
||||
demand=elasticities,
|
||||
prices=prices,
|
||||
session_features=interaction_data
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user