Spaces:
Running
Running
| """ | |
| Tests for Ray distributed execution engine. | |
| Tests parallel config generation, Batfish analysis, deployments, | |
| progress tracking, error handling, and staggered rollouts. | |
| """ | |
| import pytest | |
| import time | |
| from typing import Dict, Any, List | |
| from agent.ray_executor import ( | |
| RayExecutor, | |
| TaskStatus, | |
| TaskResult, | |
| ExecutionProgress, | |
| ProgressTracker | |
| ) | |
| # Mock functions for testing | |
| def mock_config_template(device_data: Dict[str, Any]) -> str: | |
| """Mock config generation function""" | |
| hostname = device_data.get('hostname', 'unknown') | |
| role = device_data.get('role', 'leaf') | |
| return f""" | |
| hostname {hostname} | |
| ! | |
| interface Ethernet1 | |
| description {role} uplink | |
| ! | |
| """ | |
| def mock_config_template_with_error(device_data: Dict[str, Any]) -> str: | |
| """Mock config generation that fails for certain devices""" | |
| if 'error' in device_data.get('hostname', ''): | |
| raise ValueError("Simulated config generation error") | |
| return mock_config_template(device_data) | |
| class MockBatfishClient: | |
| """Mock Batfish client for testing""" | |
| def analyze_configs(self, configs: Dict[str, str]) -> Dict[str, Any]: | |
| """Mock analysis""" | |
| return { | |
| 'issues': [], | |
| 'warnings': [], | |
| 'validated': True | |
| } | |
| class MockBatfishClientWithError: | |
| """Mock Batfish client that fails occasionally""" | |
| def __init__(self): | |
| self.call_count = 0 | |
| def analyze_configs(self, configs: Dict[str, str]) -> Dict[str, Any]: | |
| """Mock analysis that fails every 3rd call""" | |
| self.call_count += 1 | |
| if self.call_count % 3 == 0: | |
| raise Exception("Simulated Batfish error") | |
| return {'issues': [], 'warnings': [], 'validated': True} | |
| class MockGNS3Client: | |
| """Mock GNS3 client for testing""" | |
| def apply_config(self, device_id: str, config: str) -> Dict[str, Any]: | |
| """Mock config deployment""" | |
| time.sleep(0.1) # Simulate network delay | |
| return { | |
| 'device_id': device_id, | |
| 'status': 'deployed', | |
| 'timestamp': time.time() | |
| } | |
| class MockGNS3ClientWithRetry: | |
| """Mock GNS3 client that requires retries""" | |
| def __init__(self, fail_count: int = 2): | |
| self.attempts = {} | |
| self.fail_count = fail_count | |
| def apply_config(self, device_id: str, config: str) -> Dict[str, Any]: | |
| """Mock deployment that succeeds after N failures""" | |
| if device_id not in self.attempts: | |
| self.attempts[device_id] = 0 | |
| self.attempts[device_id] += 1 | |
| if self.attempts[device_id] <= self.fail_count: | |
| raise Exception(f"Simulated deployment error (attempt {self.attempts[device_id]})") | |
| return { | |
| 'device_id': device_id, | |
| 'status': 'deployed', | |
| 'attempts': self.attempts[device_id] | |
| } | |
| def executor(): | |
| """Create Ray executor instance""" | |
| executor = RayExecutor() | |
| yield executor | |
| executor.shutdown() | |
| def sample_devices(): | |
| """Sample device data for testing""" | |
| return [ | |
| {'device_id': 'leaf-1', 'hostname': 'leaf-1', 'role': 'leaf', 'mgmt_ip': '10.0.1.1'}, | |
| {'device_id': 'leaf-2', 'hostname': 'leaf-2', 'role': 'leaf', 'mgmt_ip': '10.0.1.2'}, | |
| {'device_id': 'spine-1', 'hostname': 'spine-1', 'role': 'spine', 'mgmt_ip': '10.0.2.1'}, | |
| {'device_id': 'spine-2', 'hostname': 'spine-2', 'role': 'spine', 'mgmt_ip': '10.0.2.2'}, | |
| {'device_id': 'border-1', 'hostname': 'border-1', 'role': 'border', 'mgmt_ip': '10.0.3.1'}, | |
| ] | |
| def test_execution_progress_tracking(): | |
| """Test progress tracking calculations""" | |
| progress = ExecutionProgress(total_devices=100) | |
| # Initial state | |
| assert progress.completion_percentage == 0.0 | |
| assert progress.success_rate == 0.0 | |
| # Simulate some completions | |
| progress.completed = 50 | |
| progress.failed = 10 | |
| assert progress.completion_percentage == 50.0 | |
| assert progress.success_rate == pytest.approx(83.33, rel=0.1) | |
| # Convert to dict | |
| progress_dict = progress.to_dict() | |
| assert progress_dict['total_devices'] == 100 | |
| assert progress_dict['completed'] == 50 | |
| assert progress_dict['failed'] == 10 | |
| def test_ray_initialization(executor): | |
| """Test Ray runtime initialization""" | |
| executor.initialize() | |
| assert executor.initialized is True | |
| # Get cluster resources | |
| resources = executor.get_cluster_resources() | |
| assert 'available' in resources | |
| assert 'total' in resources | |
| assert resources['total'].get('CPU', 0) > 0 | |
| def test_parallel_config_generation(executor, sample_devices): | |
| """Test parallel config generation across devices""" | |
| results, progress = executor.parallel_config_generation( | |
| devices=sample_devices, | |
| template_fn=mock_config_template, | |
| batch_size=10 | |
| ) | |
| # Check all devices processed | |
| assert len(results) == len(sample_devices) | |
| # Check all succeeded | |
| success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) | |
| assert success_count == len(sample_devices) | |
| # Check progress | |
| assert progress['total_devices'] == len(sample_devices) | |
| assert progress['completed'] == len(sample_devices) | |
| assert progress['failed'] == 0 | |
| assert progress['completion_percentage'] == 100.0 | |
| # Check configs were generated | |
| for result in results: | |
| assert result.result is not None | |
| assert 'hostname' in result.result | |
| def test_parallel_config_generation_with_errors(executor): | |
| """Test parallel config generation with some failures""" | |
| devices = [ | |
| {'device_id': 'good-1', 'hostname': 'good-1'}, | |
| {'device_id': 'error-1', 'hostname': 'error-1'}, # Will fail | |
| {'device_id': 'good-2', 'hostname': 'good-2'}, | |
| ] | |
| results, progress = executor.parallel_config_generation( | |
| devices=devices, | |
| template_fn=mock_config_template_with_error, | |
| batch_size=10 | |
| ) | |
| assert len(results) == 3 | |
| # Check success/failure counts | |
| success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) | |
| failed_count = sum(1 for r in results if r.status == TaskStatus.FAILED) | |
| assert success_count == 2 | |
| assert failed_count == 1 | |
| # Check error message | |
| failed_result = [r for r in results if r.status == TaskStatus.FAILED][0] | |
| assert 'error' in failed_result.error.lower() | |
| def test_parallel_batfish_analysis(executor, sample_devices): | |
| """Test parallel Batfish analysis""" | |
| # Generate configs first | |
| configs = { | |
| device['device_id']: mock_config_template(device) | |
| for device in sample_devices | |
| } | |
| batfish_client = MockBatfishClient() | |
| results, progress = executor.parallel_batfish_analysis( | |
| configs=configs, | |
| batfish_client=batfish_client, | |
| batch_size=10 | |
| ) | |
| assert len(results) == len(sample_devices) | |
| # All should succeed with mock client | |
| success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) | |
| assert success_count == len(sample_devices) | |
| # Check analysis results | |
| for result in results: | |
| assert result.result is not None | |
| assert 'validated' in result.result | |
| def test_parallel_batfish_analysis_with_errors(executor): | |
| """Test parallel Batfish analysis with failures""" | |
| configs = { | |
| 'device-1': 'config 1', | |
| 'device-2': 'config 2', | |
| 'device-3': 'config 3', | |
| } | |
| batfish_client = MockBatfishClientWithError() | |
| results, progress = executor.parallel_batfish_analysis( | |
| configs=configs, | |
| batfish_client=batfish_client, | |
| batch_size=10 | |
| ) | |
| assert len(results) == 3 | |
| # Some should fail (but due to random execution order, may all succeed) | |
| # Just check that we got results for all devices | |
| assert progress['total_devices'] == 3 | |
| def test_parallel_deployment(executor, sample_devices): | |
| """Test parallel deployment to devices""" | |
| deployments = { | |
| device['device_id']: mock_config_template(device) | |
| for device in sample_devices | |
| } | |
| gns3_client = MockGNS3Client() | |
| results, progress = executor.parallel_deployment( | |
| deployments=deployments, | |
| gns3_client=gns3_client, | |
| batch_size=5 | |
| ) | |
| assert len(results) == len(sample_devices) | |
| # All should succeed | |
| success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) | |
| assert success_count == len(sample_devices) | |
| # Check deployment results | |
| for result in results: | |
| assert result.result is not None | |
| assert result.result['status'] == 'deployed' | |
| def test_parallel_deployment_with_retries(executor): | |
| """Test parallel deployment with automatic retries""" | |
| deployments = { | |
| 'device-1': 'config 1', | |
| 'device-2': 'config 2', | |
| } | |
| # Client that fails twice then succeeds | |
| gns3_client = MockGNS3ClientWithRetry(fail_count=2) | |
| results, progress = executor.parallel_deployment( | |
| deployments=deployments, | |
| gns3_client=gns3_client, | |
| batch_size=5, | |
| max_retries=3 | |
| ) | |
| assert len(results) == 2 | |
| # Should succeed after retries | |
| success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) | |
| assert success_count == 2 | |
| # Check retry counts | |
| for result in results: | |
| assert result.retry_count >= 2 | |
| def test_parallel_deployment_max_retries_exceeded(executor): | |
| """Test parallel deployment when max retries exceeded""" | |
| deployments = {'device-1': 'config 1'} | |
| # Client that always fails | |
| gns3_client = MockGNS3ClientWithRetry(fail_count=999) | |
| results, progress = executor.parallel_deployment( | |
| deployments=deployments, | |
| gns3_client=gns3_client, | |
| batch_size=5, | |
| max_retries=2 | |
| ) | |
| assert len(results) == 1 | |
| assert results[0].status == TaskStatus.FAILED | |
| assert 'retries' in results[0].error.lower() | |
| def test_staggered_rollout_success(executor, sample_devices): | |
| """Test staggered rollout with all stages succeeding""" | |
| deployments = { | |
| device['device_id']: mock_config_template(device) | |
| for device in sample_devices | |
| } | |
| gns3_client = MockGNS3Client() | |
| # Use small stages for 5 devices | |
| stages = [0.2, 0.6, 1.0] # 20%, 60%, 100% | |
| results, progress = executor.staggered_rollout( | |
| deployments=deployments, | |
| gns3_client=gns3_client, | |
| stages=stages | |
| ) | |
| # All devices should be deployed | |
| assert len(results) == len(sample_devices) | |
| success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) | |
| assert success_count == len(sample_devices) | |
| def test_staggered_rollout_failure_stops_deployment(executor): | |
| """Test staggered rollout stops on high failure rate""" | |
| # Create many devices to test staged rollout | |
| devices = [ | |
| {'device_id': f'device-{i}', 'hostname': f'device-{i}'} | |
| for i in range(20) | |
| ] | |
| deployments = { | |
| device['device_id']: mock_config_template(device) | |
| for device in devices | |
| } | |
| # Client that always fails | |
| gns3_client = MockGNS3ClientWithRetry(fail_count=999) | |
| stages = [0.1, 0.5, 1.0] # 10%, 50%, 100% | |
| results, progress = executor.staggered_rollout( | |
| deployments=deployments, | |
| gns3_client=gns3_client, | |
| stages=stages, | |
| validation_fn=None | |
| ) | |
| # Should stop after first stage fails | |
| # First stage = 10% of 20 = 2 devices | |
| assert len(results) <= 2 | |
| # All should have failed | |
| failed_count = sum(1 for r in results if r.status == TaskStatus.FAILED) | |
| assert failed_count == len(results) | |
| def test_staggered_rollout_with_validation(executor, sample_devices): | |
| """Test staggered rollout with validation function""" | |
| deployments = { | |
| device['device_id']: mock_config_template(device) | |
| for device in sample_devices | |
| } | |
| gns3_client = MockGNS3Client() | |
| validation_called = [] | |
| def validation_fn(device_ids: List[str], results: List[TaskResult]) -> bool: | |
| """Mock validation that tracks calls""" | |
| validation_called.append(len(device_ids)) | |
| # All validations pass | |
| return True | |
| stages = [0.2, 0.6, 1.0] | |
| results, progress = executor.staggered_rollout( | |
| deployments=deployments, | |
| gns3_client=gns3_client, | |
| stages=stages, | |
| validation_fn=validation_fn | |
| ) | |
| # All devices deployed | |
| assert len(results) == len(sample_devices) | |
| # Validation called multiple times (once per stage) | |
| assert len(validation_called) >= 2 | |
| def test_staggered_rollout_validation_failure_stops(executor, sample_devices): | |
| """Test staggered rollout stops when validation fails""" | |
| deployments = { | |
| device['device_id']: mock_config_template(device) | |
| for device in sample_devices | |
| } | |
| gns3_client = MockGNS3Client() | |
| def validation_fn(device_ids: List[str], results: List[TaskResult]) -> bool: | |
| """Validation that always fails""" | |
| return False | |
| stages = [0.2, 0.6, 1.0] | |
| results, progress = executor.staggered_rollout( | |
| deployments=deployments, | |
| gns3_client=gns3_client, | |
| stages=stages, | |
| validation_fn=validation_fn | |
| ) | |
| # Should only deploy first stage (20% of 5 = 1 device) | |
| assert len(results) == 1 | |
| def test_task_result_serialization(): | |
| """Test TaskResult can be serialized""" | |
| result = TaskResult( | |
| device_id='test-1', | |
| status=TaskStatus.SUCCESS, | |
| result={'config': 'test'}, | |
| duration_seconds=1.5 | |
| ) | |
| assert result.device_id == 'test-1' | |
| assert result.status == TaskStatus.SUCCESS | |
| assert result.duration_seconds == 1.5 | |
| assert result.retry_count == 0 | |
| def test_large_scale_config_generation(executor): | |
| """Test config generation scales to hundreds of devices""" | |
| # Create 100 devices | |
| devices = [ | |
| {'device_id': f'device-{i:03d}', 'hostname': f'device-{i:03d}', 'role': 'leaf'} | |
| for i in range(100) | |
| ] | |
| start_time = time.time() | |
| results, progress = executor.parallel_config_generation( | |
| devices=devices, | |
| template_fn=mock_config_template, | |
| batch_size=50 | |
| ) | |
| duration = time.time() - start_time | |
| # All should succeed | |
| assert len(results) == 100 | |
| success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS) | |
| assert success_count == 100 | |
| # Should complete reasonably quickly (parallel execution) | |
| # Serial execution would take much longer | |
| assert duration < 10.0 # Should be well under 10 seconds | |
| print(f"\nGenerated 100 configs in {duration:.2f} seconds") | |
| print(f"Average: {duration/100*1000:.1f}ms per device") | |
| def test_progress_tracking_time_estimates(): | |
| """Test progress tracking time estimation""" | |
| progress = ExecutionProgress(total_devices=100) | |
| # Simulate some work | |
| time.sleep(0.1) | |
| progress.completed = 25 | |
| # Should have time estimate | |
| eta = progress.estimated_time_remaining | |
| assert eta is not None | |
| assert eta > 0 | |
| # Complete more work | |
| progress.completed = 50 | |
| eta2 = progress.estimated_time_remaining | |
| # ETA should decrease | |
| assert eta2 < eta | |
| def test_executor_multiple_operations(executor, sample_devices): | |
| """Test running multiple operations sequentially""" | |
| # Config generation | |
| results1, _ = executor.parallel_config_generation( | |
| devices=sample_devices, | |
| template_fn=mock_config_template | |
| ) | |
| # Batfish analysis | |
| configs = {r.device_id: r.result for r in results1 if r.status == TaskStatus.SUCCESS} | |
| results2, _ = executor.parallel_batfish_analysis( | |
| configs=configs, | |
| batfish_client=MockBatfishClient() | |
| ) | |
| # Deployment | |
| results3, _ = executor.parallel_deployment( | |
| deployments=configs, | |
| gns3_client=MockGNS3Client() | |
| ) | |
| # All operations should succeed | |
| assert all(r.status == TaskStatus.SUCCESS for r in results1) | |
| assert all(r.status == TaskStatus.SUCCESS for r in results2) | |
| assert all(r.status == TaskStatus.SUCCESS for r in results3) | |