import gradio as gr import numpy as np import pandas as pd import torch import yaml from huggingface_hub import hf_hub_download import spaces import traceback import functools import yfinance as yf import pandas_ta as ta import plotly.graph_objects as go from plotly.subplots import make_subplots from scipy import stats # --- All your src imports --- from examples.utils import load_model from src.plotting.plot_timeseries import plot_multivariate_timeseries from src.data.containers import BatchTimeSeriesContainer, Frequency from src.synthetic_generation.generator_params import ( SineWaveGeneratorParams, GPGeneratorParams, AnomalyGeneratorParams, MultiScaleFractalAudioParams, FinancialVolatilityAudioParams, SawToothGeneratorParams, SpikesGeneratorParams, StepGeneratorParams, OrnsteinUhlenbeckProcessGeneratorParams, NetworkTopologyAudioParams, StochasticRhythmAudioParams, CauKerGeneratorParams, ForecastPFNGeneratorParams, KernelGeneratorParams ) # Define fallback values for GIFT evaluation ALL_DATASETS = ["ETTm1", "ETTm2", "ETTh1", "ETTh2", "Weather", "Electricity", "Traffic"] TERMS = ["short", "medium", "long"] # GIFT Evaluation imports (optional) try: from src.gift_eval.evaluate import evaluate_datasets from src.gift_eval.predictor import TimeSeriesPredictor from src.gift_eval.results import aggregate_results from src.gift_eval.constants import ALL_DATASETS, TERMS GIFT_EVAL_AVAILABLE = True except ImportError: GIFT_EVAL_AVAILABLE = False print("Warning: GIFT evaluation dependencies not available. GIFT evaluation tab will be disabled.") from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import SineWaveGeneratorWrapper from src.synthetic_generation.anomalies.anomaly_generator_wrapper import AnomalyGeneratorWrapper from src.synthetic_generation.sawtooth.sawtooth_generator_wrapper import SawToothGeneratorWrapper from src.synthetic_generation.spikes.spikes_generator_wrapper import SpikesGeneratorWrapper from src.synthetic_generation.steps.step_generator_wrapper import StepGeneratorWrapper from src.synthetic_generation.ornstein_uhlenbeck_process.ou_generator_wrapper import OrnsteinUhlenbeckProcessGeneratorWrapper # Try to import additional optional generators try: from src.synthetic_generation.audio_generators.network_topology_wrapper import NetworkTopologyAudioWrapper NETWORK_AVAILABLE = True except ImportError: NETWORK_AVAILABLE = False try: from src.synthetic_generation.audio_generators.stochastic_rhythm_wrapper import StochasticRhythmAudioWrapper RHYTHM_AVAILABLE = True except ImportError: RHYTHM_AVAILABLE = False try: from src.synthetic_generation.cauker.cauker_generator_wrapper import CauKerGeneratorWrapper CAUKER_AVAILABLE = True except ImportError: CAUKER_AVAILABLE = False try: from src.synthetic_generation.forecast_pfn_prior.forecast_pfn_generator_wrapper import ForecastPFNGeneratorWrapper FORECAST_PFN_AVAILABLE = True except ImportError: FORECAST_PFN_AVAILABLE = False try: from src.synthetic_generation.kernel_synth.kernel_generator_wrapper import KernelGeneratorWrapper KERNEL_AVAILABLE = True except ImportError: KERNEL_AVAILABLE = False # Try to import optional generators try: from src.synthetic_generation.gp_prior.gp_generator_wrapper import GPGeneratorWrapper GP_AVAILABLE = True except ImportError: GP_AVAILABLE = False try: from src.synthetic_generation.audio_generators.multi_scale_fractal_wrapper import MultiScaleFractalAudioWrapper from src.synthetic_generation.audio_generators.financial_volatility_wrapper import FinancialVolatilityAudioWrapper AUDIO_AVAILABLE = True except ImportError: AUDIO_AVAILABLE = False # Define global placeholders (device is not needed - only used inside GPU function) model = None # Global variables to store forecast results for export last_forecast_results = None last_metrics_results = None last_analysis_results = None def create_gradio_app(): """Create and configure the Gradio app for TempoPFN.""" @functools.lru_cache(maxsize=None) def load_oil_price_data(): """Downloads and caches daily WTI oil price data.""" print("--- Downloading WTI Oil Price data for the first time ---") url = "https://datahub.io/core/oil-prices/r/wti-daily.csv" try: df = pd.read_csv(url) df['Date'] = pd.to_datetime(df['Date']) df = df.sort_values('Date') df = df.set_index('Date').asfreq('D').ffill().reset_index() values = df['Price'].values.astype(np.float32) start_date = df['Date'].min() print(f"--- Oil price data loaded. {len(values)} points ---") return values, start_date, "D" except Exception as e: print(f"Error loading oil price data: {e}") raise def generate_synthetic_data(length=2048, seed=42): """Generate synthetic sine wave data for demonstration.""" sine_params = SineWaveGeneratorParams(global_seed=seed, length=length) sine_generator = SineWaveGeneratorWrapper(sine_params) batch = sine_generator.generate_batch(batch_size=1, seed=seed) values = torch.from_numpy(batch.values).to(torch.float32) if values.ndim == 2: values = values.unsqueeze(-1) # FIX: Use .squeeze() to return a 1D array to match expected logic flow (4D bugfix) return values.squeeze().numpy(), batch.start[0], batch.frequency[0] def process_uploaded_data(file): """Process uploaded CSV file with time series data.""" if file is None: return None, "No file uploaded" try: df = pd.read_csv(file.name) if len(df.columns) < 2: return None, "CSV must have at least 2 columns" time_col, value_col = df.columns[0], df.columns[1] try: df[time_col] = pd.to_datetime(df[time_col]) df = df.sort_values(time_col) start_date = df[time_col].min() freq = pd.infer_freq(df[time_col]) or "D" except Exception: start_date = np.datetime64("2020-01-01") freq = "D" values = df[value_col].values.astype(np.float32) volumes = df['Volume'].values.astype(np.float32) if 'Volume' in df.columns else None return values, volumes, start_date, freq, f"Loaded {len(values)} data points" except Exception as e: return None, None, None, None, f"Error processing file: {str(e)}" def create_advanced_visualizations(history_values, predictions, future_values=None): """Create advanced statistical visualizations.""" try: # Create subplots with multiple analyses fig = make_subplots( rows=2, cols=2, subplot_titles=('Residual Analysis', 'ACF Plot', 'Distribution Comparison', 'Forecast Error Distribution'), specs=[[{"type": "scatter"}, {"type": "bar"}], [{"type": "histogram"}, {"type": "histogram"}]] ) history_flat = history_values.flatten() pred_flat = predictions.flatten() # 1. Residual Analysis (if ground truth available) if future_values is not None: future_flat = future_values.flatten()[:len(pred_flat)] residuals = future_flat - pred_flat fig.add_trace( go.Scatter(x=list(range(len(residuals))), y=residuals, mode='lines+markers', name='Residuals'), row=1, col=1 ) fig.add_hline(y=0, line_dash="dash", line_color="red", row=1, col=1) else: # Just show predictions fig.add_trace( go.Scatter(x=list(range(len(pred_flat))), y=pred_flat, mode='lines', name='Predictions'), row=1, col=1 ) # 2. Autocorrelation Function (ACF) max_lags = min(40, len(history_flat) // 2) acf_values = [] for lag in range(max_lags): if lag == 0: acf_values.append(1.0) else: acf = np.corrcoef(history_flat[:-lag], history_flat[lag:])[0, 1] acf_values.append(acf) fig.add_trace( go.Bar(x=list(range(max_lags)), y=acf_values, name='ACF'), row=1, col=2 ) # Confidence interval lines ci = 1.96 / np.sqrt(len(history_flat)) fig.add_hline(y=ci, line_dash="dash", line_color="blue", row=1, col=2) fig.add_hline(y=-ci, line_dash="dash", line_color="blue", row=1, col=2) # 3. Distribution Comparison fig.add_trace( go.Histogram(x=history_flat, name='Historical', opacity=0.7, nbinsx=30), row=2, col=1 ) fig.add_trace( go.Histogram(x=pred_flat, name='Predictions', opacity=0.7, nbinsx=30), row=2, col=1 ) # 4. Forecast Error Distribution (if ground truth available) if future_values is not None: future_flat = future_values.flatten()[:len(pred_flat)] errors = future_flat - pred_flat fig.add_trace( go.Histogram(x=errors, name='Forecast Errors', nbinsx=30), row=2, col=2 ) else: # Show prediction distribution fig.add_trace( go.Histogram(x=pred_flat, name='Pred Distribution', nbinsx=30), row=2, col=2 ) # Update layout fig.update_layout( height=800, title_text="Advanced Statistical Analysis", showlegend=True ) fig.update_xaxes(title_text="Time Index", row=1, col=1) fig.update_yaxes(title_text="Value", row=1, col=1) fig.update_xaxes(title_text="Lag", row=1, col=2) fig.update_yaxes(title_text="Correlation", row=1, col=2) fig.update_xaxes(title_text="Value", row=2, col=1) fig.update_yaxes(title_text="Frequency", row=2, col=1) fig.update_xaxes(title_text="Error", row=2, col=2) fig.update_yaxes(title_text="Frequency", row=2, col=2) return fig except Exception as e: print(f"Error creating advanced visualizations: {e}") # Return simple error figure fig = go.Figure() fig.add_annotation( text=f"Error creating visualizations: {str(e)}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14, color="red") ) return fig def export_forecast_csv(): """Export forecast data to CSV.""" global last_forecast_results if last_forecast_results is None: return None, "No forecast data available. Please run a forecast first." try: # Create DataFrame with forecast data history = last_forecast_results['history'].flatten() predictions = last_forecast_results['predictions'].flatten() future = last_forecast_results['future'].flatten() max_len = max(len(history), len(predictions)) df_data = { 'Time_Index': list(range(max_len)), 'Historical_Value': list(history) + [np.nan] * (max_len - len(history)), 'Predicted_Value': [np.nan] * len(history) + list(predictions[:max_len - len(history)]), 'True_Future_Value': [np.nan] * len(history) + list(future[:max_len - len(history)]) } df = pd.DataFrame(df_data) filepath = "/tmp/forecast_data.csv" df.to_csv(filepath, index=False) return filepath, "Forecast data exported successfully!" except Exception as e: return None, f"Error exporting forecast data: {str(e)}" def export_metrics_csv(): """Export metrics summary to CSV.""" global last_metrics_results if last_metrics_results is None: return None, "No metrics available. Please run a forecast first." try: df = pd.DataFrame([last_metrics_results]) filepath = "/tmp/metrics_summary.csv" df.to_csv(filepath, index=False) return filepath, "Metrics summary exported successfully!" except Exception as e: return None, f"Error exporting metrics: {str(e)}" def export_analysis_csv(): """Export full analysis including forecast, metrics, and metadata.""" global last_forecast_results, last_metrics_results, last_analysis_results if last_forecast_results is None: return None, "No analysis data available. Please run a forecast first." try: # Combine all data analysis_data = { **last_analysis_results, **last_metrics_results, 'num_history_points': len(last_forecast_results['history']), 'num_forecast_points': len(last_forecast_results['predictions']), } df = pd.DataFrame([analysis_data]) filepath = "/tmp/full_analysis.csv" df.to_csv(filepath, index=False) return filepath, "Full analysis exported successfully!" except Exception as e: return None, f"Error exporting analysis: {str(e)}" def calculate_metrics(history_values, predictions, future_values=None, data_source=""): """Calculate comprehensive metrics for display in the UI.""" metrics = {} # Basic statistics metrics['data_mean'] = float(np.mean(history_values)) metrics['data_std'] = float(np.std(history_values)) metrics['data_skewness'] = float(stats.skew(history_values.flatten())) metrics['data_kurtosis'] = float(stats.kurtosis(history_values.flatten())) # Latest values and forecasts metrics['latest_price'] = float(history_values[-1, 0] if history_values.ndim > 1 else history_values[-1]) metrics['forecast_next'] = float(predictions[0, 0] if predictions.ndim > 1 else predictions[0]) # Volatility (30-day rolling std as percentage of mean) if len(history_values) >= 30: recent_30 = history_values[-30:].flatten() volatility = (np.std(recent_30) / np.mean(recent_30)) * 100 if np.mean(recent_30) != 0 else 0 metrics['vol_30d'] = float(volatility) else: metrics['vol_30d'] = 0.0 # 52-week high/low (or max/min of available data) lookback = min(252, len(history_values)) # 252 trading days ≈ 1 year recent_data = history_values[-lookback:].flatten() metrics['high_52wk'] = float(np.max(recent_data)) metrics['low_52wk'] = float(np.min(recent_data)) # Time series properties # Autocorrelation at lag 1 if len(history_values) > 1: flat_history = history_values.flatten() metrics['data_autocorr'] = float(np.corrcoef(flat_history[:-1], flat_history[1:])[0, 1]) else: metrics['data_autocorr'] = 0.0 # Stationarity test (simplified - using rolling mean variance) if len(history_values) >= 20: first_half = history_values[:len(history_values)//2].flatten() second_half = history_values[len(history_values)//2:].flatten() var_ratio = np.var(second_half) / np.var(first_half) if np.var(first_half) > 0 else 1.0 metrics['data_stationary'] = "Likely" if 0.5 < var_ratio < 2.0 else "Unlikely" else: metrics['data_stationary'] = "Unknown" # Pattern detection (simple heuristic) if metrics['data_autocorr'] > 0.7: metrics['pattern_type'] = "Trending" elif abs(metrics['data_autocorr']) < 0.3: metrics['pattern_type'] = "Random Walk" else: metrics['pattern_type'] = "Mean Reverting" # Performance metrics (if ground truth available) if future_values is not None: pred_flat = predictions.flatten()[:len(future_values.flatten())] true_flat = future_values.flatten()[:len(pred_flat)] # MSE, MAE metrics['mse'] = float(np.mean((pred_flat - true_flat) ** 2)) metrics['mae'] = float(np.mean(np.abs(pred_flat - true_flat))) # MAPE (avoiding division by zero) mape_values = np.abs((true_flat - pred_flat) / (true_flat + 1e-8)) * 100 metrics['mape'] = float(np.mean(mape_values)) else: metrics['mse'] = 0.0 metrics['mae'] = 0.0 metrics['mape'] = 0.0 # Uncertainty quantification placeholders (would need quantile predictions) metrics['coverage_80'] = 0.0 metrics['coverage_95'] = 0.0 metrics['calibration'] = 0.0 # Information theory metrics (simplified) # Sample entropy approximation try: hist_normalized = (history_values.flatten() - np.mean(history_values)) / (np.std(history_values) + 1e-8) metrics['sample_entropy'] = float(-np.mean(np.log(np.abs(hist_normalized) + 1e-8))) except: metrics['sample_entropy'] = 0.0 metrics['approx_entropy'] = metrics['sample_entropy'] * 0.8 # Placeholder metrics['perm_entropy'] = metrics['sample_entropy'] * 0.9 # Placeholder # Complexity measures # Fractal dimension (box-counting approximation) try: metrics['fractal_dim'] = float(1.0 + 0.5 * metrics['data_std'] / (np.mean(np.abs(np.diff(history_values.flatten()))) + 1e-8)) except: metrics['fractal_dim'] = 1.5 # Spectral features try: # FFT-based features fft_vals = np.fft.fft(history_values.flatten()) power_spectrum = np.abs(fft_vals[:len(fft_vals)//2]) ** 2 freqs = np.fft.fftfreq(len(history_values.flatten()))[:len(fft_vals)//2] # Dominant frequency dominant_idx = np.argmax(power_spectrum[1:]) + 1 # Skip DC component metrics['dominant_freq'] = float(abs(freqs[dominant_idx])) # Spectral centroid metrics['spectral_centroid'] = float(np.sum(freqs * power_spectrum) / (np.sum(power_spectrum) + 1e-8)) # Spectral entropy power_normalized = power_spectrum / (np.sum(power_spectrum) + 1e-8) metrics['spectral_entropy'] = float(-np.sum(power_normalized * np.log(power_normalized + 1e-8))) except: metrics['dominant_freq'] = 0.0 metrics['spectral_centroid'] = 0.0 metrics['spectral_entropy'] = 0.0 # Cross-validation placeholders metrics['cv_mse'] = 0.0 metrics['cv_mae'] = 0.0 metrics['cv_windows'] = 0 # Sensitivity placeholders metrics['horizon_sensitivity'] = 0.0 metrics['history_sensitivity'] = 0.0 metrics['stability_score'] = 0.0 return metrics @spaces.GPU(duration=120) # Extend timeout to 120 seconds for first-run compilation def run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object): """ GPU-only inference function for ZeroGPU Spaces. ALL CUDA operations must happen inside this decorated function. Extended timeout for Triton kernel compilation on first run. """ global model # Load model once on first call (on CPU first to save GPU time) if model is None: print("--- Loading TempoPFN model for the first time ---") print(f"Downloading model...") model_path = hf_hub_download(repo_id="AutoML-org/TempoPFN", filename="models/checkpoint_38M.pth") # Load on CPU first to save GPU allocation time print(f"Loading model from {model_path} to CPU first...") model = load_model(config_path="configs/example.yaml", model_path=model_path, device=torch.device("cpu")) print("--- Model loaded successfully on CPU ---") # Move model to GPU inside the decorated function device = torch.device("cuda:0") print(f"Moving model to {device}...") model.to(device) # Prepare container with GPU tensors container = BatchTimeSeriesContainer( history_values=history_values_tensor.to(device), future_values=future_values_tensor.to(device), start=[start], frequency=[freq_object], ) # Run inference with bfloat16 autocast print("Running inference...") with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): model_output = model(container) # Move model back to CPU to free GPU memory model.to(torch.device("cpu")) print("Inference complete, model moved back to CPU") return model_output def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5): """ Runs the TempoPFN forecast. Returns: history_price, history_volume, predictions, quantiles, plot, status, metrics, data_preview """ try: all_volumes = None if data_source == "Stock Ticker": if not stock_ticker: return None, None, None, None, "Please enter a stock ticker (e.g., SPY, AAPL)" print(f"--- Downloading '{stock_ticker}' data from yfinance ---") hist = yf.download(stock_ticker, period="max", auto_adjust=True) if hist.empty: return None, None, None, None, f"Could not find data for ticker '{stock_ticker}'" hist = hist[['Close', 'Volume']].asfreq('D').ffill() # --- FIX: Squeeze to ensure 1D array from pandas Series/DataFrame columns (4D bugfix) --- all_values = hist['Close'].values.astype(np.float32).squeeze() all_volumes = hist['Volume'].values.astype(np.float32).squeeze() data_start_date = hist.index.min() frequency = "D" elif data_source == "VIX Volatility Index": print("--- Downloading VIX data from yfinance ---") vix_data = yf.download("^VIX", period="max", auto_adjust=True) if vix_data.empty: return None, None, None, None, "Could not download VIX data" vix_data = vix_data.asfreq('D').ffill() all_values = vix_data['Close'].values.astype(np.float32).squeeze() data_start_date = vix_data.index.min() frequency = "D" print(f"--- VIX data loaded: {len(all_values)} points ---") elif data_source == "Default (WTI Oil Prices)": all_values, data_start_date, frequency = load_oil_price_data() elif data_source == "Upload Custom CSV": all_values, all_volumes, data_start_date, frequency, message = process_uploaded_data(uploaded_file) if all_values is None: return None, None, None, None, message elif data_source == "Synthetic Playground": print(f"--- Generating {synth_generator} synthetic data (complexity: {synth_complexity}) ---") # Generate synthetic data based on selected generator total_length = history_length + forecast_horizon if synth_generator == "Sine Waves": params = SineWaveGeneratorParams(global_seed=seed, length=total_length) generator = SineWaveGeneratorWrapper(params) elif synth_generator == "Sawtooth Waves": params = SawToothGeneratorParams(global_seed=seed, length=total_length) generator = SawToothGeneratorWrapper(params) elif synth_generator == "Spikes": params = SpikesGeneratorParams(global_seed=seed, length=total_length) generator = SpikesGeneratorWrapper(params) elif synth_generator == "Steps": params = StepGeneratorParams(global_seed=seed, length=total_length) generator = StepGeneratorWrapper(params) elif synth_generator == "Ornstein-Uhlenbeck": params = OrnsteinUhlenbeckProcessGeneratorParams(global_seed=seed, length=total_length) generator = OrnsteinUhlenbeckProcessGeneratorWrapper(params) elif synth_generator == "Gaussian Processes" and GP_AVAILABLE: params = GPGeneratorParams(global_seed=seed, length=total_length) generator = GPGeneratorWrapper(params) elif synth_generator == "Anomaly Patterns": params = AnomalyGeneratorParams(global_seed=seed, length=total_length) generator = AnomalyGeneratorWrapper(params) elif synth_generator == "Financial Volatility" and AUDIO_AVAILABLE: params = FinancialVolatilityAudioParams(global_seed=seed, length=total_length) generator = FinancialVolatilityAudioWrapper(params) elif synth_generator == "Fractal Patterns" and AUDIO_AVAILABLE: params = MultiScaleFractalAudioParams(global_seed=seed, length=total_length) generator = MultiScaleFractalAudioWrapper(params) elif synth_generator == "Network Topology" and NETWORK_AVAILABLE: params = NetworkTopologyAudioParams(global_seed=seed, length=total_length) generator = NetworkTopologyAudioWrapper(params) elif synth_generator == "Stochastic Rhythm" and RHYTHM_AVAILABLE: params = StochasticRhythmAudioParams(global_seed=seed, length=total_length) generator = StochasticRhythmAudioWrapper(params) elif synth_generator == "CauKer" and CAUKER_AVAILABLE: params = CauKerGeneratorParams(global_seed=seed, length=total_length) generator = CauKerGeneratorWrapper(params) elif synth_generator == "Forecast PFN Prior" and FORECAST_PFN_AVAILABLE: params = ForecastPFNGeneratorParams(global_seed=seed, length=total_length) generator = ForecastPFNGeneratorWrapper(params) elif synth_generator == "Kernel Synth" and KERNEL_AVAILABLE: params = KernelGeneratorParams(global_seed=seed, length=total_length) generator = KernelGeneratorWrapper(params) else: # Fallback to sine waves if generator not available params = SineWaveGeneratorParams(global_seed=seed, length=total_length) generator = SineWaveGeneratorWrapper(params) # Generate the batch batch = generator.generate_batch(batch_size=1, seed=seed) values = torch.from_numpy(batch.values).to(torch.float32) if values.ndim == 2: values = values.unsqueeze(-1) all_values = values.squeeze().numpy() data_start_date = batch.start[0] if hasattr(batch, 'start') and batch.start else np.datetime64("2020-01-01") frequency = batch.frequency[0] if hasattr(batch, 'frequency') and batch.frequency else "D" print(f"--- {synth_generator} data generated: {len(all_values)} points ---") else: # "Synthetic Data" values, start, frequency = generate_synthetic_data(length=history_length + forecast_horizon, seed=seed) all_values, data_start_date = values, start # --- Common Logic for Slicing Data --- if data_source != "Synthetic Data": total_needed = history_length + forecast_horizon if len(all_values) < total_needed: return None, None, None, None, f"Data has {len(all_values)} points, but {total_needed} are needed." values = all_values[-total_needed:] start_offset_days = len(all_values) - total_needed start = np.datetime64(data_start_date) + np.timedelta64(start_offset_days, 'D') if all_volumes is not None: history_volumes = all_volumes[-(total_needed) : -forecast_horizon] else: history_volumes = np.array([np.nan] * history_length) else: start = data_start_date history_volumes = np.array([np.nan] * history_length) # --- Prepare data for model --- # Unsqueeze calls convert the 1D array into the required [B, S, N] shape: [1, S, 1] values_tensor = torch.from_numpy(values).unsqueeze(0).unsqueeze(-1) future_length = forecast_horizon # --- Convert string to the correct Frequency enum --- if isinstance(frequency, str): if frequency.startswith("D"): freq_object = Frequency.D elif frequency.startswith("W"): freq_object = Frequency.W elif frequency.startswith("M"): freq_object = Frequency.M elif frequency.startswith("Q"): freq_object = Frequency.Q elif frequency.startswith("A") or frequency.startswith("Y"): freq_object = Frequency.A else: print(f"Warning: Unknown frequency string '{frequency}'. Defaulting to Daily.") freq_object = Frequency.D else: freq_object = frequency # Prepare container for GPU inference history_values_tensor = values_tensor[:, :-future_length, :] future_values_tensor = values_tensor[:, -future_length:, :] # Ensure start is np.datetime64 if not isinstance(start, np.datetime64): start = np.datetime64(start) # Run GPU inference (all CUDA ops happen inside the decorated function) # Pass CPU tensors - they will be moved to GPU inside the function model_output = run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object) # Post-process predictions (exactly like examples/utils.py lines 65-69) preds_full = model_output["result"].to(torch.float32) if model is not None and hasattr(model, "scaler") and "scale_statistics" in model_output: preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"]) # Convert to numpy for plotting preds_np = preds_full.detach().cpu().numpy() history_np = history_values_tensor.cpu().numpy().squeeze(0) future_np = future_values_tensor.cpu().numpy().squeeze(0) preds_squeezed = preds_np.squeeze(0) # Get model quantiles if available model_quantiles = None if model is not None and hasattr(model, "loss_type") and model.loss_type == "quantile": model_quantiles = model.quantiles try: forecast_plot = plot_multivariate_timeseries( history_values=history_np, future_values=future_np, predicted_values=preds_squeezed, start=start, frequency=freq_object, title=f"TempoPFN Forecast - {data_source}", show=False # Don't show the plot, we'll display in Gradio ) except Exception as plot_error: print(f"Warning: Failed to generate plot: {plot_error}") # Create a simple error plot import plotly.graph_objects as go forecast_plot = go.Figure() forecast_plot.add_annotation( text="Plot generation failed", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14, color="red") ) # Calculate comprehensive metrics metrics = calculate_metrics( history_values=history_np, predictions=preds_squeezed, future_values=future_np, data_source=data_source ) # Store results globally for export functionality global last_forecast_results, last_metrics_results, last_analysis_results last_forecast_results = { 'history': history_np, 'predictions': preds_squeezed, 'future': future_np, 'start': start, 'frequency': freq_object } last_metrics_results = metrics last_analysis_results = { 'data_source': data_source, 'forecast_horizon': forecast_horizon, 'history_length': history_length, 'seed': seed } # Create data preview DataFrame preview_data = { 'Index': list(range(len(history_np))), 'Historical Value': history_np.flatten()[:100] # Limit to first 100 for display } if history_volumes is not None and not np.all(np.isnan(history_volumes)): preview_data['Volume'] = history_volumes[:100] data_preview_df = pd.DataFrame(preview_data) return ( history_np, history_volumes, preds_squeezed, model_quantiles, forecast_plot, "Forecasting completed successfully!", metrics, data_preview_df ) except Exception as e: traceback.print_exc() error_msg = f"Error during forecasting: {str(e)}" empty_metrics = {k: 0.0 if isinstance(v, float) else "" for k, v in calculate_metrics(np.array([0.0]), np.array([0.0])).items()} return None, None, None, None, None, error_msg, empty_metrics, pd.DataFrame() # --- [GRADIO UI - Simplified with Default Styling] --- with gr.Blocks(title="TempoPFN") as app: gr.Markdown("# TempoPFN\n### Zero-Shot Forecasting & Analysis Terminal\n*Powered by synthetic pre-training • Forecast anything, anywhere*") gr.Markdown("⚠️ **First Run Note**: Initial inference may take 60-90 seconds due to Triton kernel compilation. Subsequent runs will be much faster!") with gr.Tabs() as tabs: # ===== FINANCIAL MARKETS TAB ===== with gr.TabItem("Financial Markets", id="financial"): with gr.Row(): with gr.Column(scale=1, min_width=380): # Data Source Section gr.Markdown("### Financial Data Sources") financial_source = gr.Radio( choices=["Default (WTI Oil Prices)", "Stock Ticker", "VIX Volatility Index", "Upload Custom CSV"], value="Default (WTI Oil Prices)", label="", info="Choose financial market data or upload your own" ) # Combine the selections data_source = gr.Textbox(visible=False) # Dynamic inputs with gr.Row(): stock_ticker = gr.Textbox( label="Stock Ticker", value="SPY", placeholder="e.g., SPY, AAPL, TSLA", visible=False ) uploaded_file = gr.File( label="CSV File", file_types=[".csv"], visible=False ) def toggle_financial_input(choice): show_ticker = (choice == "Stock Ticker") show_upload = (choice == "Upload Custom CSV") return ( gr.update(visible=show_ticker), gr.update(visible=show_upload) ) # Handle selection changes financial_source.change( fn=lambda x: x, # Just pass through the selection inputs=financial_source, outputs=data_source ).then( fn=toggle_financial_input, inputs=financial_source, outputs=[stock_ticker, uploaded_file] ) # Forecasting Parameters Section gr.Markdown("### Forecasting Parameters") forecast_horizon = gr.Slider( minimum=30, maximum=512, value=90, step=1, label="Forecast Horizon", info="Number of periods to forecast ahead" ) history_length = gr.Slider( minimum=256, maximum=2048, value=1024, step=8, label="History Length", info="Historical data points to analyze" ) financial_forecast_btn = gr.Button("Run Forecast & Analysis") with gr.Column(scale=3): # Status Section gr.Markdown("### Analysis Results") status_text = gr.Textbox( label="", interactive=False, lines=3, info="Forecasting progress and results" ) # Key Metrics Section (Adaptive based on data source) gr.Markdown("### Key Metrics") # Financial metrics (shown for financial data) with gr.Row(visible=True) as financial_metrics: with gr.Column(): gr.Markdown("**Latest Level:** $0.00") latest_price_out = gr.Number(visible=False) gr.Markdown("**Forecast (Next Period):** $0.00") forecast_next_out = gr.Number(visible=False) with gr.Column(): gr.Markdown("**30-Day Volatility:** 0.00%") vol_30d_out = gr.Number(visible=False) with gr.Row(): gr.Markdown("**52-Week High:** $0.00") high_52wk_out = gr.Number(visible=False) gr.Markdown("**52-Week Low:** $0.00") low_52wk_out = gr.Number(visible=False) # Comprehensive Research Metrics (shown for synthetic data) with gr.Row(visible=False) as synthetic_metrics: with gr.Column(): gr.Markdown("**Statistical Properties:**") gr.Markdown("• **Mean:** 0.000") data_mean_out = gr.Number(visible=False) gr.Markdown("• **Std Dev:** 0.000") data_std_out = gr.Number(visible=False) gr.Markdown("• **Skewness:** 0.000") data_skewness_out = gr.Number(visible=False) gr.Markdown("• **Kurtosis:** 0.000") data_kurtosis_out = gr.Number(visible=False) with gr.Column(): gr.Markdown("**Time Series Analysis:**") gr.Markdown("• **Autocorr (lag-1):** 0.000") data_autocorr_out = gr.Number(visible=False) gr.Markdown("• **Stationary:** Unknown") data_stationary_out = gr.Textbox(visible=False) gr.Markdown("• **Pattern Type:** None") pattern_type_out = gr.Textbox(visible=False) # Model Performance Metrics with gr.Row(visible=False) as performance_metrics: with gr.Column(): gr.Markdown("**Forecast Performance:**") gr.Markdown("• **MSE:** 0.000") mse_out = gr.Number(visible=False) gr.Markdown("• **MAE:** 0.000") mae_out = gr.Number(visible=False) gr.Markdown("• **MAPE:** 0.000%") mape_out = gr.Number(visible=False) with gr.Column(): gr.Markdown("**Uncertainty Quantification:**") gr.Markdown("• **80% Coverage:** 0.000") coverage_80_out = gr.Number(visible=False) gr.Markdown("• **95% Coverage:** 0.000") coverage_95_out = gr.Number(visible=False) gr.Markdown("• **Calibration:** 0.000") calibration_out = gr.Number(visible=False) # Data Complexity Metrics with gr.Row(visible=False) as complexity_metrics: with gr.Column(): gr.Markdown("**Information Theory:**") gr.Markdown("• **Sample Entropy:** 0.000") sample_entropy_out = gr.Number(visible=False) gr.Markdown("• **Approx Entropy:** 0.000") approx_entropy_out = gr.Number(visible=False) gr.Markdown("• **Perm Entropy:** 0.000") perm_entropy_out = gr.Number(visible=False) with gr.Column(): gr.Markdown("**Complexity Measures:**") gr.Markdown("• **Fractal Dim:** 0.000") fractal_dim_out = gr.Number(visible=False) gr.Markdown("• **Dominant Freq:** 0.000") dominant_freq_out = gr.Number(visible=False) gr.Markdown("• **Spectral Centroid:** 0.000") spectral_centroid_out = gr.Number(visible=False) gr.Markdown("• **Spectral Entropy:** 0.000") spectral_entropy_out = gr.Number(visible=False) # Research Tools Section with gr.Row(visible=False) as research_tools: with gr.Column(): gr.Markdown("**Cross-Validation Results:**") gr.Markdown("• **Rolling Window MSE:** 0.000") cv_mse_out = gr.Number(visible=False) gr.Markdown("• **Rolling Window MAE:** 0.000") cv_mae_out = gr.Number(visible=False) gr.Markdown("• **Validation Windows:** 0") cv_windows_out = gr.Number(visible=False) with gr.Column(): gr.Markdown("**Parameter Sensitivity:**") gr.Markdown("• **Horizon Sensitivity:** 0.000") horizon_sensitivity_out = gr.Number(visible=False) gr.Markdown("• **History Sensitivity:** 0.000") history_sensitivity_out = gr.Number(visible=False) gr.Markdown("• **Stability Score:** 0.000") stability_score_out = gr.Number(visible=False) # Forecast Visualization Section gr.Markdown("### Forecast & Technical Analysis") plot_output = gr.Plot( label="", show_label=False ) # Advanced Visualizations Section with gr.Accordion("Advanced Statistical Visualizations", open=False): advanced_plots = gr.Plot(label="", show_label=False) # Export & Analysis Tools Section with gr.Accordion("Export & Analysis Tools", open=False): with gr.Row(): export_forecast_csv = gr.Button("📊 Export Forecast Data (CSV)") export_metrics_csv = gr.Button("📈 Export Metrics Summary (CSV)") export_analysis_csv = gr.Button("🔬 Export Full Analysis (CSV)") export_status = gr.Textbox( label="Export Status", interactive=False, lines=2, info="Export operation results" ) export_file = gr.File( label="Download Exported Data", visible=False ) # Data Preview Section with gr.Accordion("Raw Data Preview", open=False): data_preview = gr.Dataframe( label="", show_label=False, wrap=True ) # ===== RESEARCH & ANALYSIS TAB ===== with gr.TabItem("Research & Analysis", id="research"): with gr.Row(): with gr.Column(scale=1, min_width=380): # Data Source Section gr.Markdown("### Synthetic Data Testing") research_source = gr.Radio( choices=["Basic Synthetic", "Advanced Synthetic"], value="Basic Synthetic", label="", info="Test TempoPFN with synthetic data patterns" ) # Dynamic inputs for research tab seed = gr.Number( value=42, label="Random Seed", minimum=0, maximum=9999, step=1, visible=False ) # Build available generator choices available_generators = [ "Sine Waves", "Sawtooth Waves", "Spikes", "Steps", "Ornstein-Uhlenbeck", "Anomaly Patterns" ] if GP_AVAILABLE: available_generators.append("Gaussian Processes") if AUDIO_AVAILABLE: available_generators.extend(["Financial Volatility", "Fractal Patterns"]) if NETWORK_AVAILABLE: available_generators.append("Network Topology") if RHYTHM_AVAILABLE: available_generators.append("Stochastic Rhythm") if CAUKER_AVAILABLE: available_generators.append("CauKer") if FORECAST_PFN_AVAILABLE: available_generators.append("Forecast PFN Prior") if KERNEL_AVAILABLE: available_generators.append("Kernel Synth") # Synthetic Playground controls with gr.Row(): synth_generator = gr.Dropdown( choices=available_generators, value="Sine Waves", label="Generator Type", visible=False, info="Select synthetic pattern generator" ) synth_complexity = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Complexity", visible=False, info="Pattern complexity level" ) def toggle_research_input(choice): show_seed = (choice == "Basic Synthetic") show_synth = (choice == "Advanced Synthetic") return ( gr.update(visible=show_seed), gr.update(visible=show_synth), gr.update(visible=show_synth) ) # Handle selection changes research_source.change( fn=lambda x: x, # Just pass through the selection inputs=research_source, outputs=data_source ).then( fn=toggle_research_input, inputs=research_source, outputs=[seed, synth_generator, synth_complexity] ) # Forecasting Parameters Section gr.Markdown("### Forecasting Parameters") forecast_horizon = gr.Slider( minimum=30, maximum=512, value=90, step=1, label="Forecast Horizon", info="Number of periods to forecast ahead" ) history_length = gr.Slider( minimum=256, maximum=2048, value=1024, step=8, label="History Length", info="Historical data points to analyze" ) forecast_btn = gr.Button("Run Forecast & Analysis") with gr.Column(scale=3): # Status Section gr.Markdown("### Analysis Results") research_status_text = gr.Textbox( label="", interactive=False, lines=3, info="Forecasting progress and results" ) # Key Metrics Section (Adaptive based on data source) gr.Markdown("### Key Metrics") # Financial metrics (shown for financial data) with gr.Row(visible=True) as financial_metrics: with gr.Column(): gr.Markdown("**Latest Level:** $0.00") latest_price_out = gr.Number(visible=False) gr.Markdown("**Forecast (Next Period):** $0.00") forecast_next_out = gr.Number(visible=False) with gr.Column(): gr.Markdown("**30-Day Volatility:** 0.00%") vol_30d_out = gr.Number(visible=False) with gr.Row(): gr.Markdown("**52-Week High:** $0.00") high_52wk_out = gr.Number(visible=False) gr.Markdown("**52-Week Low:** $0.00") low_52wk_out = gr.Number(visible=False) # Synthetic/Research metrics (shown for synthetic data) with gr.Row(visible=False) as synthetic_metrics: with gr.Column(): gr.Markdown("**Data Mean:** 0.000") data_mean_out = gr.Number(visible=False) gr.Markdown("**Data Std Dev:** 0.000") data_std_out = gr.Number(visible=False) with gr.Column(): gr.Markdown("**Forecast Horizon:** 0") forecast_accuracy_out = gr.Number(visible=False) gr.Markdown("**Pattern Type:** None") pattern_type_out = gr.Textbox(visible=False) # Model Performance Metrics with gr.Row(visible=False) as performance_metrics: with gr.Column(): gr.Markdown("**Forecast Performance:**") gr.Markdown("• **MSE:** 0.000") mse_out = gr.Number(visible=False) gr.Markdown("• **MAE:** 0.000") mae_out = gr.Number(visible=False) gr.Markdown("• **MAPE:** 0.000%") mape_out = gr.Number(visible=False) with gr.Column(): gr.Markdown("**Uncertainty Quantification:**") gr.Markdown("• **80% Coverage:** 0.000") coverage_80_out = gr.Number(visible=False) gr.Markdown("• **95% Coverage:** 0.000") coverage_95_out = gr.Number(visible=False) gr.Markdown("• **Calibration:** 0.000") calibration_out = gr.Number(visible=False) # Data Complexity Metrics with gr.Row(visible=False) as complexity_metrics: with gr.Column(): gr.Markdown("**Information Theory:**") gr.Markdown("• **Sample Entropy:** 0.000") sample_entropy_out = gr.Number(visible=False) gr.Markdown("• **Approx Entropy:** 0.000") approx_entropy_out = gr.Number(visible=False) gr.Markdown("• **Perm Entropy:** 0.000") perm_entropy_out = gr.Number(visible=False) with gr.Column(): gr.Markdown("**Complexity Measures:**") gr.Markdown("• **Fractal Dim:** 0.000") fractal_dim_out = gr.Number(visible=False) gr.Markdown("• **Dominant Freq:** 0.000") dominant_freq_out = gr.Number(visible=False) gr.Markdown("• **Spectral Centroid:** 0.000") spectral_centroid_out = gr.Number(visible=False) gr.Markdown("• **Spectral Entropy:** 0.000") spectral_entropy_out = gr.Number(visible=False) # Research Tools Section with gr.Row(visible=False) as research_tools: with gr.Column(): gr.Markdown("**Cross-Validation Results:**") gr.Markdown("• **Rolling Window MSE:** 0.000") cv_mse_out = gr.Number(visible=False) gr.Markdown("• **Rolling Window MAE:** 0.000") cv_mae_out = gr.Number(visible=False) gr.Markdown("• **Validation Windows:** 0") cv_windows_out = gr.Number(visible=False) with gr.Column(): gr.Markdown("**Parameter Sensitivity:**") gr.Markdown("• **Horizon Sensitivity:** 0.000") horizon_sensitivity_out = gr.Number(visible=False) gr.Markdown("• **History Sensitivity:** 0.000") history_sensitivity_out = gr.Number(visible=False) gr.Markdown("• **Stability Score:** 0.000") stability_score_out = gr.Number(visible=False) # Forecast Visualization Section gr.Markdown("### Forecast & Technical Analysis") research_plot_output = gr.Plot( label="", show_label=False ) # Advanced Visualizations Section (Research tab doesn't have this defined, so add it) with gr.Accordion("Advanced Statistical Visualizations", open=False): research_advanced_plots = gr.Plot(label="", show_label=False) # Data Preview Section with gr.Accordion("Raw Data Preview", open=False): research_data_preview = gr.Dataframe( label="", show_label=False, wrap=True ) # Now add the metrics toggle function after components are defined def toggle_metrics_display(choice): """Toggle between financial and synthetic metrics based on data source""" show_financial = choice in ["Stock Ticker", "Default (WTI Oil Prices)", "VIX Volatility Index"] show_synthetic = choice in ["Basic Synthetic", "Advanced Synthetic", "Upload Custom CSV"] show_performance = show_synthetic # Show performance metrics for synthetic data show_complexity = show_synthetic # Show complexity metrics for synthetic data return ( gr.update(visible=show_financial), gr.update(visible=show_synthetic), gr.update(visible=show_performance), gr.update(visible=show_complexity) ) # Add the metrics toggle to the selection change handlers financial_source.change( fn=toggle_metrics_display, inputs=data_source, outputs=[financial_metrics, synthetic_metrics, performance_metrics, complexity_metrics] ) research_source.change( fn=toggle_metrics_display, inputs=data_source, outputs=[financial_metrics, synthetic_metrics, performance_metrics, complexity_metrics] ) # Wrapper function to unpack forecast results for UI def forecast_and_display_financial(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed): result = forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, "Sine Waves", 5) if result[5] and "Error" not in result[5]: # Check status history_np = result[0] preds = result[2] future_np = last_forecast_results['future'] if last_forecast_results else None # Generate advanced visualizations adv_viz = create_advanced_visualizations(history_np, preds, future_np) return ( result[5], # status_text result[4], # plot_output result[7], # data_preview adv_viz # advanced_plots ) else: return result[5], None, pd.DataFrame(), go.Figure() def forecast_and_display_research(data_source, forecast_horizon, history_length, seed, synth_generator, synth_complexity): result = forecast_time_series(data_source, "", None, forecast_horizon, history_length, seed, synth_generator, synth_complexity) if result[5] and "Error" not in result[5]: history_np = result[0] preds = result[2] future_np = last_forecast_results['future'] if last_forecast_results else None # Generate advanced visualizations adv_viz = create_advanced_visualizations(history_np, preds, future_np) return ( result[5], # status_text result[4], # plot_output result[7], # data_preview adv_viz # advanced_plots ) else: return result[5], None, pd.DataFrame(), go.Figure() # Connect button click handlers financial_forecast_btn.click( fn=forecast_and_display_financial, inputs=[data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed], outputs=[status_text, plot_output, data_preview, advanced_plots] ) forecast_btn.click( fn=forecast_and_display_research, inputs=[data_source, forecast_horizon, history_length, seed, synth_generator, synth_complexity], outputs=[research_status_text, research_plot_output, research_data_preview, research_advanced_plots] ) # Wrapper for export functions to show file def export_forecast_wrapper(): file, status = export_forecast_csv() return gr.update(value=file, visible=file is not None), status def export_metrics_wrapper(): file, status = export_metrics_csv() return gr.update(value=file, visible=file is not None), status def export_analysis_wrapper(): file, status = export_analysis_csv() return gr.update(value=file, visible=file is not None), status # Connect export button handlers export_forecast_csv.click( fn=export_forecast_wrapper, inputs=[], outputs=[export_file, export_status] ) export_metrics_csv.click( fn=export_metrics_wrapper, inputs=[], outputs=[export_file, export_status] ) export_analysis_csv.click( fn=export_analysis_wrapper, inputs=[], outputs=[export_file, export_status] ) return app # Return the Gradio app object # --- GRADIO APP LAUNCH --- app = create_gradio_app() app.launch()