Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import pandas as pd
- import numpy as np
- from sklearn.model_selection import train_test_split
- from statsmodels.tsa.statespace.sarimax import SARIMAX
- from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
- import joblib # For saving and loading the model
- import warnings
- from datetime import timedelta
- from bokeh.plotting import figure, show
- from bokeh.io import output_file
- from bokeh.layouts import gridplot
- from bokeh.models import ColumnDataSource
- from keras.models import Sequential
- from keras.layers import LSTM, Dense
- from keras.callbacks import EarlyStopping
- from sklearn.preprocessing import StandardScaler
- from sklearn.model_selection import cross_val_score
- # Suppress warnings
- warnings.filterwarnings("ignore")
- # Function to load data
- def load_data(data_path):
- df = pd.read_csv(data_path)
- df['transaction_date'] = pd.to_datetime(df['transaction_date'], format='%Y-%m-%d')
- df['year_month'] = df['transaction_date'].dt.to_period('M')
- return df
- # Function to engineer features
- def engineer_features(df):
- # Seasonal decomposition
- decomposition = sm.tsa.seasonal_decompose(df['quantity'], model='additive', period=12)
- df['trend'] = decomposition.trend
- df['seasonal'] = decomposition.seasonal
- df['residual'] = decomposition.resid
- # Exponential Smoothing
- df['es'] = df['quantity'].ewm(span=12).mean()
- return df
- # Function to prepare the dataset
- def prepare_data(df):
- monthly_demand = df.groupby(['item_id', 'year_month'])['quantity'].sum().reset_index()
- monthly_demand['month'] = monthly_demand['year_month'].dt.month
- monthly_demand['year'] = monthly_demand['year_month'].dt.year
- # Create lag features
- for lag in range(1, 5):
- monthly_demand[f'lag_{lag}'] = monthly_demand['quantity'].shift(lag)
- monthly_demand.dropna(inplace=True)
- # Set the index to a DatetimeIndex with a freq
- monthly_demand['date'] = monthly_demand['year_month'].apply(lambda x: x.to_timestamp())
- monthly_demand.set_index('date', inplace=True)
- monthly_demand = engineer_features(monthly_demand)
- return monthly_demand
- # Function to train SARIMA model
- def train_sarima_model(X_train, y_train):
- model = SARIMAX(y_train, order=(1,1,1), seasonal_order=(1,1,1,12))
- model_fit = model.fit(disp=False)
- return model_fit
- # Function to train LSTM model
- def train_lstm_model(X_train, y_train):
- X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
- model = Sequential()
- model.add(LSTM(50, input_shape=(X_train.shape[1], X_train.shape[2])))
- model.add(Dense(1))
- model.compile(loss='mean_squared_error', optimizer='adam')
- early_stopping = EarlyStopping(monitor='loss', patience=5, min_delta=0.001)
- model.fit(X_train, y_train, epochs=50, batch_size=1, verbose=2, callbacks=[early_stopping])
- return model
- # Function to predict future demand
- def predict_demand(model, item_data, current_stock_level, months_ahead=6):
- future_dates = pd.date_range(start=item_data.index[-1] + timedelta(days=1),
- periods=months_ahead, freq='M')
- future_df = pd.DataFrame({'year_month': future_dates})
- # Prepare future DataFrame with lagged values
- for lag in range(1, 5):
- future_df[f'lag_{lag}'] = item_data['quantity'].iloc[-lag] if len(item_data) >= lag else 0
- future_df['item_id'] = item_data['item_id'].iloc[0]
- last_entry = item_data.iloc[-1]
- future_df['month'] = last_entry['month']
- future_df['year'] = last_entry['year']
- future_X = future_df[['item_id', 'month', 'year', 'lag_1', 'lag_2', 'lag_3', 'lag_4']]
- if isinstance(model, SARIMAX):
- future_demand = model.predict(start=len(item_data), end=len(item_data)+months_ahead-1)
- else:
- future_X = np.reshape(future_X, (future_X.shape[0], 1, future_X.shape[1]))
- future_demand = model.predict(future_X)
- alerts = []
- for month, demand in zip(future_dates, future_demand):
- reorder_amount = max(0, demand - current_stock_level)
- alerts.append((month, demand, reorder_amount))
- return alerts
- # Function to plot results
- def plot_results(item_data, alerts, item_id):
- output_file("demand_forecasting.html")
- source_actual = ColumnDataSource(data=dict(
- months=item_data['year_month'].astype(str).tolist(),
- actual_quantities=item_data['quantity'].tolist(),
- ))
- future_dates = [alert[0] for alert in alerts]
- predicted_quantities = [alert[1] for alert in alerts]
- source_predicted = ColumnDataSource(data=dict(
- months=[date.strftime('%Y-%m') for date in future_dates],
- predicted_quantities=predicted_quantities,
- ))
- p_actual = figure(title=f"Actual Demand for Item ID {item_id}", x_axis_label='Months', y_axis_label='Quantity',
- x_range=source_actual.data['months'], height=400, width=350)
- p_predicted = figure(title=f"Predicted Demand for Item ID {item_id}", x_axis_label='Months',
- y_axis_label='Quantity',
- x_range=source_predicted.data['months'], height=400, width=350)
- p_actual.line('months', 'actual_quantities', source=source_actual, line_width=2, color='green',
- legend_label="Actual Demand")
- p_predicted.vbar(x='months', top='predicted_quantities', source=source_predicted, width=0.9, color='blue',
- legend_label="Predicted Demand")
- grid = gridplot([[p_actual, p_predicted]])
- show(grid)
- # Function to calculate MASE
- def calculate_mase(y_true, y_pred, y_train):
- mase = np.mean(np.abs(y_true - y_pred)) / np.mean(np.abs(y_train - np.mean(y_train)))
- return mase
- # Main function to run the script
- def main():
- data_path = 'data/orders2.csv' # Update this path if necessary
- df = load_data(data_path)
- monthly_demand = prepare_data(df)
- # Split the data into features and target variable
- X = monthly_demand.drop(columns=['quantity', 'year_month'])
- y = monthly_demand['quantity']
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
- # Train SARIMA model
- sarima_model = train_sarima_model(X_train, y_train)
- # Train LSTM model
- lstm_model = train_lstm_model(X_train, y_train)
- # Make predictions and evaluate
- sarima_pred = sarima_model.predict(start=len(y_train), end=len(y_train) + len(y_test) - 1)
- lstm_pred = lstm_model.predict(np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1])))
- print("SARIMA Model Evaluation Metrics:")
- print(f"Mean Absolute Error (MAE): {mean_absolute_error(y_test, sarima_pred):.2f}")
- print(f"Mean Squared Error (MSE): {mean_squared_error(y_test, sarima_pred):.2f}")
- print(f"R² Score: {r2_score(y_test, sarima_pred):.2f}")
- print(f"MASE: {calculate_mase(y_test, sarima_pred, y_train):.2f}")
- print("LSTM Model Evaluation Metrics:")
- print(f"Mean Absolute Error (MAE): {mean_absolute_error(y_test, lstm_pred):.2f}")
- print(f"Mean Squared Error (MSE): {mean_squared_error(y_test, lstm_pred):.2f}")
- print(f"R² Score: {r2_score(y_test, lstm_pred):.2f}")
- print(f"MASE: {calculate_mase(y_test, lstm_pred, y_train):.2f}")
- # User input for analysis
- item_id = int(input("Enter the item_id you want to analyze: "))
- current_stock_level = int(input("Enter the current stock level for this item: "))
- # Filter data for the selected item
- item_data = monthly_demand[monthly_demand['item_id'] == item_id].copy()
- # Check for existing entries
- if item_data.empty:
- print(f"No historical data available for item ID {item_id}.")
- else:
- sarima_alerts = predict_demand(sarima_model, item_data, current_stock_level)
- lstm_alerts = predict_demand(lstm_model, item_data, current_stock_level)
- for month, demand, reorder_amount in sarima_alerts:
- print(f"Projected demand for item ID {item_id} in {month.strftime('%Y-%m')} is {demand:.2f}.")
- if reorder_amount > 0:
- print(
- f"Alert: You may need to reorder {reorder_amount:.2f} units of item ID {item_id} by {month.strftime('%Y-%m')} as the stock might run out.")
- else:
- print(f"No reorder necessary for item ID {item_id}. Sufficient stock available.")
- for month, demand, reorder_amount in lstm_alerts:
- print(f"Projected demand for item ID {item_id} in {month.strftime('%Y-%m')} is {demand:.2f}.")
- if reorder_amount > 0:
- print(
- f"Alert: You may need to reorder {reorder_amount:.2f} units of item ID {item_id} by {month.strftime('%Y-%m')} as the stock might run out.")
- else:
- print(f"No reorder necessary for item ID {item_id}. Sufficient stock available.")
- # Plot the results
- plot_results(item_data, sarima_alerts, item_id)
- plot_results(item_data, lstm_alerts, item_id)
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment