sol4r

Demand Forecast 2

Aug 18th, 2024 (edited)
159
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.99 KB | None | 0 0
  1. import pandas as pd
  2. import numpy as np
  3. from sklearn.model_selection import train_test_split
  4. from statsmodels.tsa.statespace.sarimax import SARIMAX
  5. from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
  6. import joblib  # For saving and loading the model
  7. import warnings
  8. from datetime import timedelta
  9. from bokeh.plotting import figure, show
  10. from bokeh.io import output_file
  11. from bokeh.layouts import gridplot
  12. from bokeh.models import ColumnDataSource
  13. from keras.models import Sequential
  14. from keras.layers import LSTM, Dense
  15. from keras.callbacks import EarlyStopping
  16. from sklearn.preprocessing import StandardScaler
  17. from sklearn.model_selection import cross_val_score
  18.  
  19. # Suppress warnings
  20. warnings.filterwarnings("ignore")
  21.  
  22.  
  23. # Function to load data
  24. def load_data(data_path):
  25.     df = pd.read_csv(data_path)
  26.     df['transaction_date'] = pd.to_datetime(df['transaction_date'], format='%Y-%m-%d')
  27.     df['year_month'] = df['transaction_date'].dt.to_period('M')
  28.     return df
  29.  
  30.  
  31. # Function to engineer features
  32. def engineer_features(df):
  33.     # Seasonal decomposition
  34.     decomposition = sm.tsa.seasonal_decompose(df['quantity'], model='additive', period=12)
  35.     df['trend'] = decomposition.trend
  36.     df['seasonal'] = decomposition.seasonal
  37.     df['residual'] = decomposition.resid
  38.  
  39.     # Exponential Smoothing
  40.     df['es'] = df['quantity'].ewm(span=12).mean()
  41.  
  42.     return df
  43.  
  44.  
  45. # Function to prepare the dataset
  46. def prepare_data(df):
  47.     monthly_demand = df.groupby(['item_id', 'year_month'])['quantity'].sum().reset_index()
  48.     monthly_demand['month'] = monthly_demand['year_month'].dt.month
  49.     monthly_demand['year'] = monthly_demand['year_month'].dt.year
  50.  
  51.     # Create lag features
  52.     for lag in range(1, 5):
  53.         monthly_demand[f'lag_{lag}'] = monthly_demand['quantity'].shift(lag)
  54.  
  55.     monthly_demand.dropna(inplace=True)
  56.  
  57.     # Set the index to a DatetimeIndex with a freq
  58.     monthly_demand['date'] = monthly_demand['year_month'].apply(lambda x: x.to_timestamp())
  59.     monthly_demand.set_index('date', inplace=True)
  60.     monthly_demand = engineer_features(monthly_demand)
  61.  
  62.     return monthly_demand
  63.  
  64.  
  65. # Function to train SARIMA model
  66. def train_sarima_model(X_train, y_train):
  67.     model = SARIMAX(y_train, order=(1,1,1), seasonal_order=(1,1,1,12))
  68.     model_fit = model.fit(disp=False)
  69.     return model_fit
  70.  
  71.  
  72. # Function to train LSTM model
  73. def train_lstm_model(X_train, y_train):
  74.     X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
  75.     model = Sequential()
  76.     model.add(LSTM(50, input_shape=(X_train.shape[1], X_train.shape[2])))
  77.     model.add(Dense(1))
  78.     model.compile(loss='mean_squared_error', optimizer='adam')
  79.     early_stopping = EarlyStopping(monitor='loss', patience=5, min_delta=0.001)
  80.     model.fit(X_train, y_train, epochs=50, batch_size=1, verbose=2, callbacks=[early_stopping])
  81.     return model
  82.  
  83.  
  84. # Function to predict future demand
  85. def predict_demand(model, item_data, current_stock_level, months_ahead=6):
  86.     future_dates = pd.date_range(start=item_data.index[-1] + timedelta(days=1),
  87.                                  periods=months_ahead, freq='M')
  88.     future_df = pd.DataFrame({'year_month': future_dates})
  89.  
  90.     # Prepare future DataFrame with lagged values
  91.     for lag in range(1, 5):
  92.         future_df[f'lag_{lag}'] = item_data['quantity'].iloc[-lag] if len(item_data) >= lag else 0
  93.  
  94.     future_df['item_id'] = item_data['item_id'].iloc[0]
  95.     last_entry = item_data.iloc[-1]
  96.     future_df['month'] = last_entry['month']
  97.     future_df['year'] = last_entry['year']
  98.     future_X = future_df[['item_id', 'month', 'year', 'lag_1', 'lag_2', 'lag_3', 'lag_4']]
  99.  
  100.     if isinstance(model, SARIMAX):
  101.         future_demand = model.predict(start=len(item_data), end=len(item_data)+months_ahead-1)
  102.     else:
  103.         future_X = np.reshape(future_X, (future_X.shape[0], 1, future_X.shape[1]))
  104.         future_demand = model.predict(future_X)
  105.  
  106.     alerts = []
  107.  
  108.     for month, demand in zip(future_dates, future_demand):
  109.         reorder_amount = max(0, demand - current_stock_level)
  110.         alerts.append((month, demand, reorder_amount))
  111.  
  112.     return alerts
  113.  
  114.  
  115. # Function to plot results
  116. def plot_results(item_data, alerts, item_id):
  117.     output_file("demand_forecasting.html")
  118.  
  119.     source_actual = ColumnDataSource(data=dict(
  120.         months=item_data['year_month'].astype(str).tolist(),
  121.         actual_quantities=item_data['quantity'].tolist(),
  122.     ))
  123.  
  124.     future_dates = [alert[0] for alert in alerts]
  125.     predicted_quantities = [alert[1] for alert in alerts]
  126.  
  127.     source_predicted = ColumnDataSource(data=dict(
  128.         months=[date.strftime('%Y-%m') for date in future_dates],
  129.         predicted_quantities=predicted_quantities,
  130.     ))
  131.  
  132.     p_actual = figure(title=f"Actual Demand for Item ID {item_id}", x_axis_label='Months', y_axis_label='Quantity',
  133.                       x_range=source_actual.data['months'], height=400, width=350)
  134.     p_predicted = figure(title=f"Predicted Demand for Item ID {item_id}", x_axis_label='Months',
  135.                          y_axis_label='Quantity',
  136.                          x_range=source_predicted.data['months'], height=400, width=350)
  137.  
  138.     p_actual.line('months', 'actual_quantities', source=source_actual, line_width=2, color='green',
  139.                   legend_label="Actual Demand")
  140.     p_predicted.vbar(x='months', top='predicted_quantities', source=source_predicted, width=0.9, color='blue',
  141.                      legend_label="Predicted Demand")
  142.  
  143.     grid = gridplot([[p_actual, p_predicted]])
  144.     show(grid)
  145.  
  146.  
  147. # Function to calculate MASE
  148. def calculate_mase(y_true, y_pred, y_train):
  149.     mase = np.mean(np.abs(y_true - y_pred)) / np.mean(np.abs(y_train - np.mean(y_train)))
  150.     return mase
  151.  
  152.  
  153. # Main function to run the script
  154. def main():
  155.     data_path = 'data/orders2.csv'  # Update this path if necessary
  156.     df = load_data(data_path)
  157.     monthly_demand = prepare_data(df)
  158.  
  159.     # Split the data into features and target variable
  160.     X = monthly_demand.drop(columns=['quantity', 'year_month'])
  161.     y = monthly_demand['quantity']
  162.     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  163.  
  164.     # Train SARIMA model
  165.     sarima_model = train_sarima_model(X_train, y_train)
  166.  
  167.     # Train LSTM model
  168.     lstm_model = train_lstm_model(X_train, y_train)
  169.  
  170.     # Make predictions and evaluate
  171.     sarima_pred = sarima_model.predict(start=len(y_train), end=len(y_train) + len(y_test) - 1)
  172.     lstm_pred = lstm_model.predict(np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1])))
  173.  
  174.     print("SARIMA Model Evaluation Metrics:")
  175.     print(f"Mean Absolute Error (MAE): {mean_absolute_error(y_test, sarima_pred):.2f}")
  176.     print(f"Mean Squared Error (MSE): {mean_squared_error(y_test, sarima_pred):.2f}")
  177.     print(f"R² Score: {r2_score(y_test, sarima_pred):.2f}")
  178.     print(f"MASE: {calculate_mase(y_test, sarima_pred, y_train):.2f}")
  179.  
  180.     print("LSTM Model Evaluation Metrics:")
  181.     print(f"Mean Absolute Error (MAE): {mean_absolute_error(y_test, lstm_pred):.2f}")
  182.     print(f"Mean Squared Error (MSE): {mean_squared_error(y_test, lstm_pred):.2f}")
  183.     print(f"R² Score: {r2_score(y_test, lstm_pred):.2f}")
  184.     print(f"MASE: {calculate_mase(y_test, lstm_pred, y_train):.2f}")
  185.  
  186.     # User input for analysis
  187.     item_id = int(input("Enter the item_id you want to analyze: "))
  188.     current_stock_level = int(input("Enter the current stock level for this item: "))
  189.  
  190.     # Filter data for the selected item
  191.     item_data = monthly_demand[monthly_demand['item_id'] == item_id].copy()
  192.  
  193.     # Check for existing entries
  194.     if item_data.empty:
  195.         print(f"No historical data available for item ID {item_id}.")
  196.     else:
  197.         sarima_alerts = predict_demand(sarima_model, item_data, current_stock_level)
  198.         lstm_alerts = predict_demand(lstm_model, item_data, current_stock_level)
  199.  
  200.         for month, demand, reorder_amount in sarima_alerts:
  201.             print(f"Projected demand for item ID {item_id} in {month.strftime('%Y-%m')} is {demand:.2f}.")
  202.             if reorder_amount > 0:
  203.                 print(
  204.                     f"Alert: You may need to reorder {reorder_amount:.2f} units of item ID {item_id} by {month.strftime('%Y-%m')} as the stock might run out.")
  205.             else:
  206.                 print(f"No reorder necessary for item ID {item_id}. Sufficient stock available.")
  207.  
  208.         for month, demand, reorder_amount in lstm_alerts:
  209.             print(f"Projected demand for item ID {item_id} in {month.strftime('%Y-%m')} is {demand:.2f}.")
  210.             if reorder_amount > 0:
  211.                 print(
  212.                     f"Alert: You may need to reorder {reorder_amount:.2f} units of item ID {item_id} by {month.strftime('%Y-%m')} as the stock might run out.")
  213.             else:
  214.                 print(f"No reorder necessary for item ID {item_id}. Sufficient stock available.")
  215.  
  216.             # Plot the results
  217.         plot_results(item_data, sarima_alerts, item_id)
  218.         plot_results(item_data, lstm_alerts, item_id)
  219.  
  220. if __name__ == "__main__":
  221.     main()
Advertisement
Add Comment
Please, Sign In to add comment