LifeFlow-AI / src /optimization /solver /ortools_solver.py
Marco310's picture
buildup agent system
b7d08cf
"""
OR-Tools 求解器 - 封裝 OR-Tools API
完全保留原始 tsptw_solver_old.py 的 OR-Tools 設置邏輯
"""
from typing import List, Dict, Any, Tuple
from datetime import datetime, timedelta
import numpy as np
from ortools.constraint_solver import routing_enums_pb2
from ortools.constraint_solver import pywrapcp
from src.infra.logger import get_logger
from src.optimization.models.internal_models import _Task, _Graph
from src.optimization.graph.time_window_handler import TimeWindowHandler
logger = get_logger(__name__)
class ORToolsSolver:
"""
OR-Tools 求解器
職責:
- 創建 RoutingModel 和 RoutingIndexManager
- 設置時間維度約束
- 設置優先級約束
- 執行求解
"""
def __init__(
self,
time_limit_seconds: int = 30,
verbose: bool = False,
):
self.time_limit_seconds = time_limit_seconds
self.verbose = verbose
self.tw_handler = TimeWindowHandler()
def solve(
self,
graph: _Graph,
tasks: List[_Task],
start_time: datetime,
deadline: datetime,
max_wait_time_sec: int,
) -> Tuple[pywrapcp.RoutingModel, pywrapcp.RoutingIndexManager, pywrapcp.Assignment]:
"""
求解 TSPTW
完全保留原始邏輯:
- _solve_internal() 中的 OR-Tools 設置部分
Returns:
(routing, manager, solution): OR-Tools 求解結果
"""
num_nodes = len(graph.node_meta)
# 1. 計算服務時間
service_time = self._build_service_time_per_node(tasks, graph.node_meta)
# 2. 創建 manager
manager = pywrapcp.RoutingIndexManager(num_nodes, 1, 0)
# 3. 創建 routing model
routing = pywrapcp.RoutingModel(manager)
# 4. 註冊 transit callback
transit_cb_index = self._register_transit_callback(
routing, manager, graph.duration_matrix, service_time
)
routing.SetArcCostEvaluatorOfAllVehicles(transit_cb_index)
# 5. 添加時間維度
time_dimension = self._add_time_dimension(
routing,
manager,
transit_cb_index,
tasks,
graph.node_meta,
start_time,
deadline,
max_wait_time_sec,
)
# 6. 添加優先級約束
self._add_priority_disjunctions(routing, manager, tasks, graph.node_meta)
# 7. 設置搜索參數
search_parameters = self._create_search_parameters()
# 8. 求解
if self.verbose:
logger.info(
"Starting OR-Tools search with time limit = %ds",
self.time_limit_seconds,
)
solution = routing.SolveWithParameters(search_parameters)
if self.verbose:
logger.info("OR-Tools search completed")
return routing, manager, solution
@staticmethod
def _build_service_time_per_node(
tasks: List[_Task],
node_meta: List[Dict[str, Any]],
) -> List[int]:
"""
構建每個節點的服務時間(秒)
完全保留原始邏輯: _build_service_time_per_node()
"""
service_time = [0] * len(node_meta)
for node, meta in enumerate(node_meta):
if meta["type"] == "poi":
task_idx = meta["task_idx"]
task = tasks[task_idx]
service_time[node] = task.service_duration_sec
return service_time
@staticmethod
def _register_transit_callback(
routing: pywrapcp.RoutingModel,
manager: pywrapcp.RoutingIndexManager,
duration_matrix: np.ndarray,
service_time: List[int],
) -> int:
"""
註冊 transit callback
完全保留原始邏輯: _register_transit_callback()
"""
def transit_callback(from_index: int, to_index: int) -> int:
from_node = manager.IndexToNode(from_index)
to_node = manager.IndexToNode(to_index)
travel = duration_matrix[from_node, to_node]
service = service_time[from_node]
return int(travel + service)
transit_cb_index = routing.RegisterTransitCallback(transit_callback)
return transit_cb_index
def _add_time_dimension(
self,
routing: pywrapcp.RoutingModel,
manager: pywrapcp.RoutingIndexManager,
transit_cb_index: int,
tasks: List[_Task],
node_meta: List[Dict[str, Any]],
start_time: datetime,
deadline: datetime,
max_wait_time_sec: int,
) -> pywrapcp.RoutingDimension:
"""
添加時間維度約束
完全保留原始邏輯: _add_time_dimension()
"""
if deadline is None:
deadline = start_time + timedelta(days=3)
horizon_sec = int((deadline - start_time).total_seconds())
routing.AddDimension(
transit_cb_index,
max_wait_time_sec,
horizon_sec,
False,
"Time",
)
time_dimension = routing.GetDimensionOrDie("Time")
# depot 起點:允許在 [0, horizon] 內出發
start_index = routing.Start(0)
time_dimension.CumulVar(start_index).SetRange(0, horizon_sec)
for node in range(1, len(node_meta)):
meta = node_meta[node]
if meta["type"] != "poi":
continue
index = manager.NodeToIndex(node)
task_idx = meta["task_idx"]
task = tasks[task_idx]
poi_tw = meta.get("poi_time_window")
task_tw = task.time_window
# 計算有效時間窗口
start_sec, end_sec = self.tw_handler.compute_effective_time_window(
task_tw, poi_tw, start_time, horizon_sec
)
if start_sec > end_sec:
# 完全無交集 → 強制一個無效的 0 範圍,讓 solver 自己避免
logger.warning(
"Node(%s) has infeasible time window, forcing tiny 0 range.",
meta,
)
start_sec = end_sec = 0
time_dimension.CumulVar(index).SetRange(start_sec, end_sec)
end_index = routing.End(0)
time_dimension.CumulVar(end_index).SetRange(0, horizon_sec)
return time_dimension
@staticmethod
def _add_priority_disjunctions(
routing: pywrapcp.RoutingModel,
manager: pywrapcp.RoutingIndexManager,
tasks: List[_Task],
node_meta: List[Dict[str, Any]],
) -> None:
"""
添加優先級約束
完全保留原始邏輯: _add_priority_disjunctions()
"""
task_nodes: Dict[int, List[int]] = {i: [] for i in range(len(tasks))}
for node in range(1, len(node_meta)):
meta = node_meta[node]
if meta["type"] != "poi":
continue
task_idx = meta["task_idx"]
task_nodes[task_idx].append(node)
for task_idx, nodes in task_nodes.items():
if not nodes:
continue
task = tasks[task_idx]
priority = task.priority
# 根據優先級設定 penalty
if priority == "HIGH":
penalty = 10_000_000
elif priority == "MEDIUM":
penalty = 100_000
else:
penalty = 10_000
routing_indices = [manager.NodeToIndex(n) for n in nodes]
routing.AddDisjunction(routing_indices, penalty)
def _create_search_parameters(self) -> pywrapcp.DefaultRoutingSearchParameters:
"""
創建搜索參數
完全保留原始邏輯
"""
search_parameters = pywrapcp.DefaultRoutingSearchParameters()
search_parameters.first_solution_strategy = (
routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC
)
search_parameters.local_search_metaheuristic = (
routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH
)
search_parameters.time_limit.FromSeconds(self.time_limit_seconds)
return search_parameters