Unverified Commit 968585cc authored by utorque's avatar utorque Committed by GitHub
Browse files

Merge pull request #10 from utorque/claude/tirex-neural-forecasting-01BcwXd3oLCprUTzmRvNMuV1

Implement TiRex zero-shot neural forecasting
parents 14635828 5463ff58
Loading
Loading
Loading
Loading
+79 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ from datetime import datetime
from io import BytesIO
import pandas as pd
import numpy as np
import torch

try:
    from openpyxl import Workbook
@@ -20,6 +21,12 @@ try:
except ImportError:
    EXCEL_AVAILABLE = False

try:
    from tirex import load_model, ForecastModel
    TIREX_AVAILABLE = True
except ImportError:
    TIREX_AVAILABLE = False

app = Flask(__name__)
app.secret_key = 'supply-chain-forecast-secret-key-2024'
app.config['SESSION_TYPE'] = 'filesystem'
@@ -29,6 +36,17 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
TRAINING_DATA_PATH = os.path.join(DATA_DIR, 'sim2_supply_chain_data_training.json')
TEST_DATA_PATH = os.path.join(DATA_DIR, 'sim2_supply_chain_data_test.json')

# TiRex model (loaded lazily on first request)
_tirex_model = None


def get_tirex_model():
    """Load TiRex model lazily (first time only)"""
    global _tirex_model
    if _tirex_model is None and TIREX_AVAILABLE:
        _tirex_model = load_model("NX-AI/TiRex", backend="torch")
    return _tirex_model


def load_training_data():
    """Load the 10-year training dataset"""
@@ -572,6 +590,67 @@ def results():
                          test_data=test_data)


@app.route('/api/tirex_forecast', methods=['GET'])
def tirex_forecast():
    """Generate AI predictions using TiRex model"""
    if not TIREX_AVAILABLE:
        return jsonify({'error': 'TiRex library not available. Please install: pip install tirex-ts torch'}), 500

    try:
        # Load TiRex model (lazy loading)
        model = get_tirex_model()
        if model is None:
            return jsonify({'error': 'Failed to load TiRex model'}), 500

        # Load training data
        training_data = load_training_data()

        # Convert to DataFrame format for processing
        train_rows = []
        for month_data in training_data['historical_data']:
            for watch_data in month_data['watches']:
                train_rows.append({
                    'year': month_data['year'],
                    'month': month_data['month'],
                    'watch_id': watch_data['watch_id'],
                    'demand': watch_data['demand']
                })

        train_df = pd.DataFrame(train_rows)

        # Generate predictions for each watch
        predictions = {}

        for watch_id in train_df['watch_id'].unique():
            watch_data = train_df[train_df['watch_id'] == watch_id].sort_values(['year', 'month'])
            demands = watch_data['demand'].values

            # Convert to torch tensor (1 series, full history)
            context = torch.tensor(demands, dtype=torch.float32).unsqueeze(0)

            # Forecast 12 months ahead
            with torch.no_grad():
                quantiles, mean = model.forecast(context=context, prediction_length=12)

            # Extract mean predictions
            mean_preds = mean.squeeze().numpy()

            predictions[str(watch_id)] = {}
            for month_idx in range(12):
                pred = max(0, round(float(mean_preds[month_idx])))
                predictions[str(watch_id)][month_idx + 1] = pred

        return jsonify({
            'success': True,
            'predictions': predictions
        })

    except Exception as e:
        return jsonify({
            'error': f'Failed to generate predictions: {str(e)}'
        }), 500


@app.route('/reset')
def reset():
    """Clear session and start over"""
+3 −0
Original line number Diff line number Diff line
@@ -2,3 +2,6 @@ Flask==3.0.0
numpy==2.3.4
Flask-Session==0.5.0
openpyxl==3.1.2
pandas==2.2.0
torch==2.5.1
tirex-ts
+68 −0
Original line number Diff line number Diff line
@@ -84,6 +84,12 @@
                            <span>Fill with Last Year</span>
                        </button>
                    </div>
                    <div class="control">
                        <button type="button" class="button is-primary" onclick="fillWithAI()" id="aiButton">
                            <span class="icon"><i class="fas fa-brain"></i></span>
                            <span>Fill with AI Model</span>
                        </button>
                    </div>
                </div>

                <hr>
@@ -98,6 +104,14 @@
                </div>
            </div>
        </form>

        <!-- About TiRex AI Model Link -->
        <div class="has-text-centered" style="margin-top: 2rem; margin-bottom: 1rem;">
            <a href="https://github.com/NX-AI/tirex" target="_blank" class="button is-small is-light">
                <span class="icon"><i class="fas fa-info-circle"></i></span>
                <span>About TiRex AI Model</span>
            </a>
        </div>
    </div>
</section>

@@ -185,6 +199,60 @@ function fillFromYear10() {
    });
}

async function fillWithAI() {
    const button = document.getElementById('aiButton');
    const originalText = button.innerHTML;

    // Disable button and show loading state
    button.disabled = true;
    button.innerHTML = '<span class="icon"><i class="fas fa-spinner fa-pulse"></i></span><span>Loading AI Predictions...</span>';

    try {
        const response = await fetch('/api/tirex_forecast');
        const data = await response.json();

        if (!response.ok) {
            throw new Error(data.error || 'Failed to fetch predictions');
        }

        if (data.success && data.predictions) {
            // Fill in the predictions
            watches.forEach(watch => {
                const inputs = getInputs(watch.id);
                const watchPredictions = data.predictions[watch.id.toString()];

                if (watchPredictions) {
                    for (let month = 1; month <= 12; month++) {
                        if (inputs[month - 1] && watchPredictions[month] !== undefined) {
                            inputs[month - 1].value = watchPredictions[month];
                        }
                    }
                }
            });

            // Show success message
            button.innerHTML = '<span class="icon"><i class="fas fa-check"></i></span><span>AI Predictions Loaded!</span>';
            button.classList.remove('is-primary');
            button.classList.add('is-success');

            setTimeout(() => {
                button.innerHTML = originalText;
                button.classList.remove('is-success');
                button.classList.add('is-primary');
                button.disabled = false;
            }, 2000);
        } else {
            throw new Error('Invalid response format');
        }
    } catch (error) {
        console.error('Error fetching AI predictions:', error);
        alert('Failed to load AI predictions: ' + error.message + '\n\nPlease ensure TiRex is installed: pip install tirex torch');

        button.innerHTML = originalText;
        button.disabled = false;
    }
}

// Add paste functionality for each watch column
watches.forEach((watch, watchIdx) => {
    const columnHeader = document.querySelector(`.watch-column-${watch.id}`);