Spaces:
Running
Running
| """ | |
| OR-Tools 求解器 - 封裝 OR-Tools API | |
| 完全保留原始 tsptw_solver_old.py 的 OR-Tools 設置邏輯 | |
| """ | |
| from typing import List, Dict, Any, Tuple | |
| from datetime import datetime, timedelta | |
| import numpy as np | |
| from ortools.constraint_solver import routing_enums_pb2 | |
| from ortools.constraint_solver import pywrapcp | |
| from src.infra.logger import get_logger | |
| from src.optimization.models.internal_models import _Task, _Graph | |
| from src.optimization.graph.time_window_handler import TimeWindowHandler | |
| logger = get_logger(__name__) | |
| class ORToolsSolver: | |
| """ | |
| OR-Tools 求解器 | |
| 職責: | |
| - 創建 RoutingModel 和 RoutingIndexManager | |
| - 設置時間維度約束 | |
| - 設置優先級約束 | |
| - 執行求解 | |
| """ | |
| def __init__( | |
| self, | |
| time_limit_seconds: int = 30, | |
| verbose: bool = False, | |
| ): | |
| self.time_limit_seconds = time_limit_seconds | |
| self.verbose = verbose | |
| self.tw_handler = TimeWindowHandler() | |
| def solve( | |
| self, | |
| graph: _Graph, | |
| tasks: List[_Task], | |
| start_time: datetime, | |
| deadline: datetime, | |
| max_wait_time_sec: int, | |
| ) -> Tuple[pywrapcp.RoutingModel, pywrapcp.RoutingIndexManager, pywrapcp.Assignment]: | |
| """ | |
| 求解 TSPTW | |
| 完全保留原始邏輯: | |
| - _solve_internal() 中的 OR-Tools 設置部分 | |
| Returns: | |
| (routing, manager, solution): OR-Tools 求解結果 | |
| """ | |
| num_nodes = len(graph.node_meta) | |
| # 1. 計算服務時間 | |
| service_time = self._build_service_time_per_node(tasks, graph.node_meta) | |
| # 2. 創建 manager | |
| manager = pywrapcp.RoutingIndexManager(num_nodes, 1, 0) | |
| # 3. 創建 routing model | |
| routing = pywrapcp.RoutingModel(manager) | |
| # 4. 註冊 transit callback | |
| transit_cb_index = self._register_transit_callback( | |
| routing, manager, graph.duration_matrix, service_time | |
| ) | |
| routing.SetArcCostEvaluatorOfAllVehicles(transit_cb_index) | |
| # 5. 添加時間維度 | |
| time_dimension = self._add_time_dimension( | |
| routing, | |
| manager, | |
| transit_cb_index, | |
| tasks, | |
| graph.node_meta, | |
| start_time, | |
| deadline, | |
| max_wait_time_sec, | |
| ) | |
| # 6. 添加優先級約束 | |
| self._add_priority_disjunctions(routing, manager, tasks, graph.node_meta) | |
| # 7. 設置搜索參數 | |
| search_parameters = self._create_search_parameters() | |
| # 8. 求解 | |
| if self.verbose: | |
| logger.info( | |
| "Starting OR-Tools search with time limit = %ds", | |
| self.time_limit_seconds, | |
| ) | |
| solution = routing.SolveWithParameters(search_parameters) | |
| if self.verbose: | |
| logger.info("OR-Tools search completed") | |
| return routing, manager, solution | |
| def _build_service_time_per_node( | |
| tasks: List[_Task], | |
| node_meta: List[Dict[str, Any]], | |
| ) -> List[int]: | |
| """ | |
| 構建每個節點的服務時間(秒) | |
| 完全保留原始邏輯: _build_service_time_per_node() | |
| """ | |
| service_time = [0] * len(node_meta) | |
| for node, meta in enumerate(node_meta): | |
| if meta["type"] == "poi": | |
| task_idx = meta["task_idx"] | |
| task = tasks[task_idx] | |
| service_time[node] = task.service_duration_sec | |
| return service_time | |
| def _register_transit_callback( | |
| routing: pywrapcp.RoutingModel, | |
| manager: pywrapcp.RoutingIndexManager, | |
| duration_matrix: np.ndarray, | |
| service_time: List[int], | |
| ) -> int: | |
| """ | |
| 註冊 transit callback | |
| 完全保留原始邏輯: _register_transit_callback() | |
| """ | |
| def transit_callback(from_index: int, to_index: int) -> int: | |
| from_node = manager.IndexToNode(from_index) | |
| to_node = manager.IndexToNode(to_index) | |
| travel = duration_matrix[from_node, to_node] | |
| service = service_time[from_node] | |
| return int(travel + service) | |
| transit_cb_index = routing.RegisterTransitCallback(transit_callback) | |
| return transit_cb_index | |
| def _add_time_dimension( | |
| self, | |
| routing: pywrapcp.RoutingModel, | |
| manager: pywrapcp.RoutingIndexManager, | |
| transit_cb_index: int, | |
| tasks: List[_Task], | |
| node_meta: List[Dict[str, Any]], | |
| start_time: datetime, | |
| deadline: datetime, | |
| max_wait_time_sec: int, | |
| ) -> pywrapcp.RoutingDimension: | |
| """ | |
| 添加時間維度約束 | |
| 完全保留原始邏輯: _add_time_dimension() | |
| """ | |
| if deadline is None: | |
| deadline = start_time + timedelta(days=3) | |
| horizon_sec = int((deadline - start_time).total_seconds()) | |
| routing.AddDimension( | |
| transit_cb_index, | |
| max_wait_time_sec, | |
| horizon_sec, | |
| False, | |
| "Time", | |
| ) | |
| time_dimension = routing.GetDimensionOrDie("Time") | |
| # depot 起點:允許在 [0, horizon] 內出發 | |
| start_index = routing.Start(0) | |
| time_dimension.CumulVar(start_index).SetRange(0, horizon_sec) | |
| for node in range(1, len(node_meta)): | |
| meta = node_meta[node] | |
| if meta["type"] != "poi": | |
| continue | |
| index = manager.NodeToIndex(node) | |
| task_idx = meta["task_idx"] | |
| task = tasks[task_idx] | |
| poi_tw = meta.get("poi_time_window") | |
| task_tw = task.time_window | |
| # 計算有效時間窗口 | |
| start_sec, end_sec = self.tw_handler.compute_effective_time_window( | |
| task_tw, poi_tw, start_time, horizon_sec | |
| ) | |
| if start_sec > end_sec: | |
| # 完全無交集 → 強制一個無效的 0 範圍,讓 solver 自己避免 | |
| logger.warning( | |
| "Node(%s) has infeasible time window, forcing tiny 0 range.", | |
| meta, | |
| ) | |
| start_sec = end_sec = 0 | |
| time_dimension.CumulVar(index).SetRange(start_sec, end_sec) | |
| end_index = routing.End(0) | |
| time_dimension.CumulVar(end_index).SetRange(0, horizon_sec) | |
| return time_dimension | |
| def _add_priority_disjunctions( | |
| routing: pywrapcp.RoutingModel, | |
| manager: pywrapcp.RoutingIndexManager, | |
| tasks: List[_Task], | |
| node_meta: List[Dict[str, Any]], | |
| ) -> None: | |
| """ | |
| 添加優先級約束 | |
| 完全保留原始邏輯: _add_priority_disjunctions() | |
| """ | |
| task_nodes: Dict[int, List[int]] = {i: [] for i in range(len(tasks))} | |
| for node in range(1, len(node_meta)): | |
| meta = node_meta[node] | |
| if meta["type"] != "poi": | |
| continue | |
| task_idx = meta["task_idx"] | |
| task_nodes[task_idx].append(node) | |
| for task_idx, nodes in task_nodes.items(): | |
| if not nodes: | |
| continue | |
| task = tasks[task_idx] | |
| priority = task.priority | |
| # 根據優先級設定 penalty | |
| if priority == "HIGH": | |
| penalty = 10_000_000 | |
| elif priority == "MEDIUM": | |
| penalty = 100_000 | |
| else: | |
| penalty = 10_000 | |
| routing_indices = [manager.NodeToIndex(n) for n in nodes] | |
| routing.AddDisjunction(routing_indices, penalty) | |
| def _create_search_parameters(self) -> pywrapcp.DefaultRoutingSearchParameters: | |
| """ | |
| 創建搜索參數 | |
| 完全保留原始邏輯 | |
| """ | |
| search_parameters = pywrapcp.DefaultRoutingSearchParameters() | |
| search_parameters.first_solution_strategy = ( | |
| routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC | |
| ) | |
| search_parameters.local_search_metaheuristic = ( | |
| routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH | |
| ) | |
| search_parameters.time_limit.FromSeconds(self.time_limit_seconds) | |
| return search_parameters | |