Unverified Commit 6af7af3f authored by utorque's avatar utorque Committed by GitHub
Browse files

Merge pull request #12 from utorque/claude/add-disrupted-data-method-01HLMfF2mQgxyvpFv42KQ86t

Add method to generate disrupted test data
parents 75e02d06 6bbb63be
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@
/HorloML_EDU_Main/web_app/flask_session
/supply_chain_sim.egg-info
/forecast_app/data
/data
/venv
/__pycache__
*.egg
+67 −10
Original line number Diff line number Diff line
@@ -34,13 +34,15 @@ app.config['SESSION_TYPE'] = 'filesystem'
# Load dataset
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
TRAINING_DATA_PATH = os.path.join(DATA_DIR, 'sim2_supply_chain_data_training.json')
TEST_DATA_PATH = os.path.join(DATA_DIR, 'sim2_supply_chain_data_test.json')
TEST_DATA_PATH_NORMAL = os.path.join(DATA_DIR, 'sim2_supply_chain_data_test.json')
TEST_DATA_PATH_EXTERNAL = os.path.join(DATA_DIR, 'sim2_data_test_external_events.json')

# TiRex model (loaded lazily on first request)
_tirex_model = None

# AI activation state (disabled by default)
_ai_enabled = False
# Teacher parameters
_ai_enabled = True
_simulation_type = 'normal'  # 'normal' or 'external_event'
AI_PASSWORD = os.environ.get('AI_PASSWORD', 'HorloML-AI')


@@ -59,8 +61,23 @@ def load_training_data():


def load_test_data():
    """Load the year 11 test data (ground truth)"""
    with open(TEST_DATA_PATH, 'r') as f:
    """Load the year 11 test data (ground truth) based on simulation type"""
    global _simulation_type

    if _simulation_type == 'external_event':
        # Load external events data
        if os.path.exists(TEST_DATA_PATH_EXTERNAL):
            with open(TEST_DATA_PATH_EXTERNAL, 'r') as f:
                data = json.load(f)
                # External events file has different structure: {'metadata': ..., 'test_data': [...]}
                return data.get('test_data', data)
        else:
            # Fallback to normal if external file doesn't exist
            with open(TEST_DATA_PATH_NORMAL, 'r') as f:
                return json.load(f)
    else:
        # Load normal test data
        with open(TEST_DATA_PATH_NORMAL, 'r') as f:
            return json.load(f)


@@ -407,7 +424,7 @@ def download_excel():
    headers = ['Year', 'Month', 'Date']
    for watch in watches:
        headers.extend([
            f"{watch['name']} - Demand",
            f"{watch['name']} - Units Sold",
            f"{watch['name']} - Revenue",
            f"{watch['name']} - Profit"
        ])
@@ -445,7 +462,7 @@ def download_excel():
    yearly_headers = ['Year']
    for watch in watches:
        yearly_headers.extend([
            f"{watch['name']} - Total Demand",
            f"{watch['name']} - Total Units Sold",
            f"{watch['name']} - Total Revenue",
            f"{watch['name']} - Total Profit"
        ])
@@ -594,16 +611,56 @@ def results():
                          test_data=test_data)


@app.route('/api/teacher_params', methods=['GET'])
def get_teacher_params():
    """Get current teacher parameters (AI status and simulation type)"""
    global _ai_enabled, _simulation_type
    return jsonify({
        'ai_enabled': _ai_enabled,
        'simulation_type': _simulation_type
    })


@app.route('/api/teacher_params', methods=['POST'])
def update_teacher_params():
    """Update teacher parameters with password verification"""
    global _ai_enabled, _simulation_type

    data = request.get_json()
    password = data.get('password', '')

    if password != AI_PASSWORD:
        return jsonify({'success': False, 'error': 'Invalid password'}), 401

    # Update parameters if provided
    if 'ai_enabled' in data:
        _ai_enabled = data['ai_enabled']

    if 'simulation_type' in data:
        sim_type = data['simulation_type']
        if sim_type in ['normal', 'external_event']:
            _simulation_type = sim_type
        else:
            return jsonify({'success': False, 'error': 'Invalid simulation type'}), 400

    return jsonify({
        'success': True,
        'ai_enabled': _ai_enabled,
        'simulation_type': _simulation_type
    })


# Legacy endpoints for backward compatibility
@app.route('/api/ai_status', methods=['GET'])
def ai_status():
    """Get current AI activation status"""
    """Get current AI activation status (legacy endpoint)"""
    global _ai_enabled
    return jsonify({'enabled': _ai_enabled})


@app.route('/api/toggle_ai', methods=['POST'])
def toggle_ai():
    """Toggle AI activation with password verification"""
    """Toggle AI activation with password verification (legacy endpoint)"""
    global _ai_enabled

    data = request.get_json()
+127 −1
Original line number Diff line number Diff line
@@ -389,7 +389,8 @@ class SmartSupplyChainDataGenerator:
                demand = purchases[watch_id]

                # Production strategy: produce based on demand + safety stock
                production = int(demand * 1.05) # TODO : here maybe do a random value between .95 and 1.05 for noise ?
                # production = int(demand * 1.05) # TODO : here maybe do a random value between .95 and 1.05 for noise ?
                production = int(demand * np.random.uniform(0.95, 1.05))

                # Calculate financials
                watch_data = self._calculate_costs_and_revenue(
@@ -439,6 +440,113 @@ class SmartSupplyChainDataGenerator:

        return dataset['historical_data'][start_idx:end_idx]

    def generate_changed_test_data(self, dataset: Dict, disruption_config: List[Dict],
                                   test_year: int = 11, filepath: str = 'data/data_test_external_events.json'):
        """
        Generate test data with external disruption events applied

        Args:
            dataset: Full dataset to extract test data from
            disruption_config: List of disruption events, each with:
                - month_to_change: Month number (1-12) to apply disruption
                - demand_ratio: Multiplier for demand (e.g., 0.5 halves demand, 2.0 doubles it)
                - watch_ids: Optional list of watch IDs to affect (default: all watches)
            test_year: Which year to use as test data (default: 11)
            filepath: Where to save the modified test data

        Returns:
            Modified test data with disruptions applied

        Example:
            disruption_config = [
                {'month_to_change': 3, 'demand_ratio': 0.5},  # Halve demand in March
                {'month_to_change': 6, 'demand_ratio': 0.3, 'watch_ids': [1, 2]}  # 70% drop in June for watches 1&2
            ]
        """
        # Get base test data
        test_data = self.get_test_data(dataset, test_year)

        # Create a deep copy to modify
        import copy
        modified_test_data = copy.deepcopy(test_data)

        # Track inventory for recalculating financials
        # Get ending inventory from last training month
        last_training_idx = (test_year - 1) * 12 - 1
        inventory = {}
        for watch in self.watches:
            watch_id = watch['id']
            # Find the watch data in the last training month
            last_month_data = dataset['historical_data'][last_training_idx]
            watch_data = next(w for w in last_month_data['watches'] if w['watch_id'] == watch_id)
            inventory[watch_id] = watch_data['inventory_end']

        # Apply disruptions month by month
        for month_idx, month_data in enumerate(modified_test_data):
            month_in_year = month_data['month']

            # Check if any disruptions apply to this month
            disruptions_this_month = [
                d for d in disruption_config
                if d['month_to_change'] == month_in_year
            ]

            # Process each watch
            for watch_idx, watch in enumerate(self.watches):
                watch_id = watch['id']
                watch_data = month_data['watches'][watch_idx]

                # Get current demand
                original_demand = watch_data['demand']
                modified_demand = original_demand

                # Apply all applicable disruptions for this watch
                for disruption in disruptions_this_month:
                    # Check if this disruption affects this watch
                    affected_watches = disruption.get('watch_ids', None)
                    if affected_watches is None or watch_id in affected_watches:
                        modified_demand = int(modified_demand * disruption['demand_ratio'])

                # If demand changed, recalculate all financials
                if modified_demand != original_demand:
                    production = watch_data['production']  # Keep same production decision

                    # Recalculate financials with new demand
                    new_financials = self._calculate_costs_and_revenue(
                        watch, modified_demand, production, inventory[watch_id]
                    )

                    # Update watch data with new values
                    watch_data.update(new_financials)
                    watch_data['watch_id'] = watch['id']
                    watch_data['watch_name'] = watch['name']
                    watch_data['original_demand'] = original_demand  # Keep track of original
                    watch_data['disruption_applied'] = True

                # Update inventory for next month
                inventory[watch_id] = watch_data['inventory_end']

        # Add metadata about disruptions
        output_data = {
            'metadata': {
                'generated_date': datetime.now().isoformat(),
                'generator_type': 'smart_customer_simulation_with_disruptions',
                'base_year': test_year,
                'disruptions': disruption_config,
                'watches': self.watches
            },
            'test_data': modified_test_data
        }

        # Save to file
        import os
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        with open(filepath, 'w') as f:
            json.dump(output_data, f, indent=2)
        print(f"Modified test data with disruptions saved to {filepath}")

        return output_data


def main():
    """Generate and save the dataset"""
@@ -473,6 +581,22 @@ def main():
    with open('data/sim2_supply_chain_data_test.json', 'w') as f:
        json.dump(test_data, f, indent=2)

    # Generate test data with external disruptions (example)
    print("\n" + "-" * 60)
    print("Generating test data with external disruptions...")
    print("-" * 60)

    disruption_config = [
        {'month_to_change': 3, 'demand_ratio': 0.4},  # Halve demand in March (supply chain crisis)
        {'month_to_change': 4, 'demand_ratio': 0.4},  # 50% increase in July (recovery surge)
    ]

    disrupted_data = generator.generate_changed_test_data(
        full_dataset,
        disruption_config,
        filepath='data/sim2_data_test_external_events.json'
    )

    print("\n" + "=" * 60)
    print("Dataset Generation Complete!")
    print("=" * 60)
@@ -534,3 +658,5 @@ def main():

if __name__ == "__main__":
    main()

    
+208 −55

File changed.

Preview size limit exceeded, changes collapsed.

+3 −3
Original line number Diff line number Diff line
@@ -89,7 +89,7 @@
            {% for watch in watches %}
            <div class="column is-4">
                <div class="box">
                    <h4 class="title is-5">{{ watch.name }} - Monthly Demand</h4>
                    <h4 class="title is-5">{{ watch.name }} - Monthly Units</h4>
                    <div class="chart-container" style="height: 300px;">
                        <canvas id="chart_{{ watch.id }}"></canvas>
                    </div>
@@ -115,7 +115,7 @@
                            <th></th>
                            <th></th>
                            {% for watch in watches %}
                            <th class="has-text-right">Demand</th>
                            <th class="has-text-right">Units</th>
                            <th class="has-text-right">Revenue</th>
                            {% endfor %}
                        </tr>
@@ -168,7 +168,7 @@ watches.forEach(watch => {
        data: {
            labels: labels,
            datasets: [{
                label: 'Monthly Demand',
                label: 'Monthly Units Sold',
                data: watchData,
                borderColor: watch.id === 1 ? '#667eea' : watch.id === 2 ? '#f093fb' : '#4facfe',
                backgroundColor: watch.id === 1 ? '#667eea33' : watch.id === 2 ? '#f093fb33' : '#4facfe33',
Loading