overgrowth / test_ray_executor.py
Graham Paasch
feat: Ray distributed execution for hyperscale deployments
a660fc7
"""
Tests for Ray distributed execution engine.
Tests parallel config generation, Batfish analysis, deployments,
progress tracking, error handling, and staggered rollouts.
"""
import pytest
import time
from typing import Dict, Any, List
from agent.ray_executor import (
RayExecutor,
TaskStatus,
TaskResult,
ExecutionProgress,
ProgressTracker
)
# Mock functions for testing
def mock_config_template(device_data: Dict[str, Any]) -> str:
"""Mock config generation function"""
hostname = device_data.get('hostname', 'unknown')
role = device_data.get('role', 'leaf')
return f"""
hostname {hostname}
!
interface Ethernet1
description {role} uplink
!
"""
def mock_config_template_with_error(device_data: Dict[str, Any]) -> str:
"""Mock config generation that fails for certain devices"""
if 'error' in device_data.get('hostname', ''):
raise ValueError("Simulated config generation error")
return mock_config_template(device_data)
class MockBatfishClient:
"""Mock Batfish client for testing"""
def analyze_configs(self, configs: Dict[str, str]) -> Dict[str, Any]:
"""Mock analysis"""
return {
'issues': [],
'warnings': [],
'validated': True
}
class MockBatfishClientWithError:
"""Mock Batfish client that fails occasionally"""
def __init__(self):
self.call_count = 0
def analyze_configs(self, configs: Dict[str, str]) -> Dict[str, Any]:
"""Mock analysis that fails every 3rd call"""
self.call_count += 1
if self.call_count % 3 == 0:
raise Exception("Simulated Batfish error")
return {'issues': [], 'warnings': [], 'validated': True}
class MockGNS3Client:
"""Mock GNS3 client for testing"""
def apply_config(self, device_id: str, config: str) -> Dict[str, Any]:
"""Mock config deployment"""
time.sleep(0.1) # Simulate network delay
return {
'device_id': device_id,
'status': 'deployed',
'timestamp': time.time()
}
class MockGNS3ClientWithRetry:
"""Mock GNS3 client that requires retries"""
def __init__(self, fail_count: int = 2):
self.attempts = {}
self.fail_count = fail_count
def apply_config(self, device_id: str, config: str) -> Dict[str, Any]:
"""Mock deployment that succeeds after N failures"""
if device_id not in self.attempts:
self.attempts[device_id] = 0
self.attempts[device_id] += 1
if self.attempts[device_id] <= self.fail_count:
raise Exception(f"Simulated deployment error (attempt {self.attempts[device_id]})")
return {
'device_id': device_id,
'status': 'deployed',
'attempts': self.attempts[device_id]
}
@pytest.fixture
def executor():
"""Create Ray executor instance"""
executor = RayExecutor()
yield executor
executor.shutdown()
@pytest.fixture
def sample_devices():
"""Sample device data for testing"""
return [
{'device_id': 'leaf-1', 'hostname': 'leaf-1', 'role': 'leaf', 'mgmt_ip': '10.0.1.1'},
{'device_id': 'leaf-2', 'hostname': 'leaf-2', 'role': 'leaf', 'mgmt_ip': '10.0.1.2'},
{'device_id': 'spine-1', 'hostname': 'spine-1', 'role': 'spine', 'mgmt_ip': '10.0.2.1'},
{'device_id': 'spine-2', 'hostname': 'spine-2', 'role': 'spine', 'mgmt_ip': '10.0.2.2'},
{'device_id': 'border-1', 'hostname': 'border-1', 'role': 'border', 'mgmt_ip': '10.0.3.1'},
]
def test_execution_progress_tracking():
"""Test progress tracking calculations"""
progress = ExecutionProgress(total_devices=100)
# Initial state
assert progress.completion_percentage == 0.0
assert progress.success_rate == 0.0
# Simulate some completions
progress.completed = 50
progress.failed = 10
assert progress.completion_percentage == 50.0
assert progress.success_rate == pytest.approx(83.33, rel=0.1)
# Convert to dict
progress_dict = progress.to_dict()
assert progress_dict['total_devices'] == 100
assert progress_dict['completed'] == 50
assert progress_dict['failed'] == 10
def test_ray_initialization(executor):
"""Test Ray runtime initialization"""
executor.initialize()
assert executor.initialized is True
# Get cluster resources
resources = executor.get_cluster_resources()
assert 'available' in resources
assert 'total' in resources
assert resources['total'].get('CPU', 0) > 0
def test_parallel_config_generation(executor, sample_devices):
"""Test parallel config generation across devices"""
results, progress = executor.parallel_config_generation(
devices=sample_devices,
template_fn=mock_config_template,
batch_size=10
)
# Check all devices processed
assert len(results) == len(sample_devices)
# Check all succeeded
success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS)
assert success_count == len(sample_devices)
# Check progress
assert progress['total_devices'] == len(sample_devices)
assert progress['completed'] == len(sample_devices)
assert progress['failed'] == 0
assert progress['completion_percentage'] == 100.0
# Check configs were generated
for result in results:
assert result.result is not None
assert 'hostname' in result.result
def test_parallel_config_generation_with_errors(executor):
"""Test parallel config generation with some failures"""
devices = [
{'device_id': 'good-1', 'hostname': 'good-1'},
{'device_id': 'error-1', 'hostname': 'error-1'}, # Will fail
{'device_id': 'good-2', 'hostname': 'good-2'},
]
results, progress = executor.parallel_config_generation(
devices=devices,
template_fn=mock_config_template_with_error,
batch_size=10
)
assert len(results) == 3
# Check success/failure counts
success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS)
failed_count = sum(1 for r in results if r.status == TaskStatus.FAILED)
assert success_count == 2
assert failed_count == 1
# Check error message
failed_result = [r for r in results if r.status == TaskStatus.FAILED][0]
assert 'error' in failed_result.error.lower()
def test_parallel_batfish_analysis(executor, sample_devices):
"""Test parallel Batfish analysis"""
# Generate configs first
configs = {
device['device_id']: mock_config_template(device)
for device in sample_devices
}
batfish_client = MockBatfishClient()
results, progress = executor.parallel_batfish_analysis(
configs=configs,
batfish_client=batfish_client,
batch_size=10
)
assert len(results) == len(sample_devices)
# All should succeed with mock client
success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS)
assert success_count == len(sample_devices)
# Check analysis results
for result in results:
assert result.result is not None
assert 'validated' in result.result
def test_parallel_batfish_analysis_with_errors(executor):
"""Test parallel Batfish analysis with failures"""
configs = {
'device-1': 'config 1',
'device-2': 'config 2',
'device-3': 'config 3',
}
batfish_client = MockBatfishClientWithError()
results, progress = executor.parallel_batfish_analysis(
configs=configs,
batfish_client=batfish_client,
batch_size=10
)
assert len(results) == 3
# Some should fail (but due to random execution order, may all succeed)
# Just check that we got results for all devices
assert progress['total_devices'] == 3
def test_parallel_deployment(executor, sample_devices):
"""Test parallel deployment to devices"""
deployments = {
device['device_id']: mock_config_template(device)
for device in sample_devices
}
gns3_client = MockGNS3Client()
results, progress = executor.parallel_deployment(
deployments=deployments,
gns3_client=gns3_client,
batch_size=5
)
assert len(results) == len(sample_devices)
# All should succeed
success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS)
assert success_count == len(sample_devices)
# Check deployment results
for result in results:
assert result.result is not None
assert result.result['status'] == 'deployed'
def test_parallel_deployment_with_retries(executor):
"""Test parallel deployment with automatic retries"""
deployments = {
'device-1': 'config 1',
'device-2': 'config 2',
}
# Client that fails twice then succeeds
gns3_client = MockGNS3ClientWithRetry(fail_count=2)
results, progress = executor.parallel_deployment(
deployments=deployments,
gns3_client=gns3_client,
batch_size=5,
max_retries=3
)
assert len(results) == 2
# Should succeed after retries
success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS)
assert success_count == 2
# Check retry counts
for result in results:
assert result.retry_count >= 2
def test_parallel_deployment_max_retries_exceeded(executor):
"""Test parallel deployment when max retries exceeded"""
deployments = {'device-1': 'config 1'}
# Client that always fails
gns3_client = MockGNS3ClientWithRetry(fail_count=999)
results, progress = executor.parallel_deployment(
deployments=deployments,
gns3_client=gns3_client,
batch_size=5,
max_retries=2
)
assert len(results) == 1
assert results[0].status == TaskStatus.FAILED
assert 'retries' in results[0].error.lower()
def test_staggered_rollout_success(executor, sample_devices):
"""Test staggered rollout with all stages succeeding"""
deployments = {
device['device_id']: mock_config_template(device)
for device in sample_devices
}
gns3_client = MockGNS3Client()
# Use small stages for 5 devices
stages = [0.2, 0.6, 1.0] # 20%, 60%, 100%
results, progress = executor.staggered_rollout(
deployments=deployments,
gns3_client=gns3_client,
stages=stages
)
# All devices should be deployed
assert len(results) == len(sample_devices)
success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS)
assert success_count == len(sample_devices)
def test_staggered_rollout_failure_stops_deployment(executor):
"""Test staggered rollout stops on high failure rate"""
# Create many devices to test staged rollout
devices = [
{'device_id': f'device-{i}', 'hostname': f'device-{i}'}
for i in range(20)
]
deployments = {
device['device_id']: mock_config_template(device)
for device in devices
}
# Client that always fails
gns3_client = MockGNS3ClientWithRetry(fail_count=999)
stages = [0.1, 0.5, 1.0] # 10%, 50%, 100%
results, progress = executor.staggered_rollout(
deployments=deployments,
gns3_client=gns3_client,
stages=stages,
validation_fn=None
)
# Should stop after first stage fails
# First stage = 10% of 20 = 2 devices
assert len(results) <= 2
# All should have failed
failed_count = sum(1 for r in results if r.status == TaskStatus.FAILED)
assert failed_count == len(results)
def test_staggered_rollout_with_validation(executor, sample_devices):
"""Test staggered rollout with validation function"""
deployments = {
device['device_id']: mock_config_template(device)
for device in sample_devices
}
gns3_client = MockGNS3Client()
validation_called = []
def validation_fn(device_ids: List[str], results: List[TaskResult]) -> bool:
"""Mock validation that tracks calls"""
validation_called.append(len(device_ids))
# All validations pass
return True
stages = [0.2, 0.6, 1.0]
results, progress = executor.staggered_rollout(
deployments=deployments,
gns3_client=gns3_client,
stages=stages,
validation_fn=validation_fn
)
# All devices deployed
assert len(results) == len(sample_devices)
# Validation called multiple times (once per stage)
assert len(validation_called) >= 2
def test_staggered_rollout_validation_failure_stops(executor, sample_devices):
"""Test staggered rollout stops when validation fails"""
deployments = {
device['device_id']: mock_config_template(device)
for device in sample_devices
}
gns3_client = MockGNS3Client()
def validation_fn(device_ids: List[str], results: List[TaskResult]) -> bool:
"""Validation that always fails"""
return False
stages = [0.2, 0.6, 1.0]
results, progress = executor.staggered_rollout(
deployments=deployments,
gns3_client=gns3_client,
stages=stages,
validation_fn=validation_fn
)
# Should only deploy first stage (20% of 5 = 1 device)
assert len(results) == 1
def test_task_result_serialization():
"""Test TaskResult can be serialized"""
result = TaskResult(
device_id='test-1',
status=TaskStatus.SUCCESS,
result={'config': 'test'},
duration_seconds=1.5
)
assert result.device_id == 'test-1'
assert result.status == TaskStatus.SUCCESS
assert result.duration_seconds == 1.5
assert result.retry_count == 0
def test_large_scale_config_generation(executor):
"""Test config generation scales to hundreds of devices"""
# Create 100 devices
devices = [
{'device_id': f'device-{i:03d}', 'hostname': f'device-{i:03d}', 'role': 'leaf'}
for i in range(100)
]
start_time = time.time()
results, progress = executor.parallel_config_generation(
devices=devices,
template_fn=mock_config_template,
batch_size=50
)
duration = time.time() - start_time
# All should succeed
assert len(results) == 100
success_count = sum(1 for r in results if r.status == TaskStatus.SUCCESS)
assert success_count == 100
# Should complete reasonably quickly (parallel execution)
# Serial execution would take much longer
assert duration < 10.0 # Should be well under 10 seconds
print(f"\nGenerated 100 configs in {duration:.2f} seconds")
print(f"Average: {duration/100*1000:.1f}ms per device")
def test_progress_tracking_time_estimates():
"""Test progress tracking time estimation"""
progress = ExecutionProgress(total_devices=100)
# Simulate some work
time.sleep(0.1)
progress.completed = 25
# Should have time estimate
eta = progress.estimated_time_remaining
assert eta is not None
assert eta > 0
# Complete more work
progress.completed = 50
eta2 = progress.estimated_time_remaining
# ETA should decrease
assert eta2 < eta
def test_executor_multiple_operations(executor, sample_devices):
"""Test running multiple operations sequentially"""
# Config generation
results1, _ = executor.parallel_config_generation(
devices=sample_devices,
template_fn=mock_config_template
)
# Batfish analysis
configs = {r.device_id: r.result for r in results1 if r.status == TaskStatus.SUCCESS}
results2, _ = executor.parallel_batfish_analysis(
configs=configs,
batfish_client=MockBatfishClient()
)
# Deployment
results3, _ = executor.parallel_deployment(
deployments=configs,
gns3_client=MockGNS3Client()
)
# All operations should succeed
assert all(r.status == TaskStatus.SUCCESS for r in results1)
assert all(r.status == TaskStatus.SUCCESS for r in results2)
assert all(r.status == TaskStatus.SUCCESS for r in results3)