sol4r

Demand Forecast

Aug 17th, 2024 (edited)
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.31 KB | None | 0 0
  1. import pandas as pd
  2. from sklearn.model_selection import train_test_split, GridSearchCV
  3. from xgboost import XGBRegressor
  4. from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
  5. import joblib  # For saving and loading the model
  6. import warnings
  7. from datetime import timedelta
  8. from bokeh.plotting import figure, show
  9. from bokeh.io import output_file
  10. from bokeh.layouts import gridplot
  11. from bokeh.models import ColumnDataSource
  12.  
  13. # Suppress warnings
  14. warnings.filterwarnings("ignore")
  15.  
  16. # Function to load data
  17. def load_data(data_path):
  18.     df = pd.read_csv(data_path)
  19.     df['transaction_date'] = pd.to_datetime(df['transaction_date'], format='%Y-%m-%d')
  20.     df['year_month'] = df['transaction_date'].dt.to_period('M')
  21.     return df
  22.  
  23. # Function to prepare the dataset
  24. def prepare_data(df):
  25.     monthly_demand = df.groupby(['item_id', 'year_month'])['quantity'].sum().reset_index()
  26.     monthly_demand['month'] = monthly_demand['year_month'].dt.month
  27.     monthly_demand['year'] = monthly_demand['year_month'].dt.year
  28.  
  29.     # Create lag features
  30.     for lag in range(1, 5):
  31.         monthly_demand[f'lag_{lag}'] = monthly_demand['quantity'].shift(lag)
  32.  
  33.     monthly_demand.dropna(inplace=True)
  34.     return monthly_demand
  35.  
  36. # Function to train the model
  37. def train_model(X_train, y_train):
  38.     param_grid = {
  39.         'n_estimators': [50, 100],
  40.         'max_depth': [3, 5],
  41.         'learning_rate': [0.1, 0.2],
  42.         'subsample': [0.8, 1.0]
  43.     }
  44.     grid_search = GridSearchCV(XGBRegressor(random_state=42), param_grid, cv=5, scoring='r2', n_jobs=-1)
  45.     grid_search.fit(X_train, y_train)
  46.     return grid_search.best_estimator_
  47.  
  48. # Function to predict future demand
  49. def predict_demand(best_model, item_data, current_stock_level, months_ahead=6):
  50.     future_dates = pd.date_range(start=item_data['year_month'].iloc[-1].end_time + timedelta(days=1),
  51.                                  periods=months_ahead, freq='M')
  52.     future_df = pd.DataFrame({'year_month': future_dates})
  53.  
  54.     # Prepare future DataFrame with lagged values
  55.     for lag in range(1, 5):
  56.         future_df[f'lag_{lag}'] = item_data['quantity'].iloc[-lag] if len(item_data) >= lag else 0
  57.  
  58.     future_df['item_id'] = item_data['item_id'].iloc[0]
  59.     last_entry = item_data.iloc[-1]
  60.     future_df['month'] = last_entry['month']
  61.     future_df['year'] = last_entry['year']
  62.     future_X = future_df[['item_id', 'month', 'year', 'lag_1', 'lag_2', 'lag_3', 'lag_4']]
  63.  
  64.     future_demand = best_model.predict(future_X)
  65.     alerts = []
  66.  
  67.     for month, demand in zip(future_dates, future_demand):
  68.         reorder_amount = max(0, demand - current_stock_level)
  69.         alerts.append((month, demand, reorder_amount))
  70.  
  71.     return alerts
  72.  
  73. # Function to plot results
  74. def plot_results(item_data, alerts, item_id):
  75.     output_file("demand_forecasting.html")
  76.  
  77.     source_actual = ColumnDataSource(data=dict(
  78.         months=item_data['year_month'].astype(str).tolist(),
  79.         actual_quantities=item_data['quantity'].tolist(),
  80.     ))
  81.  
  82.     future_dates = [alert[0] for alert in alerts]
  83.     predicted_quantities = [alert[1] for alert in alerts]
  84.  
  85.     source_predicted = ColumnDataSource(data=dict(
  86.         months=[date.strftime('%Y-%m') for date in future_dates],
  87.         predicted_quantities=predicted_quantities,
  88.     ))
  89.  
  90.     p_actual = figure(title=f"Actual Demand for Item ID {item_id}", x_axis_label='Months', y_axis_label='Quantity',
  91.                       x_range=source_actual.data['months'], height=400, width=350)
  92.     p_predicted = figure(title=f"Predicted Demand for Item ID {item_id}", x_axis_label='Months',
  93.                          y_axis_label='Quantity',
  94.                          x_range=source_predicted.data['months'], height=400, width=350)
  95.  
  96.     p_actual.line('months', 'actual_quantities', source=source_actual, line_width=2, color='green',
  97.                   legend_label="Actual Demand")
  98.     p_predicted.vbar(x='months', top='predicted_quantities', source=source_predicted, width=0.9, color='blue',
  99.                      legend_label="Predicted Demand")
  100.  
  101.     grid = gridplot([[p_actual, p_predicted]])
  102.     show(grid)
  103.  
  104. # Main function to run the script
  105. def main():
  106.     data_path = 'data/orders2.csv'  # Update this path if necessary
  107.     df = load_data(data_path)
  108.     monthly_demand = prepare_data(df)
  109.  
  110.     # Split the data into features and target variable
  111.     X = monthly_demand.drop(columns=['quantity', 'year_month'])
  112.     y = monthly_demand['quantity']
  113.     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  114.  
  115.     # Check if the model already exists
  116.     try:
  117.         best_model = joblib.load('best_xgb_model.joblib')
  118.         print("Loaded pre-trained model.")
  119.     except FileNotFoundError:
  120.         best_model = train_model(X_train, y_train)
  121.         joblib.dump(best_model, 'best_xgb_model.joblib')
  122.         print("Trained and saved model.")
  123.  
  124.     # Make predictions and evaluate
  125.     y_pred = best_model.predict(X_test)
  126.     print("Model Evaluation Metrics:")
  127.     print(f"Mean Absolute Error (MAE): {mean_absolute_error(y_test, y_pred):.2f}")
  128.     print(f"Mean Squared Error (MSE): {mean_squared_error(y_test, y_pred):.2f}")
  129.     print(f"R² Score: {r2_score(y_test, y_pred):.2f}")
  130.  
  131.     # User input for analysis
  132.     item_id = int(input("Enter the item_id you want to analyze: "))
  133.     current_stock_level = int(input("Enter the current stock level for this item: "))
  134.  
  135.     # Filter data for the selected item
  136.     item_data = monthly_demand[monthly_demand['item_id'] == item_id].copy()
  137.  
  138.     # Check for existing entries
  139.     if item_data.empty:
  140.         print(f"No historical data available for item ID {item_id}.")
  141.     else:
  142.         alerts = predict_demand(best_model, item_data, current_stock_level)
  143.         for month, demand, reorder_amount in alerts:
  144.             print(f"Projected demand for item ID {item_id} in {month.strftime('%Y-%m')} is {demand:.2f}.")
  145.             if reorder_amount > 0:
  146.                 print(
  147.                     f"Alert: You may need to reorder {reorder_amount:.2f} units of item ID {item_id} by {month.strftime('%Y-%m')} as the stock might run out.")
  148.             else:
  149.                 print(f"No reorder necessary for item ID {item_id}. Sufficient stock available.")
  150.  
  151.         # Plot the results
  152.         plot_results(item_data, alerts, item_id)
  153.  
  154. if __name__ == "__main__":
  155.     main()
  156.  
Advertisement
Add Comment
Please, Sign In to add comment