|
|
""" |
|
|
Chronos 2 Model Service |
|
|
Handles model loading, caching, and inference using Chronos2Pipeline |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import time |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
from chronos import ChronosPipeline, Chronos2Pipeline |
|
|
|
|
|
from config.constants import CHRONOS2_MODEL, CONFIDENCE_LEVELS |
|
|
from config.settings import CONFIG, DEVICE, MODEL_CONFIG |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ChronosModelService: |
|
|
""" |
|
|
Service for managing Chronos 2 model lifecycle and inference |
|
|
Uses Chronos2Pipeline with DataFrame-based API |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.device = None |
|
|
self.model_variant = None |
|
|
self.is_loaded = False |
|
|
self.load_time = None |
|
|
self.is_chronos2 = False |
|
|
|
|
|
def _get_device(self) -> str: |
|
|
"""Determine the best available device""" |
|
|
if DEVICE == 'cuda': |
|
|
if not torch.cuda.is_available(): |
|
|
logger.warning("CUDA requested but not available, falling back to CPU") |
|
|
return 'cpu' |
|
|
return 'cuda' |
|
|
elif DEVICE == 'cpu': |
|
|
return 'cpu' |
|
|
else: |
|
|
return 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
def load_model(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Load the Chronos 2 model at startup |
|
|
|
|
|
Returns: |
|
|
Dictionary with loading status and metadata |
|
|
""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
logger.info("Loading Chronos 2 model from HuggingFace paper 2510.15821") |
|
|
|
|
|
|
|
|
model_path = CHRONOS2_MODEL |
|
|
self.model_variant = 'chronos-2' |
|
|
|
|
|
|
|
|
self.device = self._get_device() |
|
|
logger.info(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
self.model = Chronos2Pipeline.from_pretrained( |
|
|
model_path, |
|
|
device_map=self.device, |
|
|
torch_dtype=torch.bfloat16 if self.device == 'cuda' else torch.float32, |
|
|
) |
|
|
self.is_chronos2 = True |
|
|
|
|
|
self.load_time = time.time() - start_time |
|
|
self.is_loaded = True |
|
|
|
|
|
logger.info(f"Model loaded successfully in {self.load_time:.2f}s") |
|
|
|
|
|
|
|
|
if MODEL_CONFIG['warmup_enabled']: |
|
|
self._warmup() |
|
|
|
|
|
return { |
|
|
'status': 'success', |
|
|
'model': 'chronos-2', |
|
|
'device': self.device, |
|
|
'load_time': self.load_time, |
|
|
'model_name': model_path |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {str(e)}", exc_info=True) |
|
|
self.is_loaded = False |
|
|
return { |
|
|
'status': 'error', |
|
|
'error': str(e) |
|
|
} |
|
|
|
|
|
def _warmup(self): |
|
|
"""Run a warmup prediction to initialize the model""" |
|
|
try: |
|
|
logger.info("Running warmup prediction") |
|
|
|
|
|
|
|
|
warmup_data = pd.DataFrame({ |
|
|
'id': ['warmup'] * MODEL_CONFIG['warmup_length'], |
|
|
'timestamp': pd.date_range('2020-01-01', periods=MODEL_CONFIG['warmup_length'], freq='D'), |
|
|
'target': np.random.randn(MODEL_CONFIG['warmup_length']) |
|
|
}) |
|
|
|
|
|
self.predict( |
|
|
warmup_data, |
|
|
horizon=MODEL_CONFIG['warmup_horizon'], |
|
|
confidence_levels=[80] |
|
|
) |
|
|
logger.info("Warmup completed successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Warmup failed: {str(e)}") |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
data: pd.DataFrame, |
|
|
horizon: int, |
|
|
confidence_levels: List[int] = None, |
|
|
future_df: Optional[pd.DataFrame] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Generate forecasts using Chronos 2 model with DataFrame API |
|
|
|
|
|
Args: |
|
|
data: DataFrame with columns ['id', 'timestamp', 'target'] |
|
|
Can also include covariates for multivariate forecasting |
|
|
horizon: Number of periods to forecast |
|
|
confidence_levels: List of confidence levels (e.g., [80, 90, 95]) |
|
|
future_df: Optional DataFrame with future covariate values |
|
|
|
|
|
Returns: |
|
|
Dictionary with predictions and metadata |
|
|
""" |
|
|
logger.info("=" * 80) |
|
|
logger.info("MODEL SERVICE: predict() - ENTRY") |
|
|
logger.info(f"Data shape: {data.shape}") |
|
|
logger.info(f"Data columns: {data.columns.tolist()}") |
|
|
logger.info(f"Horizon: {horizon}") |
|
|
logger.info(f"Confidence levels: {confidence_levels}") |
|
|
logger.info(f"Is loaded: {self.is_loaded}") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
if not self.is_loaded: |
|
|
logger.error("β Model not loaded!") |
|
|
raise RuntimeError("Model not loaded. Call load_model() first.") |
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
logger.info("Starting prediction...") |
|
|
|
|
|
|
|
|
if confidence_levels is None: |
|
|
confidence_levels = CONFIDENCE_LEVELS |
|
|
|
|
|
|
|
|
quantile_levels = [] |
|
|
for cl in sorted(confidence_levels): |
|
|
lower = (100 - cl) / 200 |
|
|
upper = 1 - lower |
|
|
quantile_levels.extend([lower, upper]) |
|
|
|
|
|
|
|
|
quantile_levels.append(0.5) |
|
|
quantile_levels = sorted(set(quantile_levels)) |
|
|
|
|
|
logger.info(f"Generating forecast for horizon={horizon}, quantiles={quantile_levels}") |
|
|
|
|
|
|
|
|
required_cols = ['id', 'timestamp', 'target'] |
|
|
logger.info(f"Checking for required columns: {required_cols}") |
|
|
if not all(col in data.columns for col in required_cols): |
|
|
error_msg = f"Data must contain columns: {required_cols}, but got: {data.columns.tolist()}" |
|
|
logger.error(f"β {error_msg}") |
|
|
raise ValueError(error_msg) |
|
|
logger.info("β All required columns present") |
|
|
|
|
|
|
|
|
if self.is_chronos2: |
|
|
logger.info("Using Chronos2Pipeline.predict_df() method") |
|
|
logger.info(f"Calling predict_df with prediction_length={horizon}, quantile_levels={quantile_levels}") |
|
|
|
|
|
pred_df = self.model.predict_df( |
|
|
df=data, |
|
|
future_df=future_df, |
|
|
prediction_length=horizon, |
|
|
quantile_levels=quantile_levels, |
|
|
id_column='id', |
|
|
timestamp_column='timestamp', |
|
|
target='target' |
|
|
) |
|
|
logger.info(f"β predict_df completed - result shape: {pred_df.shape}") |
|
|
else: |
|
|
|
|
|
|
|
|
context_tensor = torch.tensor(data['target'].values, dtype=torch.float32).unsqueeze(0) |
|
|
|
|
|
|
|
|
forecast_tensors = self.model.predict( |
|
|
context=context_tensor, |
|
|
prediction_length=horizon, |
|
|
num_samples=20, |
|
|
limit_prediction_length=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
quantiles_np = np.quantile( |
|
|
forecast_tensors.squeeze(0).numpy(), |
|
|
q=quantile_levels, |
|
|
axis=0 |
|
|
) |
|
|
|
|
|
|
|
|
last_timestamp = pd.to_datetime(data['timestamp'].iloc[-1]) |
|
|
freq = pd.infer_freq(pd.to_datetime(data['timestamp'])) |
|
|
if freq is None: |
|
|
freq = 'D' |
|
|
|
|
|
future_timestamps = pd.date_range( |
|
|
start=last_timestamp, |
|
|
periods=horizon + 1, |
|
|
freq=freq |
|
|
)[1:] |
|
|
|
|
|
pred_df = pd.DataFrame({ |
|
|
'id': [data['id'].iloc[0]] * horizon, |
|
|
'timestamp': future_timestamps |
|
|
}) |
|
|
|
|
|
|
|
|
for i, q in enumerate(quantile_levels): |
|
|
pred_df[f'{q:.2f}'] = quantiles_np[i, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
series_ids = pred_df['id'].unique() |
|
|
if len(series_ids) > 0: |
|
|
series_pred = pred_df[pred_df['id'] == series_ids[0]].copy() |
|
|
else: |
|
|
series_pred = pred_df.copy() |
|
|
|
|
|
|
|
|
forecast_df = pd.DataFrame({ |
|
|
'ds': series_pred['timestamp'], |
|
|
'forecast': series_pred['0.5'] |
|
|
}) |
|
|
|
|
|
|
|
|
for cl in confidence_levels: |
|
|
lower = (100 - cl) / 200 |
|
|
upper = 1 - lower |
|
|
|
|
|
lower_col = f'{lower:.2f}' |
|
|
upper_col = f'{upper:.2f}' |
|
|
|
|
|
if lower_col in series_pred.columns: |
|
|
forecast_df[f'lower_{cl}'] = series_pred[lower_col].values |
|
|
if upper_col in series_pred.columns: |
|
|
forecast_df[f'upper_{cl}'] = series_pred[upper_col].values |
|
|
|
|
|
inference_time = time.time() - start_time |
|
|
|
|
|
logger.info(f"β Forecast generated successfully in {inference_time:.2f}s") |
|
|
logger.info(f"Returning forecast DataFrame with {len(forecast_df)} rows") |
|
|
logger.info("MODEL SERVICE: predict() - EXIT (success)") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
return { |
|
|
'status': 'success', |
|
|
'forecast': forecast_df, |
|
|
'inference_time': inference_time, |
|
|
'horizon': horizon, |
|
|
'confidence_levels': confidence_levels, |
|
|
'full_prediction': pred_df |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β EXCEPTION in predict(): {str(e)}", exc_info=True) |
|
|
logger.info("MODEL SERVICE: predict() - EXIT (exception)") |
|
|
logger.info("=" * 80) |
|
|
return { |
|
|
'status': 'error', |
|
|
'error': str(e) |
|
|
} |
|
|
|
|
|
def backtest( |
|
|
self, |
|
|
data: pd.DataFrame, |
|
|
test_size: int, |
|
|
forecast_horizon: int, |
|
|
confidence_levels: List[int] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Perform backtesting on historical data to evaluate model performance |
|
|
|
|
|
Args: |
|
|
data: DataFrame with columns ['id', 'timestamp', 'target'] |
|
|
test_size: Number of periods to use for testing |
|
|
forecast_horizon: Forecast horizon for each prediction |
|
|
confidence_levels: List of confidence levels |
|
|
|
|
|
Returns: |
|
|
Dictionary with backtest results including predictions vs actuals |
|
|
""" |
|
|
logger.info("=" * 80) |
|
|
logger.info("MODEL SERVICE: backtest() - ENTRY") |
|
|
logger.info(f"Data shape: {data.shape}") |
|
|
logger.info(f"Test size: {test_size}") |
|
|
logger.info(f"Forecast horizon: {forecast_horizon}") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("Model not loaded. Call load_model() first.") |
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
train_size = len(data) - test_size |
|
|
if train_size < forecast_horizon * 2: |
|
|
raise ValueError(f"Insufficient training data. Need at least {forecast_horizon * 2} points.") |
|
|
|
|
|
|
|
|
|
|
|
train_data = data.iloc[:train_size].copy() |
|
|
test_data = data.iloc[train_size:].copy() |
|
|
|
|
|
logger.info(f"Train size: {len(train_data)}, Test size: {len(test_data)}") |
|
|
|
|
|
|
|
|
forecast_result = self.predict( |
|
|
data=train_data, |
|
|
horizon=test_size, |
|
|
confidence_levels=confidence_levels |
|
|
) |
|
|
|
|
|
if forecast_result['status'] == 'error': |
|
|
return forecast_result |
|
|
|
|
|
forecast_df = forecast_result['forecast'] |
|
|
|
|
|
|
|
|
backtest_df = pd.DataFrame({ |
|
|
'timestamp': test_data['timestamp'].values, |
|
|
'actual': test_data['target'].values, |
|
|
'predicted': forecast_df['forecast'].values[:len(test_data)] |
|
|
}) |
|
|
|
|
|
|
|
|
for cl in (confidence_levels or []): |
|
|
lower_col = f'lower_{cl}' |
|
|
upper_col = f'upper_{cl}' |
|
|
if lower_col in forecast_df.columns: |
|
|
backtest_df[lower_col] = forecast_df[lower_col].values[:len(test_data)] |
|
|
if upper_col in forecast_df.columns: |
|
|
backtest_df[upper_col] = forecast_df[upper_col].values[:len(test_data)] |
|
|
|
|
|
|
|
|
actual = backtest_df['actual'].values |
|
|
predicted = backtest_df['predicted'].values |
|
|
|
|
|
|
|
|
mask = ~(np.isnan(actual) | np.isnan(predicted)) |
|
|
actual = actual[mask] |
|
|
predicted = predicted[mask] |
|
|
|
|
|
if len(actual) == 0: |
|
|
raise ValueError("No valid data points for metric calculation") |
|
|
|
|
|
mae = np.mean(np.abs(actual - predicted)) |
|
|
rmse = np.sqrt(np.mean((actual - predicted) ** 2)) |
|
|
mape = np.mean(np.abs((actual - predicted) / (actual + 1e-10))) * 100 |
|
|
|
|
|
|
|
|
ss_res = np.sum((actual - predicted) ** 2) |
|
|
ss_tot = np.sum((actual - np.mean(actual)) ** 2) |
|
|
r2 = 1 - (ss_res / (ss_tot + 1e-10)) |
|
|
|
|
|
metrics = { |
|
|
'MAE': float(mae), |
|
|
'RMSE': float(rmse), |
|
|
'MAPE': float(mape), |
|
|
'R2': float(r2) |
|
|
} |
|
|
|
|
|
inference_time = time.time() - start_time |
|
|
|
|
|
logger.info(f"β Backtest completed in {inference_time:.2f}s") |
|
|
logger.info(f"Metrics: MAE={mae:.2f}, RMSE={rmse:.2f}, MAPE={mape:.2f}%, R2={r2:.4f}") |
|
|
logger.info("MODEL SERVICE: backtest() - EXIT (success)") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
return { |
|
|
'status': 'success', |
|
|
'backtest_data': backtest_df, |
|
|
'metrics': metrics, |
|
|
'inference_time': inference_time, |
|
|
'train_size': train_size, |
|
|
'test_size': test_size |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β EXCEPTION in backtest(): {str(e)}", exc_info=True) |
|
|
logger.info("MODEL SERVICE: backtest() - EXIT (exception)") |
|
|
logger.info("=" * 80) |
|
|
return { |
|
|
'status': 'error', |
|
|
'error': str(e) |
|
|
} |
|
|
|
|
|
def get_status(self) -> Dict[str, Any]: |
|
|
"""Get current model status""" |
|
|
return { |
|
|
'is_loaded': self.is_loaded, |
|
|
'variant': self.model_variant, |
|
|
'device': self.device, |
|
|
'load_time': self.load_time |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
model_service = ChronosModelService() |
|
|
|