Commit f10b52bf authored by Barthelet Thibault's avatar Barthelet Thibault
Browse files

added tirex

parent d64ceb70
Loading
Loading
Loading
Loading
+13.9 KiB (71.4 KiB)
Loading image diff...
+169 KiB (742 KiB)
Loading image diff...
+83 −7
Original line number Diff line number Diff line
%% Cell type:code id:9441f141 tags:

``` python
pip install prophet
```

%% Cell type:code id:9c0e1517 tags:

``` python
"""
Forecasting Strategy Comparison
Compares six approaches:
1. Naive: Use Year 10 demand as-is for Year 11
2. Trend-based: Year 10 + average monthly year-over-year growth
3. Autoregressor: Linear model using past 12 months to predict next month
4. Prophet: Facebook's forecasting tool with automatic seasonality detection
5. TiRex: Zero-shot neural forecasting model with xLSTM
6. Ensemble (Mean All): Average of all five approaches
"""

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error
from sklearn.linear_model import LinearRegression
from prophet import Prophet
from tirex import load_model, ForecastModel
import warnings
warnings.filterwarnings('ignore')


def load_data(train_path: str, test_path: str):
    """Load training and test data"""
    with open(train_path) as f:
        train_json = json.load(f)
    with open(test_path) as f:
        test_json = json.load(f)

    train_rows = []
    for month in train_json['historical_data']:
        for watch in month['watches']:
            train_rows.append({
                'year': month['year'],
                'month': month['month'],
                'watch_id': watch['watch_id'],
                'watch_name': watch['watch_name'],
                'demand': watch['demand']
            })

    test_rows = []
    for month in test_json:
        for watch in month['watches']:
            test_rows.append({
                'year': month['year'],
                'month': month['month'],
                'watch_id': watch['watch_id'],
                'watch_name': watch['watch_name'],
                'demand': watch['demand']
            })

    return pd.DataFrame(train_rows), pd.DataFrame(test_rows)


def naive_forecast(train_df: pd.DataFrame):
    """Strategy 1: Just use Year 10 values"""
    predictions = {}

    for watch_id in train_df['watch_id'].unique():
        predictions[watch_id] = {}
        year_10 = train_df[(train_df['watch_id'] == watch_id) & (train_df['year'] == 10)]

        for month in range(1, 13):
            demand = year_10[year_10['month'] == month]['demand'].values[0]
            predictions[watch_id][month] = demand

    return predictions


def calculate_monthly_trends(train_df: pd.DataFrame):
    """Calculate average year-over-year change for each month"""
    trends = {}

    for watch_id in train_df['watch_id'].unique():
        watch_data = train_df[train_df['watch_id'] == watch_id]
        trends[watch_id] = {}

        for month in range(1, 13):
            month_data = watch_data[watch_data['month'] == month].sort_values('year')
            demands = month_data['demand'].values

            if len(demands) > 1:
                yoy_diffs = np.diff(demands)
                avg_change = np.mean(yoy_diffs)
            else:
                avg_change = 0

            trends[watch_id][month] = avg_change

    return trends


def trend_forecast(train_df: pd.DataFrame, trends: dict):
    """Strategy 2: Year 10 + average trend"""
    predictions = {}

    for watch_id in train_df['watch_id'].unique():
        predictions[watch_id] = {}
        year_10 = train_df[(train_df['watch_id'] == watch_id) & (train_df['year'] == 10)]

        for month in range(1, 13):
            year_10_demand = year_10[year_10['month'] == month]['demand'].values[0]
            avg_change = trends[watch_id][month]
            predicted_demand = year_10_demand + avg_change
            predictions[watch_id][month] = max(0, round(predicted_demand))

    return predictions


def autoregressor_forecast(train_df: pd.DataFrame, lags: int = 12):
    """Strategy 3: Autoregressor using past N months"""
    predictions = {}
    models = {}

    for watch_id in train_df['watch_id'].unique():
        watch_data = train_df[train_df['watch_id'] == watch_id].sort_values(['year', 'month'])
        demands = watch_data['demand'].values

        # Create lagged features
        X_train = []
        y_train = []

        for i in range(lags, len(demands)):
            X_train.append(demands[i-lags:i])
            y_train.append(demands[i])

        X_train = np.array(X_train)
        y_train = np.array(y_train)

        # Train linear regression model
        model = LinearRegression()
        model.fit(X_train, y_train)
        models[watch_id] = model

        # Predict year 11 iteratively
        predictions[watch_id] = {}
        history = list(demands[-lags:])

        for month in range(1, 13):
            X_pred = np.array([history[-lags:]]).reshape(1, -1)
            pred = model.predict(X_pred)[0]
            pred = max(0, round(pred))

            predictions[watch_id][month] = pred
            history.append(pred)

    return predictions, models


def prophet_forecast(train_df: pd.DataFrame):
    """Strategy 4: Prophet with automatic seasonality detection"""
    predictions = {}
    models = {}

    for watch_id in train_df['watch_id'].unique():
        watch_data = train_df[train_df['watch_id'] == watch_id].sort_values(['year', 'month'])

        # Prepare data for Prophet
        prophet_df = pd.DataFrame({
            'ds': pd.date_range(start='2014-01-01', periods=len(watch_data), freq='MS'),
            'y': watch_data['demand'].values
        })

        # Initialize and fit Prophet model
        model = Prophet(
            yearly_seasonality=True,
            weekly_seasonality=False,
            daily_seasonality=False,
            changepoint_prior_scale=0.05,
            seasonality_prior_scale=10.0,
        )

        model.fit(prophet_df)
        models[watch_id] = model

        # Create future dataframe for year 11
        future = model.make_future_dataframe(periods=12, freq='MS')
        forecast = model.predict(future)

        # Extract predictions for year 11
        predictions[watch_id] = {}
        year_11_forecasts = forecast.tail(12)['yhat'].values

        for month_idx, pred in enumerate(year_11_forecasts, start=1):
            predictions[watch_id][month_idx] = max(0, round(pred))

    return predictions, models


def tirex_forecast(train_df: pd.DataFrame, model: ForecastModel):
    """Strategy 5: TiRex zero-shot neural forecasting"""
    predictions = {}

    for watch_id in train_df['watch_id'].unique():
        watch_data = train_df[train_df['watch_id'] == watch_id].sort_values(['year', 'month'])
        demands = watch_data['demand'].values

        # Convert to torch tensor (1 series, full history)
        context = torch.tensor(demands, dtype=torch.float32).unsqueeze(0)

        # Forecast 12 months ahead
        with torch.no_grad():
            quantiles, mean = model.forecast(context=context, prediction_length=12)

        # Extract mean predictions
        predictions[watch_id] = {}
        mean_preds = mean.squeeze().numpy()

        for month_idx in range(12):
            pred = max(0, round(mean_preds[month_idx]))
            predictions[watch_id][month_idx + 1] = pred

    return predictions


def ensemble_forecast(naive_pred: dict, trend_pred: dict, ar_pred: dict,
                      prophet_pred: dict, tirex_pred: dict):
    """Strategy 6: Simple average of all predictions (ensemble)"""
    predictions = {}

    for watch_id in naive_pred.keys():
        predictions[watch_id] = {}

        for month in range(1, 13):
            avg_pred = (
                naive_pred[watch_id][month] +
                trend_pred[watch_id][month] +
                ar_pred[watch_id][month] +
                prophet_pred[watch_id][month] +
                tirex_pred[watch_id][month]
            ) / 5.0

            predictions[watch_id][month] = max(0, round(avg_pred))

    return predictions


def calculate_metrics(actual: np.ndarray, predicted: np.ndarray):
    """Calculate all relevant metrics"""
    rmse = np.sqrt(mean_squared_error(actual, predicted))
    mae = mean_absolute_error(actual, predicted)
    mape = mean_absolute_percentage_error(actual, predicted) * 100
    within_10pct = np.mean(np.abs(actual - predicted) / actual <= 0.10) * 100

    return {
        'RMSE': rmse,
        'MAE': mae,
        'MAPE': mape,
        'Accuracy_10%': within_10pct
    }


def evaluate_and_compare(test_df: pd.DataFrame, naive_pred: dict, trend_pred: dict,
                         ar_pred: dict, prophet_pred: dict, tirex_pred: dict,
                         ensemble_pred: dict):
    """Compare all six forecasting strategies"""
    watch_names = {1: 'Luxury Classic', 2: 'Sport Pro', 3: 'Casual Style'}

    print("\n" + "=" * 145)
    print("FORECASTING STRATEGY COMPARISON")
    print("=" * 145)

    all_results = {}
    all_actual = []
    all_naive = []
    all_trend = []
    all_ar = []
    all_prophet = []
    all_tirex = []
    all_ensemble = []

    for watch_id in sorted(test_df['watch_id'].unique()):
        watch_test = test_df[test_df['watch_id'] == watch_id].sort_values('month')

        actual = watch_test['demand'].values
        naive = np.array([naive_pred[watch_id][m] for m in range(1, 13)])
        trend = np.array([trend_pred[watch_id][m] for m in range(1, 13)])
        ar = np.array([ar_pred[watch_id][m] for m in range(1, 13)])
        prophet = np.array([prophet_pred[watch_id][m] for m in range(1, 13)])
        tirex = np.array([tirex_pred[watch_id][m] for m in range(1, 13)])
        ensemble = np.array([ensemble_pred[watch_id][m] for m in range(1, 13)])

        all_actual.extend(actual)
        all_naive.extend(naive)
        all_trend.extend(trend)
        all_ar.extend(ar)
        all_prophet.extend(prophet)
        all_tirex.extend(tirex)
        all_ensemble.extend(ensemble)

        naive_metrics = calculate_metrics(actual, naive)
        trend_metrics = calculate_metrics(actual, trend)
        ar_metrics = calculate_metrics(actual, ar)
        prophet_metrics = calculate_metrics(actual, prophet)
        tirex_metrics = calculate_metrics(actual, tirex)
        ensemble_metrics = calculate_metrics(actual, ensemble)

        all_results[watch_id] = {
            'naive': naive_metrics,
            'trend': trend_metrics,
            'ar': ar_metrics,
            'prophet': prophet_metrics,
            'tirex': tirex_metrics,
            'ensemble': ensemble_metrics
        }

        print(f"\n{watch_names[watch_id]}:")
        print(f"  {'Metric':<15} | {'Naive':>10} | {'Trend':>10} | {'AR(12)':>10} | {'Prophet':>10} | {'TiRex':>10} | {'Ensemble':>10} | {'Winner':>10}")
        print(f"  {'-'*15}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}")

        for metric in ['RMSE', 'MAE', 'MAPE', 'Accuracy_10%']:
            naive_val = naive_metrics[metric]
            trend_val = trend_metrics[metric]
            ar_val = ar_metrics[metric]
            prophet_val = prophet_metrics[metric]
            tirex_val = tirex_metrics[metric]
            ensemble_val = ensemble_metrics[metric]

            # Determine winner
            if metric == 'Accuracy_10%':
                winner = max([('Naive', naive_val), ('Trend', trend_val),
                             ('AR', ar_val), ('Prophet', prophet_val),
                             ('TiRex', tirex_val), ('Ensemble', ensemble_val)],
                            key=lambda x: x[1])[0]
            else:
                winner = min([('Naive', naive_val), ('Trend', trend_val),
                             ('AR', ar_val), ('Prophet', prophet_val),
                             ('TiRex', tirex_val), ('Ensemble', ensemble_val)],
                            key=lambda x: x[1])[0]

            if metric == 'Accuracy_10%':
                print(f"  {metric:<15} | {naive_val:9.1f}% | {trend_val:9.1f}% | {ar_val:9.1f}% | {prophet_val:9.1f}% | {tirex_val:9.1f}% | {ensemble_val:9.1f}% | {winner:>10}")
            else:
                print(f"  {metric:<15} | {naive_val:10.2f} | {trend_val:10.2f} | {ar_val:10.2f} | {prophet_val:10.2f} | {tirex_val:10.2f} | {ensemble_val:10.2f} | {winner:>10}")

    # Overall comparison
    all_actual = np.array(all_actual)
    all_naive = np.array(all_naive)
    all_trend = np.array(all_trend)
    all_ar = np.array(all_ar)
    all_prophet = np.array(all_prophet)
    all_tirex = np.array(all_tirex)
    all_ensemble = np.array(all_ensemble)

    overall_naive = calculate_metrics(all_actual, all_naive)
    overall_trend = calculate_metrics(all_actual, all_trend)
    overall_ar = calculate_metrics(all_actual, all_ar)
    overall_prophet = calculate_metrics(all_actual, all_prophet)
    overall_tirex = calculate_metrics(all_actual, all_tirex)
    overall_ensemble = calculate_metrics(all_actual, all_ensemble)

    print("\n" + "=" * 145)
    print("OVERALL PERFORMANCE (All Watches Combined)")
    print("=" * 145)
    print(f"  {'Metric':<15} | {'Naive':>10} | {'Trend':>10} | {'AR(12)':>10} | {'Prophet':>10} | {'TiRex':>10} | {'Ensemble':>10} | {'Winner':>10}")
    print(f"  {'-'*15}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}")

    wins = {'Naive': 0, 'Trend': 0, 'AR': 0, 'Prophet': 0, 'TiRex': 0, 'Ensemble': 0}

    for metric in ['RMSE', 'MAE', 'MAPE', 'Accuracy_10%']:
        naive_val = overall_naive[metric]
        trend_val = overall_trend[metric]
        ar_val = overall_ar[metric]
        prophet_val = overall_prophet[metric]
        tirex_val = overall_tirex[metric]
        ensemble_val = overall_ensemble[metric]

        if metric == 'Accuracy_10%':
            winner = max([('Naive', naive_val), ('Trend', trend_val),
                         ('AR', ar_val), ('Prophet', prophet_val),
                         ('TiRex', tirex_val), ('Ensemble', ensemble_val)],
                        key=lambda x: x[1])[0]
            print(f"  {metric:<15} | {naive_val:9.1f}% | {trend_val:9.1f}% | {ar_val:9.1f}% | {prophet_val:9.1f}% | {tirex_val:9.1f}% | {ensemble_val:9.1f}% | {winner:>10}")
        else:
            winner = min([('Naive', naive_val), ('Trend', trend_val),
                         ('AR', ar_val), ('Prophet', prophet_val),
                         ('TiRex', tirex_val), ('Ensemble', ensemble_val)],
                        key=lambda x: x[1])[0]
            print(f"  {metric:<15} | {naive_val:10.2f} | {trend_val:10.2f} | {ar_val:10.2f} | {prophet_val:10.2f} | {tirex_val:10.2f} | {ensemble_val:10.2f} | {winner:>10}")

        wins[winner] += 1

    print("=" * 145)

    # Summary verdict
    print("\nVERDICT:")
    max_wins = max(wins.values())
    winners = [k for k, v in wins.items() if v == max_wins]

    if len(winners) == 1:
        print(f"{winners[0]} wins ({wins[winners[0]]}/4 metrics)")
        if winners[0] == 'Ensemble':
            print(f"    Averaging multiple models provides best robustness")
        elif winners[0] == 'TiRex':
            print(f"    Zero-shot neural forecasting with xLSTM delivers best results")
        elif winners[0] == 'Prophet':
            print(f"    Prophet's seasonality modeling delivers best results")
        elif winners[0] == 'Trend':
            print(f"    Simple year-over-year trends work surprisingly well")
        elif winners[0] == 'AR':
            print(f"    Autoregressive patterns capture the dynamics best")
        else:
            print(f"    Sometimes the simplest approach wins")
    else:
        print(f"  ≈ Tie between: {', '.join(winners)} ({max_wins}/4 metrics each)")

    print("\n  Win Count: ", end="")
    for method in ['Naive', 'Trend', 'AR', 'Prophet', 'TiRex', 'Ensemble']:
        print(f"{method}={wins[method]}  ", end="")
    print()

    return all_results, overall_naive, overall_trend, overall_ar, overall_prophet, overall_tirex, overall_ensemble


def plot_comparison(test_df: pd.DataFrame, naive_pred: dict, trend_pred: dict,
                   ar_pred: dict, prophet_pred: dict, tirex_pred: dict,
                   ensemble_pred: dict):
    """Visualize all six forecasting strategies"""
    watch_names = {1: 'Luxury Classic', 2: 'Sport Pro', 3: 'Casual Style'}

    fig, axes = plt.subplots(3, 1, figsize=(18, 14))

    for idx, watch_id in enumerate(sorted(test_df['watch_id'].unique())):
        ax = axes[idx]
        watch_test = test_df[test_df['watch_id'] == watch_id].sort_values('month')

        months = range(1, 13)
        actual = watch_test['demand'].values
        naive = np.array([naive_pred[watch_id][m] for m in months])
        trend = np.array([trend_pred[watch_id][m] for m in months])
        ar = np.array([ar_pred[watch_id][m] for m in months])
        prophet = np.array([prophet_pred[watch_id][m] for m in months])
        tirex = np.array([tirex_pred[watch_id][m] for m in months])
        ensemble = np.array([ensemble_pred[watch_id][m] for m in months])

        # Plot ensemble first (background)
        ax.plot(months, ensemble, '-', label='Ensemble (Mean All)',
                linewidth=3, color='gray', alpha=0.5, zorder=2)

        # Plot other methods
        ax.plot(months, naive, 's--', label='Naive',
                linewidth=1.8, markersize=6, alpha=0.7, color='C0', zorder=3)
        ax.plot(months, trend, '^--', label='Trend',
                linewidth=1.8, markersize=6, alpha=0.7, color='C1', zorder=3)
        ax.plot(months, ar, 'd--', label='AR(12)',
                linewidth=1.8, markersize=6, alpha=0.7, color='C2', zorder=3)
        ax.plot(months, prophet, 'v--', label='Prophet',
                linewidth=1.8, markersize=6, alpha=0.7, color='C3', zorder=3)
        ax.plot(months, tirex, 'p--', label='TiRex',
                linewidth=1.8, markersize=6, alpha=0.7, color='C4', zorder=3)

        # Plot actual last (on top)
        ax.plot(months, actual, 'o-', label='Actual',
                linewidth=3, markersize=10, color='black', zorder=5)

        # Calculate RMSE for each
        rmse_naive = np.sqrt(mean_squared_error(actual, naive))
        rmse_trend = np.sqrt(mean_squared_error(actual, trend))
        rmse_ar = np.sqrt(mean_squared_error(actual, ar))
        rmse_prophet = np.sqrt(mean_squared_error(actual, prophet))
        rmse_tirex = np.sqrt(mean_squared_error(actual, tirex))
        rmse_ensemble = np.sqrt(mean_squared_error(actual, ensemble))

        best_rmse = min(rmse_naive, rmse_trend, rmse_ar, rmse_prophet, rmse_tirex, rmse_ensemble)
        if best_rmse == rmse_naive:
            winner = 'Naive'
        elif best_rmse == rmse_trend:
            winner = 'Trend'
        elif best_rmse == rmse_ar:
            winner = 'AR'
        elif best_rmse == rmse_prophet:
            winner = 'Prophet'
        elif best_rmse == rmse_tirex:
            winner = 'TiRex'
        else:
            winner = 'Ensemble'

        ax.set_title(f'{watch_names[watch_id]} - Year 11 Forecast\n'
                     f'RMSE: Naive={rmse_naive:.1f}, Trend={rmse_trend:.1f}, '
                     f'AR={rmse_ar:.1f}, Prophet={rmse_prophet:.1f}, TiRex={rmse_tirex:.1f}, '
                     f'Ensemble={rmse_ensemble:.1f} (Winner: {winner})',
                     fontsize=10, fontweight='bold')
        ax.set_xlabel('Month', fontsize=9)
        ax.set_ylabel('Demand (units)', fontsize=9)
        ax.legend(loc='best', fontsize=8)
        ax.grid(True, alpha=0.3)
        ax.set_xticks(months)

    plt.tight_layout()
    plt.savefig('forecast_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("\nComparison plot saved as 'forecast_comparison.png'")


def plot_error_distribution(test_df: pd.DataFrame, naive_pred: dict, trend_pred: dict,
                           ar_pred: dict, prophet_pred: dict, tirex_pred: dict,
                           ensemble_pred: dict):
    """Plot error distribution for all methods"""
    watch_names = {1: 'Luxury Classic', 2: 'Sport Pro', 3: 'Casual Style'}

    fig, axes = plt.subplots(1, 3, figsize=(20, 5))

    for idx, watch_id in enumerate(sorted(test_df['watch_id'].unique())):
        ax = axes[idx]
        watch_test = test_df[test_df['watch_id'] == watch_id].sort_values('month')

        actual = watch_test['demand'].values
        naive = np.array([naive_pred[watch_id][m] for m in range(1, 13)])
        trend = np.array([trend_pred[watch_id][m] for m in range(1, 13)])
        ar = np.array([ar_pred[watch_id][m] for m in range(1, 13)])
        prophet = np.array([prophet_pred[watch_id][m] for m in range(1, 13)])
        tirex = np.array([tirex_pred[watch_id][m] for m in range(1, 13)])
        ensemble = np.array([ensemble_pred[watch_id][m] for m in range(1, 13)])

        naive_errors = naive - actual
        trend_errors = trend - actual
        ar_errors = ar - actual
        prophet_errors = prophet - actual
        tirex_errors = tirex - actual
        ensemble_errors = ensemble - actual

        x = np.arange(12)
        width = 0.14

        ax.bar(x - 2.5*width, naive_errors, width, label='Naive', alpha=0.7, color='C0')
        ax.bar(x - 1.5*width, trend_errors, width, label='Trend', alpha=0.7, color='C1')
        ax.bar(x - 0.5*width, ar_errors, width, label='AR', alpha=0.7, color='C2')
        ax.bar(x + 0.5*width, prophet_errors, width, label='Prophet', alpha=0.7, color='C3')
        ax.bar(x + 1.5*width, tirex_errors, width, label='TiRex', alpha=0.7, color='C4')
        ax.bar(x + 2.5*width, ensemble_errors, width, label='Ensemble', alpha=0.7, color='gray')
        ax.axhline(0, color='black', linestyle='-', linewidth=0.8)

        ax.set_title(f'{watch_names[watch_id]}', fontsize=11, fontweight='bold')
        ax.set_xlabel('Month', fontsize=9)
        ax.set_ylabel('Prediction Error (units)', fontsize=9)
        ax.legend(fontsize=7)
        ax.grid(True, alpha=0.3, axis='y')
        ax.set_xticks(x)
        ax.set_xticklabels(range(1, 13))

    plt.tight_layout()
    plt.savefig('error_distribution.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("Error distribution saved as 'error_distribution.png'")


def main():
    print("=" * 145)
    print("FORECASTING STRATEGY COMPARISON")
    print("Strategy 1 (Naive):    Use Year 10 demand as-is")
    print("Strategy 2 (Trend):    Year 10 + average monthly year-over-year growth")
    print("Strategy 3 (AR12):     Autoregressive model using past 12 months")
    print("Strategy 4 (Prophet):  Facebook's forecasting with automatic seasonality")
    print("Strategy 5 (TiRex):    Zero-shot neural forecasting (xLSTM-based)")
    print("Strategy 6 (Ensemble): Simple average of all five approaches")
    print("=" * 145)

    # Load data
    print("\n1. Loading data...")
    train_df, test_df = load_data(
        '../data/supply_chain_data_training.json',
        '../data/supply_chain_data_test.json'
    )
    print(f"   Training: {len(train_df)} records (10 years, 3 watches)")
    print(f"   Test:     {len(test_df)} records (year 11, 3 watches)")

    # Load TiRex model
    print("\n2. Loading TiRex model...")
    print("   Note: First time may take a while to download model (~35M parameters)")
    tirex_model = load_model("NX-AI/TiRex")
    tirex_model = load_model("NX-AI/TiRex", backend="torch")
    print("   TiRex model loaded successfully")

    # Generate predictions
    print("\n3. Generating forecasts...")
    print("   - Naive forecast (Year 10 as-is)")
    naive_predictions = naive_forecast(train_df)

    print("   - Trend-based forecast (Year 10 + trends)")
    trends = calculate_monthly_trends(train_df)
    trend_predictions = trend_forecast(train_df, trends)

    print("   - Autoregressive forecast (AR with 12 lags)")
    ar_predictions, ar_models = autoregressor_forecast(train_df, lags=12)

    print("   - Prophet forecast (with seasonality)")
    prophet_predictions, prophet_models = prophet_forecast(train_df)

    print("   - TiRex forecast (zero-shot neural)")
    tirex_predictions = tirex_forecast(train_df, tirex_model)

    print("   - Ensemble forecast (averaging all methods)")
    ensemble_predictions = ensemble_forecast(naive_predictions, trend_predictions,
                                            ar_predictions, prophet_predictions,
                                            tirex_predictions)

    # Compare
    print("\n4. Evaluating and comparing strategies...")
    results, *overall_metrics = evaluate_and_compare(
        test_df, naive_predictions, trend_predictions, ar_predictions,
        prophet_predictions, tirex_predictions, ensemble_predictions
    )

    # Visualize
    print("\n5. Creating visualizations...")
    plot_comparison(test_df, naive_predictions, trend_predictions,
                   ar_predictions, prophet_predictions, tirex_predictions,
                   ensemble_predictions)
    plot_error_distribution(test_df, naive_predictions, trend_predictions,
                           ar_predictions, prophet_predictions, tirex_predictions,
                           ensemble_predictions)

    print("\n" + "=" * 145)
    print("Analysis Complete!")
    print("=" * 145)

    return (naive_predictions, trend_predictions, ar_predictions,
            prophet_predictions, tirex_predictions, ensemble_predictions, results)


if __name__ == "__main__":
    (naive_pred, trend_pred, ar_pred, prophet_pred,
     tirex_pred, ensemble_pred, results) = main()
```

%% Output

    /home/thibault/horloml/aa_venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
      from .autonotebook import tqdm as notebook_tqdm
    Importing plotly failed. Interactive plots will not work.

    =================================================================================================================================================
    FORECASTING STRATEGY COMPARISON
    Strategy 1 (Naive):    Use Year 10 demand as-is
    Strategy 2 (Trend):    Year 10 + average monthly year-over-year growth
    Strategy 3 (AR12):     Autoregressive model using past 12 months
    Strategy 4 (Prophet):  Facebook's forecasting with automatic seasonality
    Strategy 5 (TiRex):    Zero-shot neural forecasting (xLSTM-based)
    Strategy 6 (Ensemble): Simple average of all five approaches
    =================================================================================================================================================
    
    1. Loading data...
       Training: 360 records (10 years, 3 watches)
       Test:     36 records (year 11, 3 watches)
    
    2. Loading TiRex model...
       Note: First time may take a while to download model (~35M parameters)

    ---------------------------------------------------------------------------
    AssertionError                            Traceback (most recent call last)
    Cell In[1], line 624
        618     return (naive_predictions, trend_predictions, ar_predictions,
        619             prophet_predictions, tirex_predictions, ensemble_predictions, results)
        622 if __name__ == "__main__":
        623     (naive_pred, trend_pred, ar_pred, prophet_pred,
    --> 624      tirex_pred, ensemble_pred, results) = main()
    Cell In[1], line 572, in main()
        570 print("\n2. Loading TiRex model...")
        571 print("   Note: First time may take a while to download model (~35M parameters)")
    --> 572 tirex_model = load_model("NX-AI/TiRex", backend="cpu")
        573 print("   TiRex model loaded successfully")
        575 # Generate predictions
    File ~/horloml/aa_venv/lib/python3.12/site-packages/tirex/base.py:130, in load_model(path, device, backend, compile, hf_kwargs, ckp_kwargs)
        127 if model_cls is None:
        128     raise ValueError(f"Invalid model id {model_id}")
    --> 130 return model_cls.from_pretrained(
        131     path, device=device, backend=backend, compile=compile, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs
        132 )
    File ~/horloml/aa_venv/lib/python3.12/site-packages/tirex/base.py:73, in PretrainedModel.from_pretrained(cls, path, backend, device, compile, hf_kwargs, ckp_kwargs)
         71 # load lightning checkpoint
         72 checkpoint = torch.load(checkpoint_path, map_location=device, **ckp_kwargs, weights_only=True)
    ---> 73 model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
         74 model.on_load_checkpoint(checkpoint)
         75 model.load_state_dict(checkpoint["state_dict"])
    File ~/horloml/aa_venv/lib/python3.12/site-packages/tirex/models/tirex.py:47, in TiRexZero.__init__(self, backend, model_config, train_ctx_len)
         40 block_config = dataclass_from_dict(sLSTMBlockConfig, self.config.block_kwargs)
         41 self.input_patch_embedding = ResidualBlock(
         42     in_dim=self.config.input_patch_size * 2,
         43     h_dim=self.config.input_ff_dim,
         44     out_dim=block_config.embedding_dim,
         45 )
    ---> 47 self.blocks = nn.ModuleList([sLSTMBlock(block_config, backend) for i in range(num_blocks)])
         49 self.out_norm = RMSNorm(block_config.embedding_dim)
         51 self.output_patch_embedding = ResidualBlock(
         52     in_dim=block_config.embedding_dim,
         53     h_dim=self.config.input_ff_dim,
         54     out_dim=len(self.config.quantiles) * self.config.output_patch_size,
         55 )
    File ~/horloml/aa_venv/lib/python3.12/site-packages/tirex/models/slstm/block.py:17, in sLSTMBlock.__init__(self, config, backend)
         15 self.config = config
         16 self.norm_slstm = RMSNorm(config.embedding_dim)
    ---> 17 self.slstm_layer = sLSTMLayer(config, backend)
         18 self.norm_ffn = RMSNorm(config.embedding_dim)
         20 up_proj_dim = round_up_to_next_multiple_of(config.embedding_dim * config.ffn_proj_factor, 64)
    File ~/horloml/aa_venv/lib/python3.12/site-packages/tirex/models/slstm/layer.py:22, in sLSTMLayer.__init__(self, config, backend)
         19 self.zgate = LinearHeadwiseExpand(in_features, num_heads)
         20 self.ogate = LinearHeadwiseExpand(in_features, num_heads)
    ---> 22 self.slstm_cell = sLSTMCell(self.config, backend)
         23 self.group_norm = MultiHeadLayerNorm(ndim=in_features, num_heads=num_heads)
    File ~/horloml/aa_venv/lib/python3.12/site-packages/tirex/models/slstm/cell.py:31, in sLSTMCell.__init__(self, config, backend)
         29 def __init__(self, config: sLSTMBlockConfig, backend: Literal["torch", "cuda"]):
         30     super().__init__()
    ---> 31     assert backend in ["torch", "cuda"], f"Backend can either be torch or cuda, not {backend}!"
         32     self.config = config
         33     self.backend = backend
    AssertionError: Backend can either be torch or cuda, not cpu!