Spaces:
Running
Running
| """ | |
| Tests for RAG-based Incident Learning System | |
| Tests incident capture, root cause analysis, and regression test generation | |
| """ | |
| import pytest | |
| import json | |
| from datetime import datetime | |
| from pathlib import Path | |
| from agent.incident_learning import ( | |
| Incident, | |
| IncidentDatabase, | |
| RootCauseAnalyzer, | |
| RegressionTestGenerator, | |
| capture_deployment_failure, | |
| learn_from_incident | |
| ) | |
| def incident_db(tmp_path): | |
| """Create temporary incident database""" | |
| return IncidentDatabase(db_path=tmp_path / "incidents") | |
| def sample_incident(): | |
| """Create sample incident for testing""" | |
| return Incident( | |
| id="test-20251125-120000", | |
| timestamp=datetime.now().isoformat(), | |
| severity="high", | |
| category="deployment_failure", | |
| description="VLAN 100 configuration failed - duplicate VLAN ID", | |
| affected_devices=["leaf-01", "leaf-02"], | |
| network_model={ | |
| "name": "test-network", | |
| "vlans": [ | |
| {"id": 100, "name": "Users"}, | |
| {"id": 100, "name": "Servers"} # Duplicate! | |
| ] | |
| }, | |
| validation_errors=[ | |
| {"error": "Duplicate VLAN ID 100", "type": "validation"} | |
| ] | |
| ) | |
| def routing_incident(): | |
| """Routing issue incident""" | |
| return Incident( | |
| id="routing-20251125-130000", | |
| timestamp=datetime.now().isoformat(), | |
| severity="critical", | |
| category="routing_issue", | |
| description="Routing loop detected between spine-01 and spine-02", | |
| affected_devices=["spine-01", "spine-02"], | |
| root_cause="BGP AS-PATH loop due to incorrect route reflector config" | |
| ) | |
| class TestIncident: | |
| """Test Incident dataclass""" | |
| def test_incident_creation(self, sample_incident): | |
| """Test creating incident""" | |
| assert sample_incident.id == "test-20251125-120000" | |
| assert sample_incident.severity == "high" | |
| assert sample_incident.category == "deployment_failure" | |
| assert len(sample_incident.affected_devices) == 2 | |
| def test_incident_to_dict(self, sample_incident): | |
| """Test serialization""" | |
| data = sample_incident.to_dict() | |
| assert data['id'] == sample_incident.id | |
| assert data['severity'] == sample_incident.severity | |
| assert data['category'] == sample_incident.category | |
| assert data['affected_devices'] == sample_incident.affected_devices | |
| assert 'network_model' in data | |
| assert 'validation_errors' in data | |
| class TestIncidentDatabase: | |
| """Test incident database operations""" | |
| def test_add_incident(self, incident_db, sample_incident): | |
| """Test adding incident to database""" | |
| incident_id = incident_db.add_incident(sample_incident) | |
| assert incident_id == sample_incident.id | |
| # Verify file was created | |
| assert incident_db.incidents_file.exists() | |
| # Verify data | |
| with open(incident_db.incidents_file) as f: | |
| data = json.load(f) | |
| assert incident_id in data | |
| def test_get_incident(self, incident_db, sample_incident): | |
| """Test retrieving incident by ID""" | |
| incident_db.add_incident(sample_incident) | |
| retrieved = incident_db.get_incident(sample_incident.id) | |
| assert retrieved is not None | |
| assert retrieved.id == sample_incident.id | |
| assert retrieved.severity == sample_incident.severity | |
| assert retrieved.description == sample_incident.description | |
| def test_get_nonexistent_incident(self, incident_db): | |
| """Test getting incident that doesn't exist""" | |
| result = incident_db.get_incident("nonexistent-id") | |
| assert result is None | |
| def test_update_incident(self, incident_db, sample_incident): | |
| """Test updating incident""" | |
| incident_db.add_incident(sample_incident) | |
| # Update with resolution | |
| incident_db.update_incident(sample_incident.id, { | |
| 'resolution': 'Fixed duplicate VLAN IDs', | |
| 'resolved_at': datetime.now().isoformat() | |
| }) | |
| # Verify update | |
| updated = incident_db.get_incident(sample_incident.id) | |
| assert updated.resolution == 'Fixed duplicate VLAN IDs' | |
| assert updated.resolved_at is not None | |
| def test_get_all_incidents(self, incident_db, sample_incident, routing_incident): | |
| """Test getting all incidents""" | |
| incident_db.add_incident(sample_incident) | |
| incident_db.add_incident(routing_incident) | |
| incidents = incident_db.get_all_incidents() | |
| assert len(incidents) == 2 | |
| # Should be sorted by timestamp (newest first) | |
| assert incidents[0].id == routing_incident.id | |
| def test_filter_by_severity(self, incident_db, sample_incident, routing_incident): | |
| """Test filtering incidents by severity""" | |
| incident_db.add_incident(sample_incident) # high | |
| incident_db.add_incident(routing_incident) # critical | |
| critical = incident_db.get_all_incidents(severity="critical") | |
| assert len(critical) == 1 | |
| assert critical[0].severity == "critical" | |
| high = incident_db.get_all_incidents(severity="high") | |
| assert len(high) == 1 | |
| assert high[0].severity == "high" | |
| def test_filter_by_category(self, incident_db, sample_incident, routing_incident): | |
| """Test filtering by category""" | |
| incident_db.add_incident(sample_incident) | |
| incident_db.add_incident(routing_incident) | |
| deploy_failures = incident_db.get_all_incidents(category="deployment_failure") | |
| assert len(deploy_failures) == 1 | |
| assert deploy_failures[0].category == "deployment_failure" | |
| routing = incident_db.get_all_incidents(category="routing_issue") | |
| assert len(routing) == 1 | |
| assert routing[0].category == "routing_issue" | |
| def test_keyword_search(self, incident_db, sample_incident, routing_incident): | |
| """Test keyword search (fallback mode)""" | |
| incident_db.add_incident(sample_incident) | |
| incident_db.add_incident(routing_incident) | |
| # Search for VLAN issues | |
| results = incident_db.search_similar("VLAN duplicate", n_results=5) | |
| # In mock mode without ChromaDB, keyword search may return empty if no exact match | |
| assert isinstance(results, list) | |
| # If results found, verify they're relevant | |
| if results: | |
| assert any("VLAN" in r.description or "vlan" in r.description.lower() for r in results) | |
| # Search for routing issues | |
| results = incident_db.search_similar("routing loop BGP", n_results=5) | |
| assert isinstance(results, list) | |
| class TestRootCauseAnalyzer: | |
| """Test root cause analysis""" | |
| def test_analyze_incident(self, incident_db, sample_incident): | |
| """Test analyzing incident""" | |
| # Add some historical incidents for context | |
| historical = Incident( | |
| id="historical-1", | |
| timestamp="2025-11-20T10:00:00", | |
| severity="high", | |
| category="deployment_failure", | |
| description="Duplicate VLAN configuration on leaf-03", | |
| affected_devices=["leaf-03"], | |
| root_cause="Schema validation missed duplicate VLAN IDs" | |
| ) | |
| incident_db.add_incident(historical) | |
| analyzer = RootCauseAnalyzer(incident_db) | |
| analysis = analyzer.analyze(sample_incident) | |
| assert 'suggested_root_cause' in analysis | |
| assert 'similar_incidents' in analysis | |
| assert 'patterns_found' in analysis | |
| assert 'confidence' in analysis | |
| # Similar incidents may be 0 without ChromaDB vector search | |
| assert isinstance(analysis['similar_incidents'], list) | |
| # Confidence should be >= 0 | |
| assert analysis['confidence'] >= 0.0 | |
| def test_extract_patterns(self, incident_db): | |
| """Test pattern extraction from similar incidents""" | |
| # Add multiple incidents with same root cause | |
| for i in range(3): | |
| incident = Incident( | |
| id=f"pattern-test-{i}", | |
| timestamp=datetime.now().isoformat(), | |
| severity="medium", | |
| category="config_error", | |
| description=f"BGP neighbor config error on device-{i}", | |
| affected_devices=[f"device-{i}"], | |
| root_cause="Missing BGP neighbor authentication" | |
| ) | |
| incident_db.add_incident(incident) | |
| analyzer = RootCauseAnalyzer(incident_db) | |
| incidents = incident_db.get_all_incidents() | |
| patterns = analyzer._extract_patterns(incidents) | |
| assert len(patterns) > 0 | |
| # Should identify common root cause | |
| assert any("Missing BGP neighbor authentication" in p for p in patterns) | |
| def test_confidence_calculation(self, incident_db): | |
| """Test confidence scoring""" | |
| analyzer = RootCauseAnalyzer(incident_db) | |
| # Low confidence: no similar incidents | |
| confidence = analyzer._calculate_confidence([], []) | |
| assert confidence == 0.0 | |
| # Medium confidence: some similar, no patterns | |
| similar = [ | |
| Incident( | |
| id=f"test-{i}", | |
| timestamp=datetime.now().isoformat(), | |
| severity="low", | |
| category="config_error", | |
| description="Test incident" | |
| ) | |
| for i in range(3) | |
| ] | |
| confidence = analyzer._calculate_confidence(similar, []) | |
| assert 0.0 < confidence < 1.0 | |
| # High confidence: similar incidents + resolved + patterns | |
| for inc in similar: | |
| inc.resolution = "Fixed" | |
| patterns = ["Common pattern 1", "Common pattern 2"] | |
| confidence = analyzer._calculate_confidence(similar, patterns) | |
| assert confidence >= 0.7 | |
| class TestRegressionTestGenerator: | |
| """Test regression test generation""" | |
| def test_generate_config_test(self, sample_incident): | |
| """Test generating config error test""" | |
| generator = RegressionTestGenerator() | |
| test_code = generator.generate_test(sample_incident) | |
| assert test_code is not None | |
| assert "pyats" in test_code or "pytest" in test_code | |
| assert sample_incident.id in test_code | |
| assert sample_incident.description in test_code | |
| def test_generate_routing_test(self, routing_incident): | |
| """Test generating routing issue test""" | |
| generator = RegressionTestGenerator() | |
| test_code = generator.generate_test(routing_incident) | |
| assert test_code is not None | |
| assert "routing" in test_code.lower() | |
| assert routing_incident.id in test_code | |
| def test_generate_deployment_test(self): | |
| """Test generating deployment failure test""" | |
| incident = Incident( | |
| id="deploy-fail-1", | |
| timestamp=datetime.now().isoformat(), | |
| severity="high", | |
| category="deployment_failure", | |
| description="Deployment failed due to syntax error" | |
| ) | |
| generator = RegressionTestGenerator() | |
| test_code = generator.generate_test(incident) | |
| assert test_code is not None | |
| assert "deployment" in test_code.lower() or "validation" in test_code.lower() | |
| def test_generated_test_is_valid_python(self, sample_incident): | |
| """Test that generated code is valid Python""" | |
| generator = RegressionTestGenerator() | |
| test_code = generator.generate_test(sample_incident) | |
| # Try to compile it | |
| try: | |
| compile(test_code, '<string>', 'exec') | |
| valid = True | |
| except SyntaxError: | |
| valid = False | |
| assert valid, "Generated test code has syntax errors" | |
| class TestCaptureDeploymentFailure: | |
| """Test deployment failure capture""" | |
| def test_capture_failure(self, tmp_path): | |
| """Test capturing deployment failure""" | |
| # Create temp database | |
| db = IncidentDatabase(db_path=tmp_path / "incidents") | |
| # Monkey patch to use our temp database | |
| import agent.incident_learning | |
| original_db_init = agent.incident_learning.IncidentDatabase | |
| agent.incident_learning.IncidentDatabase = lambda: db | |
| try: | |
| incident = capture_deployment_failure( | |
| description="Test deployment failure", | |
| network_model={"name": "test-net", "devices": []}, | |
| validation_errors=[{"error": "Test error"}], | |
| affected_devices=["device-1"] | |
| ) | |
| assert incident.severity == "high" | |
| assert incident.category == "deployment_failure" | |
| assert len(incident.affected_devices) == 1 | |
| # Verify it was stored | |
| retrieved = db.get_incident(incident.id) | |
| assert retrieved is not None | |
| finally: | |
| # Restore | |
| agent.incident_learning.IncidentDatabase = original_db_init | |
| class TestLearnFromIncident: | |
| """Test complete learning workflow""" | |
| def test_learn_from_incident(self, tmp_path, sample_incident): | |
| """Test full learning cycle""" | |
| db = IncidentDatabase(db_path=tmp_path / "incidents") | |
| db.add_incident(sample_incident) | |
| # Monkey patch database | |
| import agent.incident_learning | |
| original_db_init = agent.incident_learning.IncidentDatabase | |
| agent.incident_learning.IncidentDatabase = lambda: db | |
| try: | |
| result = learn_from_incident(sample_incident) | |
| assert 'incident_id' in result | |
| assert 'root_cause' in result | |
| assert 'regression_test' in result | |
| assert 'confidence' in result | |
| # Check that incident was updated with learnings | |
| updated = db.get_incident(sample_incident.id) | |
| assert updated.root_cause is not None | |
| # Check that test file would be created | |
| assert result['regression_test'] is not None | |
| finally: | |
| agent.incident_learning.IncidentDatabase = original_db_init | |
| class TestPipelineIntegration: | |
| """Test integration with pipeline""" | |
| def test_pipeline_has_incident_db(self): | |
| """Test pipeline initializes incident learning""" | |
| from agent.pipeline_engine import OvergrowthPipeline | |
| pipeline = OvergrowthPipeline() | |
| assert hasattr(pipeline, 'incident_db') | |
| assert hasattr(pipeline, 'rca_analyzer') | |
| assert hasattr(pipeline, 'test_generator') | |
| def test_capture_on_validation_failure(self, tmp_path): | |
| """Test incident captured when validation fails""" | |
| from agent.pipeline_engine import OvergrowthPipeline, NetworkModel, NetworkIntent | |
| pipeline = OvergrowthPipeline() | |
| # Override incident DB to use temp directory | |
| pipeline.incident_db = IncidentDatabase(db_path=tmp_path / "incidents") | |
| # Create invalid network model | |
| intent = NetworkIntent( | |
| description="Test network with validation errors", | |
| business_requirements=["Invalid config"], | |
| constraints=[] | |
| ) | |
| model = NetworkModel( | |
| name="invalid-network", | |
| version="1.0.0", | |
| intent=intent, | |
| devices=[], # Empty devices = validation error | |
| vlans=[], | |
| subnets=[], | |
| routing={}, | |
| services=[] | |
| ) | |
| # Run validation (should fail) | |
| results = pipeline.stage0_preflight(model) | |
| # Should have captured incident | |
| incidents = pipeline.incident_db.get_all_incidents() | |
| # May capture incident for validation failures | |
| assert isinstance(incidents, list) | |
| def test_learn_from_recent_incidents(self, tmp_path): | |
| """Test learning from recent incidents""" | |
| from agent.pipeline_engine import OvergrowthPipeline | |
| pipeline = OvergrowthPipeline() | |
| pipeline.incident_db = IncidentDatabase(db_path=tmp_path / "incidents") | |
| # Add some test incidents | |
| for i in range(3): | |
| incident = Incident( | |
| id=f"test-{i}", | |
| timestamp=datetime.now().isoformat(), | |
| severity="medium", | |
| category="config_error", | |
| description=f"Test incident {i}" | |
| ) | |
| pipeline.incident_db.add_incident(incident) | |
| # Trigger learning | |
| learnings = pipeline.learn_from_recent_incidents(limit=5) | |
| assert 'total_incidents' in learnings | |
| assert 'unresolved' in learnings | |
| assert 'analyzed' in learnings | |
| assert learnings['total_incidents'] == 3 | |
| if __name__ == '__main__': | |
| pytest.main([__file__, '-v']) | |