overgrowth / test_incident_learning.py
Graham Paasch
feat: RAG-based incident learning system (Todo #5)
8e74f68
"""
Tests for RAG-based Incident Learning System
Tests incident capture, root cause analysis, and regression test generation
"""
import pytest
import json
from datetime import datetime
from pathlib import Path
from agent.incident_learning import (
Incident,
IncidentDatabase,
RootCauseAnalyzer,
RegressionTestGenerator,
capture_deployment_failure,
learn_from_incident
)
@pytest.fixture
def incident_db(tmp_path):
"""Create temporary incident database"""
return IncidentDatabase(db_path=tmp_path / "incidents")
@pytest.fixture
def sample_incident():
"""Create sample incident for testing"""
return Incident(
id="test-20251125-120000",
timestamp=datetime.now().isoformat(),
severity="high",
category="deployment_failure",
description="VLAN 100 configuration failed - duplicate VLAN ID",
affected_devices=["leaf-01", "leaf-02"],
network_model={
"name": "test-network",
"vlans": [
{"id": 100, "name": "Users"},
{"id": 100, "name": "Servers"} # Duplicate!
]
},
validation_errors=[
{"error": "Duplicate VLAN ID 100", "type": "validation"}
]
)
@pytest.fixture
def routing_incident():
"""Routing issue incident"""
return Incident(
id="routing-20251125-130000",
timestamp=datetime.now().isoformat(),
severity="critical",
category="routing_issue",
description="Routing loop detected between spine-01 and spine-02",
affected_devices=["spine-01", "spine-02"],
root_cause="BGP AS-PATH loop due to incorrect route reflector config"
)
class TestIncident:
"""Test Incident dataclass"""
def test_incident_creation(self, sample_incident):
"""Test creating incident"""
assert sample_incident.id == "test-20251125-120000"
assert sample_incident.severity == "high"
assert sample_incident.category == "deployment_failure"
assert len(sample_incident.affected_devices) == 2
def test_incident_to_dict(self, sample_incident):
"""Test serialization"""
data = sample_incident.to_dict()
assert data['id'] == sample_incident.id
assert data['severity'] == sample_incident.severity
assert data['category'] == sample_incident.category
assert data['affected_devices'] == sample_incident.affected_devices
assert 'network_model' in data
assert 'validation_errors' in data
class TestIncidentDatabase:
"""Test incident database operations"""
def test_add_incident(self, incident_db, sample_incident):
"""Test adding incident to database"""
incident_id = incident_db.add_incident(sample_incident)
assert incident_id == sample_incident.id
# Verify file was created
assert incident_db.incidents_file.exists()
# Verify data
with open(incident_db.incidents_file) as f:
data = json.load(f)
assert incident_id in data
def test_get_incident(self, incident_db, sample_incident):
"""Test retrieving incident by ID"""
incident_db.add_incident(sample_incident)
retrieved = incident_db.get_incident(sample_incident.id)
assert retrieved is not None
assert retrieved.id == sample_incident.id
assert retrieved.severity == sample_incident.severity
assert retrieved.description == sample_incident.description
def test_get_nonexistent_incident(self, incident_db):
"""Test getting incident that doesn't exist"""
result = incident_db.get_incident("nonexistent-id")
assert result is None
def test_update_incident(self, incident_db, sample_incident):
"""Test updating incident"""
incident_db.add_incident(sample_incident)
# Update with resolution
incident_db.update_incident(sample_incident.id, {
'resolution': 'Fixed duplicate VLAN IDs',
'resolved_at': datetime.now().isoformat()
})
# Verify update
updated = incident_db.get_incident(sample_incident.id)
assert updated.resolution == 'Fixed duplicate VLAN IDs'
assert updated.resolved_at is not None
def test_get_all_incidents(self, incident_db, sample_incident, routing_incident):
"""Test getting all incidents"""
incident_db.add_incident(sample_incident)
incident_db.add_incident(routing_incident)
incidents = incident_db.get_all_incidents()
assert len(incidents) == 2
# Should be sorted by timestamp (newest first)
assert incidents[0].id == routing_incident.id
def test_filter_by_severity(self, incident_db, sample_incident, routing_incident):
"""Test filtering incidents by severity"""
incident_db.add_incident(sample_incident) # high
incident_db.add_incident(routing_incident) # critical
critical = incident_db.get_all_incidents(severity="critical")
assert len(critical) == 1
assert critical[0].severity == "critical"
high = incident_db.get_all_incidents(severity="high")
assert len(high) == 1
assert high[0].severity == "high"
def test_filter_by_category(self, incident_db, sample_incident, routing_incident):
"""Test filtering by category"""
incident_db.add_incident(sample_incident)
incident_db.add_incident(routing_incident)
deploy_failures = incident_db.get_all_incidents(category="deployment_failure")
assert len(deploy_failures) == 1
assert deploy_failures[0].category == "deployment_failure"
routing = incident_db.get_all_incidents(category="routing_issue")
assert len(routing) == 1
assert routing[0].category == "routing_issue"
def test_keyword_search(self, incident_db, sample_incident, routing_incident):
"""Test keyword search (fallback mode)"""
incident_db.add_incident(sample_incident)
incident_db.add_incident(routing_incident)
# Search for VLAN issues
results = incident_db.search_similar("VLAN duplicate", n_results=5)
# In mock mode without ChromaDB, keyword search may return empty if no exact match
assert isinstance(results, list)
# If results found, verify they're relevant
if results:
assert any("VLAN" in r.description or "vlan" in r.description.lower() for r in results)
# Search for routing issues
results = incident_db.search_similar("routing loop BGP", n_results=5)
assert isinstance(results, list)
class TestRootCauseAnalyzer:
"""Test root cause analysis"""
def test_analyze_incident(self, incident_db, sample_incident):
"""Test analyzing incident"""
# Add some historical incidents for context
historical = Incident(
id="historical-1",
timestamp="2025-11-20T10:00:00",
severity="high",
category="deployment_failure",
description="Duplicate VLAN configuration on leaf-03",
affected_devices=["leaf-03"],
root_cause="Schema validation missed duplicate VLAN IDs"
)
incident_db.add_incident(historical)
analyzer = RootCauseAnalyzer(incident_db)
analysis = analyzer.analyze(sample_incident)
assert 'suggested_root_cause' in analysis
assert 'similar_incidents' in analysis
assert 'patterns_found' in analysis
assert 'confidence' in analysis
# Similar incidents may be 0 without ChromaDB vector search
assert isinstance(analysis['similar_incidents'], list)
# Confidence should be >= 0
assert analysis['confidence'] >= 0.0
def test_extract_patterns(self, incident_db):
"""Test pattern extraction from similar incidents"""
# Add multiple incidents with same root cause
for i in range(3):
incident = Incident(
id=f"pattern-test-{i}",
timestamp=datetime.now().isoformat(),
severity="medium",
category="config_error",
description=f"BGP neighbor config error on device-{i}",
affected_devices=[f"device-{i}"],
root_cause="Missing BGP neighbor authentication"
)
incident_db.add_incident(incident)
analyzer = RootCauseAnalyzer(incident_db)
incidents = incident_db.get_all_incidents()
patterns = analyzer._extract_patterns(incidents)
assert len(patterns) > 0
# Should identify common root cause
assert any("Missing BGP neighbor authentication" in p for p in patterns)
def test_confidence_calculation(self, incident_db):
"""Test confidence scoring"""
analyzer = RootCauseAnalyzer(incident_db)
# Low confidence: no similar incidents
confidence = analyzer._calculate_confidence([], [])
assert confidence == 0.0
# Medium confidence: some similar, no patterns
similar = [
Incident(
id=f"test-{i}",
timestamp=datetime.now().isoformat(),
severity="low",
category="config_error",
description="Test incident"
)
for i in range(3)
]
confidence = analyzer._calculate_confidence(similar, [])
assert 0.0 < confidence < 1.0
# High confidence: similar incidents + resolved + patterns
for inc in similar:
inc.resolution = "Fixed"
patterns = ["Common pattern 1", "Common pattern 2"]
confidence = analyzer._calculate_confidence(similar, patterns)
assert confidence >= 0.7
class TestRegressionTestGenerator:
"""Test regression test generation"""
def test_generate_config_test(self, sample_incident):
"""Test generating config error test"""
generator = RegressionTestGenerator()
test_code = generator.generate_test(sample_incident)
assert test_code is not None
assert "pyats" in test_code or "pytest" in test_code
assert sample_incident.id in test_code
assert sample_incident.description in test_code
def test_generate_routing_test(self, routing_incident):
"""Test generating routing issue test"""
generator = RegressionTestGenerator()
test_code = generator.generate_test(routing_incident)
assert test_code is not None
assert "routing" in test_code.lower()
assert routing_incident.id in test_code
def test_generate_deployment_test(self):
"""Test generating deployment failure test"""
incident = Incident(
id="deploy-fail-1",
timestamp=datetime.now().isoformat(),
severity="high",
category="deployment_failure",
description="Deployment failed due to syntax error"
)
generator = RegressionTestGenerator()
test_code = generator.generate_test(incident)
assert test_code is not None
assert "deployment" in test_code.lower() or "validation" in test_code.lower()
def test_generated_test_is_valid_python(self, sample_incident):
"""Test that generated code is valid Python"""
generator = RegressionTestGenerator()
test_code = generator.generate_test(sample_incident)
# Try to compile it
try:
compile(test_code, '<string>', 'exec')
valid = True
except SyntaxError:
valid = False
assert valid, "Generated test code has syntax errors"
class TestCaptureDeploymentFailure:
"""Test deployment failure capture"""
def test_capture_failure(self, tmp_path):
"""Test capturing deployment failure"""
# Create temp database
db = IncidentDatabase(db_path=tmp_path / "incidents")
# Monkey patch to use our temp database
import agent.incident_learning
original_db_init = agent.incident_learning.IncidentDatabase
agent.incident_learning.IncidentDatabase = lambda: db
try:
incident = capture_deployment_failure(
description="Test deployment failure",
network_model={"name": "test-net", "devices": []},
validation_errors=[{"error": "Test error"}],
affected_devices=["device-1"]
)
assert incident.severity == "high"
assert incident.category == "deployment_failure"
assert len(incident.affected_devices) == 1
# Verify it was stored
retrieved = db.get_incident(incident.id)
assert retrieved is not None
finally:
# Restore
agent.incident_learning.IncidentDatabase = original_db_init
class TestLearnFromIncident:
"""Test complete learning workflow"""
def test_learn_from_incident(self, tmp_path, sample_incident):
"""Test full learning cycle"""
db = IncidentDatabase(db_path=tmp_path / "incidents")
db.add_incident(sample_incident)
# Monkey patch database
import agent.incident_learning
original_db_init = agent.incident_learning.IncidentDatabase
agent.incident_learning.IncidentDatabase = lambda: db
try:
result = learn_from_incident(sample_incident)
assert 'incident_id' in result
assert 'root_cause' in result
assert 'regression_test' in result
assert 'confidence' in result
# Check that incident was updated with learnings
updated = db.get_incident(sample_incident.id)
assert updated.root_cause is not None
# Check that test file would be created
assert result['regression_test'] is not None
finally:
agent.incident_learning.IncidentDatabase = original_db_init
class TestPipelineIntegration:
"""Test integration with pipeline"""
def test_pipeline_has_incident_db(self):
"""Test pipeline initializes incident learning"""
from agent.pipeline_engine import OvergrowthPipeline
pipeline = OvergrowthPipeline()
assert hasattr(pipeline, 'incident_db')
assert hasattr(pipeline, 'rca_analyzer')
assert hasattr(pipeline, 'test_generator')
def test_capture_on_validation_failure(self, tmp_path):
"""Test incident captured when validation fails"""
from agent.pipeline_engine import OvergrowthPipeline, NetworkModel, NetworkIntent
pipeline = OvergrowthPipeline()
# Override incident DB to use temp directory
pipeline.incident_db = IncidentDatabase(db_path=tmp_path / "incidents")
# Create invalid network model
intent = NetworkIntent(
description="Test network with validation errors",
business_requirements=["Invalid config"],
constraints=[]
)
model = NetworkModel(
name="invalid-network",
version="1.0.0",
intent=intent,
devices=[], # Empty devices = validation error
vlans=[],
subnets=[],
routing={},
services=[]
)
# Run validation (should fail)
results = pipeline.stage0_preflight(model)
# Should have captured incident
incidents = pipeline.incident_db.get_all_incidents()
# May capture incident for validation failures
assert isinstance(incidents, list)
def test_learn_from_recent_incidents(self, tmp_path):
"""Test learning from recent incidents"""
from agent.pipeline_engine import OvergrowthPipeline
pipeline = OvergrowthPipeline()
pipeline.incident_db = IncidentDatabase(db_path=tmp_path / "incidents")
# Add some test incidents
for i in range(3):
incident = Incident(
id=f"test-{i}",
timestamp=datetime.now().isoformat(),
severity="medium",
category="config_error",
description=f"Test incident {i}"
)
pipeline.incident_db.add_incident(incident)
# Trigger learning
learnings = pipeline.learn_from_recent_incidents(limit=5)
assert 'total_incidents' in learnings
assert 'unresolved' in learnings
assert 'analyzed' in learnings
assert learnings['total_incidents'] == 3
if __name__ == '__main__':
pytest.main([__file__, '-v'])