Loading forecast_app/app.py +79 −0 Original line number Diff line number Diff line Loading @@ -12,6 +12,7 @@ from datetime import datetime from io import BytesIO import pandas as pd import numpy as np import torch try: from openpyxl import Workbook Loading @@ -20,6 +21,12 @@ try: except ImportError: EXCEL_AVAILABLE = False try: from tirex import load_model, ForecastModel TIREX_AVAILABLE = True except ImportError: TIREX_AVAILABLE = False app = Flask(__name__) app.secret_key = 'supply-chain-forecast-secret-key-2024' app.config['SESSION_TYPE'] = 'filesystem' Loading @@ -29,6 +36,17 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') TRAINING_DATA_PATH = os.path.join(DATA_DIR, 'sim2_supply_chain_data_training.json') TEST_DATA_PATH = os.path.join(DATA_DIR, 'sim2_supply_chain_data_test.json') # TiRex model (loaded lazily on first request) _tirex_model = None def get_tirex_model(): """Load TiRex model lazily (first time only)""" global _tirex_model if _tirex_model is None and TIREX_AVAILABLE: _tirex_model = load_model("NX-AI/TiRex", backend="torch") return _tirex_model def load_training_data(): """Load the 10-year training dataset""" Loading Loading @@ -572,6 +590,67 @@ def results(): test_data=test_data) @app.route('/api/tirex_forecast', methods=['GET']) def tirex_forecast(): """Generate AI predictions using TiRex model""" if not TIREX_AVAILABLE: return jsonify({'error': 'TiRex library not available. Please install: pip install tirex-ts torch'}), 500 try: # Load TiRex model (lazy loading) model = get_tirex_model() if model is None: return jsonify({'error': 'Failed to load TiRex model'}), 500 # Load training data training_data = load_training_data() # Convert to DataFrame format for processing train_rows = [] for month_data in training_data['historical_data']: for watch_data in month_data['watches']: train_rows.append({ 'year': month_data['year'], 'month': month_data['month'], 'watch_id': watch_data['watch_id'], 'demand': watch_data['demand'] }) train_df = pd.DataFrame(train_rows) # Generate predictions for each watch predictions = {} for watch_id in train_df['watch_id'].unique(): watch_data = train_df[train_df['watch_id'] == watch_id].sort_values(['year', 'month']) demands = watch_data['demand'].values # Convert to torch tensor (1 series, full history) context = torch.tensor(demands, dtype=torch.float32).unsqueeze(0) # Forecast 12 months ahead with torch.no_grad(): quantiles, mean = model.forecast(context=context, prediction_length=12) # Extract mean predictions mean_preds = mean.squeeze().numpy() predictions[str(watch_id)] = {} for month_idx in range(12): pred = max(0, round(float(mean_preds[month_idx]))) predictions[str(watch_id)][month_idx + 1] = pred return jsonify({ 'success': True, 'predictions': predictions }) except Exception as e: return jsonify({ 'error': f'Failed to generate predictions: {str(e)}' }), 500 @app.route('/reset') def reset(): """Clear session and start over""" Loading forecast_app/requirements.txt +3 −0 Original line number Diff line number Diff line Loading @@ -2,3 +2,6 @@ Flask==3.0.0 numpy==2.3.4 Flask-Session==0.5.0 openpyxl==3.1.2 pandas==2.2.0 torch==2.5.1 tirex-ts forecast_app/templates/predict.html +68 −0 Original line number Diff line number Diff line Loading @@ -84,6 +84,12 @@ <span>Fill with Last Year</span> </button> </div> <div class="control"> <button type="button" class="button is-primary" onclick="fillWithAI()" id="aiButton"> <span class="icon"><i class="fas fa-brain"></i></span> <span>Fill with AI Model</span> </button> </div> </div> <hr> Loading @@ -98,6 +104,14 @@ </div> </div> </form> <!-- About TiRex AI Model Link --> <div class="has-text-centered" style="margin-top: 2rem; margin-bottom: 1rem;"> <a href="https://github.com/NX-AI/tirex" target="_blank" class="button is-small is-light"> <span class="icon"><i class="fas fa-info-circle"></i></span> <span>About TiRex AI Model</span> </a> </div> </div> </section> Loading Loading @@ -185,6 +199,60 @@ function fillFromYear10() { }); } async function fillWithAI() { const button = document.getElementById('aiButton'); const originalText = button.innerHTML; // Disable button and show loading state button.disabled = true; button.innerHTML = '<span class="icon"><i class="fas fa-spinner fa-pulse"></i></span><span>Loading AI Predictions...</span>'; try { const response = await fetch('/api/tirex_forecast'); const data = await response.json(); if (!response.ok) { throw new Error(data.error || 'Failed to fetch predictions'); } if (data.success && data.predictions) { // Fill in the predictions watches.forEach(watch => { const inputs = getInputs(watch.id); const watchPredictions = data.predictions[watch.id.toString()]; if (watchPredictions) { for (let month = 1; month <= 12; month++) { if (inputs[month - 1] && watchPredictions[month] !== undefined) { inputs[month - 1].value = watchPredictions[month]; } } } }); // Show success message button.innerHTML = '<span class="icon"><i class="fas fa-check"></i></span><span>AI Predictions Loaded!</span>'; button.classList.remove('is-primary'); button.classList.add('is-success'); setTimeout(() => { button.innerHTML = originalText; button.classList.remove('is-success'); button.classList.add('is-primary'); button.disabled = false; }, 2000); } else { throw new Error('Invalid response format'); } } catch (error) { console.error('Error fetching AI predictions:', error); alert('Failed to load AI predictions: ' + error.message + '\n\nPlease ensure TiRex is installed: pip install tirex torch'); button.innerHTML = originalText; button.disabled = false; } } // Add paste functionality for each watch column watches.forEach((watch, watchIdx) => { const columnHeader = document.querySelector(`.watch-column-${watch.id}`); Loading Loading
forecast_app/app.py +79 −0 Original line number Diff line number Diff line Loading @@ -12,6 +12,7 @@ from datetime import datetime from io import BytesIO import pandas as pd import numpy as np import torch try: from openpyxl import Workbook Loading @@ -20,6 +21,12 @@ try: except ImportError: EXCEL_AVAILABLE = False try: from tirex import load_model, ForecastModel TIREX_AVAILABLE = True except ImportError: TIREX_AVAILABLE = False app = Flask(__name__) app.secret_key = 'supply-chain-forecast-secret-key-2024' app.config['SESSION_TYPE'] = 'filesystem' Loading @@ -29,6 +36,17 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') TRAINING_DATA_PATH = os.path.join(DATA_DIR, 'sim2_supply_chain_data_training.json') TEST_DATA_PATH = os.path.join(DATA_DIR, 'sim2_supply_chain_data_test.json') # TiRex model (loaded lazily on first request) _tirex_model = None def get_tirex_model(): """Load TiRex model lazily (first time only)""" global _tirex_model if _tirex_model is None and TIREX_AVAILABLE: _tirex_model = load_model("NX-AI/TiRex", backend="torch") return _tirex_model def load_training_data(): """Load the 10-year training dataset""" Loading Loading @@ -572,6 +590,67 @@ def results(): test_data=test_data) @app.route('/api/tirex_forecast', methods=['GET']) def tirex_forecast(): """Generate AI predictions using TiRex model""" if not TIREX_AVAILABLE: return jsonify({'error': 'TiRex library not available. Please install: pip install tirex-ts torch'}), 500 try: # Load TiRex model (lazy loading) model = get_tirex_model() if model is None: return jsonify({'error': 'Failed to load TiRex model'}), 500 # Load training data training_data = load_training_data() # Convert to DataFrame format for processing train_rows = [] for month_data in training_data['historical_data']: for watch_data in month_data['watches']: train_rows.append({ 'year': month_data['year'], 'month': month_data['month'], 'watch_id': watch_data['watch_id'], 'demand': watch_data['demand'] }) train_df = pd.DataFrame(train_rows) # Generate predictions for each watch predictions = {} for watch_id in train_df['watch_id'].unique(): watch_data = train_df[train_df['watch_id'] == watch_id].sort_values(['year', 'month']) demands = watch_data['demand'].values # Convert to torch tensor (1 series, full history) context = torch.tensor(demands, dtype=torch.float32).unsqueeze(0) # Forecast 12 months ahead with torch.no_grad(): quantiles, mean = model.forecast(context=context, prediction_length=12) # Extract mean predictions mean_preds = mean.squeeze().numpy() predictions[str(watch_id)] = {} for month_idx in range(12): pred = max(0, round(float(mean_preds[month_idx]))) predictions[str(watch_id)][month_idx + 1] = pred return jsonify({ 'success': True, 'predictions': predictions }) except Exception as e: return jsonify({ 'error': f'Failed to generate predictions: {str(e)}' }), 500 @app.route('/reset') def reset(): """Clear session and start over""" Loading
forecast_app/requirements.txt +3 −0 Original line number Diff line number Diff line Loading @@ -2,3 +2,6 @@ Flask==3.0.0 numpy==2.3.4 Flask-Session==0.5.0 openpyxl==3.1.2 pandas==2.2.0 torch==2.5.1 tirex-ts
forecast_app/templates/predict.html +68 −0 Original line number Diff line number Diff line Loading @@ -84,6 +84,12 @@ <span>Fill with Last Year</span> </button> </div> <div class="control"> <button type="button" class="button is-primary" onclick="fillWithAI()" id="aiButton"> <span class="icon"><i class="fas fa-brain"></i></span> <span>Fill with AI Model</span> </button> </div> </div> <hr> Loading @@ -98,6 +104,14 @@ </div> </div> </form> <!-- About TiRex AI Model Link --> <div class="has-text-centered" style="margin-top: 2rem; margin-bottom: 1rem;"> <a href="https://github.com/NX-AI/tirex" target="_blank" class="button is-small is-light"> <span class="icon"><i class="fas fa-info-circle"></i></span> <span>About TiRex AI Model</span> </a> </div> </div> </section> Loading Loading @@ -185,6 +199,60 @@ function fillFromYear10() { }); } async function fillWithAI() { const button = document.getElementById('aiButton'); const originalText = button.innerHTML; // Disable button and show loading state button.disabled = true; button.innerHTML = '<span class="icon"><i class="fas fa-spinner fa-pulse"></i></span><span>Loading AI Predictions...</span>'; try { const response = await fetch('/api/tirex_forecast'); const data = await response.json(); if (!response.ok) { throw new Error(data.error || 'Failed to fetch predictions'); } if (data.success && data.predictions) { // Fill in the predictions watches.forEach(watch => { const inputs = getInputs(watch.id); const watchPredictions = data.predictions[watch.id.toString()]; if (watchPredictions) { for (let month = 1; month <= 12; month++) { if (inputs[month - 1] && watchPredictions[month] !== undefined) { inputs[month - 1].value = watchPredictions[month]; } } } }); // Show success message button.innerHTML = '<span class="icon"><i class="fas fa-check"></i></span><span>AI Predictions Loaded!</span>'; button.classList.remove('is-primary'); button.classList.add('is-success'); setTimeout(() => { button.innerHTML = originalText; button.classList.remove('is-success'); button.classList.add('is-primary'); button.disabled = false; }, 2000); } else { throw new Error('Invalid response format'); } } catch (error) { console.error('Error fetching AI predictions:', error); alert('Failed to load AI predictions: ' + error.message + '\n\nPlease ensure TiRex is installed: pip install tirex torch'); button.innerHTML = originalText; button.disabled = false; } } // Add paste functionality for each watch column watches.forEach((watch, watchIdx) => { const columnHeader = document.querySelector(`.watch-column-${watch.id}`); Loading