""" Tests for Ray distributed execution engine. Tests parallel config generation, Batfish analysis, deployments, progress tracking, error handling, and staggered rollouts. """ import pytest import time from typing import Dict, Any, List from agent.ray_executor import ( RayExecutor, TaskStatus, TaskResult, ExecutionProgress, ProgressTracker ) # Mock functions for testing def mock_config_template(device_data: Dict[str, Any]) -> str: """Mock config generation function""" hostname = device_data.get('hostname', 'unknown') role = device_data.get('role', 'leaf') return f""" hostname {hostname} ! interface Ethernet1 description {role} uplink ! """ def mock_config_template_with_error(device_data: Dict[str, Any]) -> str: """Mock config generation that fails for certain devices""" if 'error' in device_data.get('hostname', ''): raise ValueError("Simulated config generation error") return mock_config_template(device_data) class MockBatfishClient: """Mock Batfish client for testing""" def analyze_configs(self, configs: Dict[str, str]) -> Dict[str, Any]: """Mock analysis""" return { 'issues': [], 'warnings': [], 'validated': True } class MockBatfishClientWithError: """Mock Batfish client that fails occasionally""" def __init__(self): self.call_count = 0 def analyze_configs(self, configs: Dict[str, str]) -> Dict[str, Any]: """Mock analysis that fails every 3rd call""" self.call_count += 1 if self.call_count % 3 == 0: raise Exception("Simulated Batfish error") return {'issues': [], 'warnings': [], 'validated': True} class MockGNS3Client: """Mock GNS3 client for testing""" def apply_config(self, device_id: str, config: str) -> Dict[str, Any]: """Mock config deployment""" time.sleep(0.1) # Simulate network delay return { 'device_id': device_id, 'status': 'deployed', 'timestamp': time.time() } class MockGNS3ClientWithRetry: """Mock GNS3 client that requires retries""" def __init__(self, fail_count: int = 2): self.attempts = {} self.fail_count = fail_count def apply_config(self, device_id: str, config: str) -> Dict[str, Any]: """Mock deployment that succeeds after N failures""" if device_id not in self.attempts: self.attempts[device_id] = 0 self.attempts[device_id] += 1 if self.attempts[device_id] <= self.fail_count: raise Exception(f"Simulated deployment error (attempt {self.attempts[device_id]})") return { 'device_id': device_id, 'status': 'deployed', 'attempts': self.attempts[device_id] } @pytest.fixture def executor(): """Create Ray executor instance""" executor = RayExecutor() yield executor executor.shutdown() @pytest.fixture def sample_devices(): """Sample device data for testing""" return [ {'device_id': 'leaf-1', 'hostname': 'leaf-1', 'role': 'leaf', 'mgmt_ip': '10.0.1.1'}, {'device_id': 'leaf-2', 'hostname': 'leaf-2', 'role': 'leaf', 'mgmt_ip': '10.0.1.2'}, {'device_id': 'spine-1', 'hostname': 'spine-1', 'role': 'spine', 'mgmt_ip': '10.0.2.1'}, {'device_id': 'spine-2', 'hostname': 'spine-2', 'role': 'spine', 'mgmt_ip': '10.0.2.2'}, {'device_id': 'border-1', 'hostname': 'border-1', 'role': 'border', 'mgmt_ip': '10.0.3.1'}, ] def test_execution_progress_tracking(): """Test progress tracking calculations""" progress = ExecutionProgress(total_devices=100) # Initial state assert progress.completion_percentage == 0.0 assert progress.success_rate == 0.0 # Simulate some completions progress.completed = 50 progress.failed = 10 assert progress.completion_percentage == 50.0 assert progress.success_rate == pytest.approx(83.33, rel=0.1) # Convert to dict progress_dict = progress.to_dict() assert progress_dict['total_devices'] == 100 assert progress_dict['completed'] == 50 assert progress_dict['failed'] == 10 def test_ray_initialization(executor): """Test Ray runtime initialization""" executor.initialize() assert executor.initialized is True # Get cluster resources resources = executor.get_cluster_resources() assert 'available' in resources assert 'total' in resources assert resources['total'].get('CPU', 0) > 0 def test_parallel_config_generation(executor, sample_devices): """Test parallel config generation across devices""" results, progress = executor.parallel_config_generation( devices=sample_devices, template_fn=mock_config_template, batch_size=10 ) # Check all devices processed assert len(results) == len(sample_devices) # Check all succeeded success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) assert success_count == len(sample_devices) # Check progress assert progress['total_devices'] == len(sample_devices) assert progress['completed'] == len(sample_devices) assert progress['failed'] == 0 assert progress['completion_percentage'] == 100.0 # Check configs were generated for result in results: assert result.result is not None assert 'hostname' in result.result def test_parallel_config_generation_with_errors(executor): """Test parallel config generation with some failures""" devices = [ {'device_id': 'good-1', 'hostname': 'good-1'}, {'device_id': 'error-1', 'hostname': 'error-1'}, # Will fail {'device_id': 'good-2', 'hostname': 'good-2'}, ] results, progress = executor.parallel_config_generation( devices=devices, template_fn=mock_config_template_with_error, batch_size=10 ) assert len(results) == 3 # Check success/failure counts success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) failed_count = sum(1 for r in results if r.status == TaskStatus.FAILED) assert success_count == 2 assert failed_count == 1 # Check error message failed_result = [r for r in results if r.status == TaskStatus.FAILED][0] assert 'error' in failed_result.error.lower() def test_parallel_batfish_analysis(executor, sample_devices): """Test parallel Batfish analysis""" # Generate configs first configs = { device['device_id']: mock_config_template(device) for device in sample_devices } batfish_client = MockBatfishClient() results, progress = executor.parallel_batfish_analysis( configs=configs, batfish_client=batfish_client, batch_size=10 ) assert len(results) == len(sample_devices) # All should succeed with mock client success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) assert success_count == len(sample_devices) # Check analysis results for result in results: assert result.result is not None assert 'validated' in result.result def test_parallel_batfish_analysis_with_errors(executor): """Test parallel Batfish analysis with failures""" configs = { 'device-1': 'config 1', 'device-2': 'config 2', 'device-3': 'config 3', } batfish_client = MockBatfishClientWithError() results, progress = executor.parallel_batfish_analysis( configs=configs, batfish_client=batfish_client, batch_size=10 ) assert len(results) == 3 # Some should fail (but due to random execution order, may all succeed) # Just check that we got results for all devices assert progress['total_devices'] == 3 def test_parallel_deployment(executor, sample_devices): """Test parallel deployment to devices""" deployments = { device['device_id']: mock_config_template(device) for device in sample_devices } gns3_client = MockGNS3Client() results, progress = executor.parallel_deployment( deployments=deployments, gns3_client=gns3_client, batch_size=5 ) assert len(results) == len(sample_devices) # All should succeed success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) assert success_count == len(sample_devices) # Check deployment results for result in results: assert result.result is not None assert result.result['status'] == 'deployed' def test_parallel_deployment_with_retries(executor): """Test parallel deployment with automatic retries""" deployments = { 'device-1': 'config 1', 'device-2': 'config 2', } # Client that fails twice then succeeds gns3_client = MockGNS3ClientWithRetry(fail_count=2) results, progress = executor.parallel_deployment( deployments=deployments, gns3_client=gns3_client, batch_size=5, max_retries=3 ) assert len(results) == 2 # Should succeed after retries success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) assert success_count == 2 # Check retry counts for result in results: assert result.retry_count >= 2 def test_parallel_deployment_max_retries_exceeded(executor): """Test parallel deployment when max retries exceeded""" deployments = {'device-1': 'config 1'} # Client that always fails gns3_client = MockGNS3ClientWithRetry(fail_count=999) results, progress = executor.parallel_deployment( deployments=deployments, gns3_client=gns3_client, batch_size=5, max_retries=2 ) assert len(results) == 1 assert results[0].status == TaskStatus.FAILED assert 'retries' in results[0].error.lower() def test_staggered_rollout_success(executor, sample_devices): """Test staggered rollout with all stages succeeding""" deployments = { device['device_id']: mock_config_template(device) for device in sample_devices } gns3_client = MockGNS3Client() # Use small stages for 5 devices stages = [0.2, 0.6, 1.0] # 20%, 60%, 100% results, progress = executor.staggered_rollout( deployments=deployments, gns3_client=gns3_client, stages=stages ) # All devices should be deployed assert len(results) == len(sample_devices) success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) assert success_count == len(sample_devices) def test_staggered_rollout_failure_stops_deployment(executor): """Test staggered rollout stops on high failure rate""" # Create many devices to test staged rollout devices = [ {'device_id': f'device-{i}', 'hostname': f'device-{i}'} for i in range(20) ] deployments = { device['device_id']: mock_config_template(device) for device in devices } # Client that always fails gns3_client = MockGNS3ClientWithRetry(fail_count=999) stages = [0.1, 0.5, 1.0] # 10%, 50%, 100% results, progress = executor.staggered_rollout( deployments=deployments, gns3_client=gns3_client, stages=stages, validation_fn=None ) # Should stop after first stage fails # First stage = 10% of 20 = 2 devices assert len(results) <= 2 # All should have failed failed_count = sum(1 for r in results if r.status == TaskStatus.FAILED) assert failed_count == len(results) def test_staggered_rollout_with_validation(executor, sample_devices): """Test staggered rollout with validation function""" deployments = { device['device_id']: mock_config_template(device) for device in sample_devices } gns3_client = MockGNS3Client() validation_called = [] def validation_fn(device_ids: List[str], results: List[TaskResult]) -> bool: """Mock validation that tracks calls""" validation_called.append(len(device_ids)) # All validations pass return True stages = [0.2, 0.6, 1.0] results, progress = executor.staggered_rollout( deployments=deployments, gns3_client=gns3_client, stages=stages, validation_fn=validation_fn ) # All devices deployed assert len(results) == len(sample_devices) # Validation called multiple times (once per stage) assert len(validation_called) >= 2 def test_staggered_rollout_validation_failure_stops(executor, sample_devices): """Test staggered rollout stops when validation fails""" deployments = { device['device_id']: mock_config_template(device) for device in sample_devices } gns3_client = MockGNS3Client() def validation_fn(device_ids: List[str], results: List[TaskResult]) -> bool: """Validation that always fails""" return False stages = [0.2, 0.6, 1.0] results, progress = executor.staggered_rollout( deployments=deployments, gns3_client=gns3_client, stages=stages, validation_fn=validation_fn ) # Should only deploy first stage (20% of 5 = 1 device) assert len(results) == 1 def test_task_result_serialization(): """Test TaskResult can be serialized""" result = TaskResult( device_id='test-1', status=TaskStatus.SUCCESS, result={'config': 'test'}, duration_seconds=1.5 ) assert result.device_id == 'test-1' assert result.status == TaskStatus.SUCCESS assert result.duration_seconds == 1.5 assert result.retry_count == 0 def test_large_scale_config_generation(executor): """Test config generation scales to hundreds of devices""" # Create 100 devices devices = [ {'device_id': f'device-{i:03d}', 'hostname': f'device-{i:03d}', 'role': 'leaf'} for i in range(100) ] start_time = time.time() results, progress = executor.parallel_config_generation( devices=devices, template_fn=mock_config_template, batch_size=50 ) duration = time.time() - start_time # All should succeed assert len(results) == 100 success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) assert success_count == 100 # Should complete reasonably quickly (parallel execution) # Serial execution would take much longer assert duration < 10.0 # Should be well under 10 seconds print(f"\nGenerated 100 configs in {duration:.2f} seconds") print(f"Average: {duration/100*1000:.1f}ms per device") def test_progress_tracking_time_estimates(): """Test progress tracking time estimation""" progress = ExecutionProgress(total_devices=100) # Simulate some work time.sleep(0.1) progress.completed = 25 # Should have time estimate eta = progress.estimated_time_remaining assert eta is not None assert eta > 0 # Complete more work progress.completed = 50 eta2 = progress.estimated_time_remaining # ETA should decrease assert eta2 < eta def test_executor_multiple_operations(executor, sample_devices): """Test running multiple operations sequentially""" # Config generation results1, _ = executor.parallel_config_generation( devices=sample_devices, template_fn=mock_config_template ) # Batfish analysis configs = {r.device_id: r.result for r in results1 if r.status == TaskStatus.SUCCESS} results2, _ = executor.parallel_batfish_analysis( configs=configs, batfish_client=MockBatfishClient() ) # Deployment results3, _ = executor.parallel_deployment( deployments=configs, gns3_client=MockGNS3Client() ) # All operations should succeed assert all(r.status == TaskStatus.SUCCESS for r in results1) assert all(r.status == TaskStatus.SUCCESS for r in results2) assert all(r.status == TaskStatus.SUCCESS for r in results3)