import pandas as pd
from statsmodels.tsa.arima.model import ARIMA
import matplotlib.pyplot as plt
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.tsa.stattools import adfuller
from sklearn.metrics import mean_absolute_error, mean_squared_error
import numpy as np
print("Case Study: Profit Forecast using ARIMA")
def load_data(file_path):
"""
Load historical profit data from a CSV file into a pandas DataFrame.
"""
return pd.read_csv(file_path)
def preprocess_data(data):
"""
Preprocess the profit data:
- Convert date column to datetime format
- Check for missing values
"""
data['Date'] = pd.to_datetime(data['Date'])
if data.isnull().values.any():
print("Warning: Missing values found in the data.")
return data
def fit_arima_model(data, p, d, q):
"""
Fit ARIMA model to the profit data.
"""
model = ARIMA(data, order=(p, d, q))
model_fit = model.fit()
return model_fit
def evaluate_model(model_fit, test_data):
"""
Evaluate ARIMA model's performance using Mean Absolute Error (MAE), Mean Squared Error (MSE), and Root Mean Squared Error (RMSE).
"""
forecast = model_fit.forecast(steps=len(test_data))
mae = mean_absolute_error(test_data, forecast)
mse = mean_squared_error(test_data, forecast)
rmse = np.sqrt(mse)
return mae, mse, rmse
def forecast_profits(model_fit, periods):
"""
Forecast future profits using the ARIMA model.
"""
forecast = model_fit.forecast(steps=periods)
return forecast
def visualize_data(data):
"""
Visualize historical profit data.
"""
plt.figure(figsize=(10, 6))
plt.plot(data['Date'], data['Profit'], label='Historical Profit Data')
plt.xlabel('Date')
plt.ylabel('Profit')
plt.title('Historical Profit Data')
plt.legend()
plt.show()
def main():
file_path = 'profit_data.csv'
profit_data = load_data(file_path)
profit_data = preprocess_data(profit_data)
p, d, q = 2, 1, 2
model_fit = fit_arima_model(profit_data['Profit'], p, d, q)
test_data = profit_data['Profit'][-12:]
mae, mse, rmse = evaluate_model(model_fit, test_data)
print("Evaluation Metrics:")
print("Mean Absolute Error (MAE):", mae)
print("Mean Squared Error (MSE):", mse)
print("Root Mean Squared Error (RMSE):", rmse)
forecast_period = 12
forecast = forecast_profits(model_fit, forecast_period)
visualize_data(profit_data)
if __name__ == "__main__":
main()