# Quick Start: Univariate Quantile Forecasting (CUDA, bfloat16)

This notebook demonstrates how to:
- Generate synthetic sine wave time series data
- Pack data into `BatchTimeSeriesContainer`
- Load a pretrained model (from Hugging Face)
- Run inference with bfloat16 on CUDA
- Visualize predictions

## 1) Setup

In [None]:
from pathlib import Path

import numpy as np
import torch
from huggingface_hub import hf_hub_download

# Ensure CUDA is available
if not torch.cuda.is_available():
 raise RuntimeError("CUDA is required to run this demo. No CUDA device detected.")

device = torch.device("cuda:0")

# Resolve repository root to be robust to running from subdirectories (e.g., examples/)
repo_root = Path.cwd()
if not (repo_root / "configs").exists():
 repo_root = repo_root.parent

# Inline plotting
%matplotlib inline

## 2) Download Checkpoint from Hugging Face

In [None]:
print("Downloading model checkpoint from Hugging Face Hub...")

CHECKPOINT_PATH = hf_hub_download(repo_id="AutoML-org/TempoPFN", filename="models/checkpoint_38M.pth")

print(f"Checkpoint is available at: {CHECKPOINT_PATH}")

## 3) Generate synthetic sine wave data

In [None]:
from src.synthetic_generation.generator_params import SineWaveGeneratorParams
from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import (
 SineWaveGeneratorWrapper,
)

batch_size = 3
total_length = 1024
seed = 2025

sine_params = SineWaveGeneratorParams(global_seed=seed, length=total_length)
wrapper = SineWaveGeneratorWrapper(sine_params)

batch = wrapper.generate_batch(batch_size=batch_size, seed=seed)
values = torch.from_numpy(batch.values).to(torch.float32)
if values.ndim == 2:
 values = values.unsqueeze(-1) # [B, S, 1]

future_length = 256
history_values = values[:, :-future_length, :]
future_values = values[:, -future_length:, :]

print("History:", history_values.shape, "Future:", future_values.shape)

## 4) Build BatchTimeSeriesContainer

In [None]:
from src.data.containers import BatchTimeSeriesContainer

container = BatchTimeSeriesContainer(
 history_values=history_values.to(device),
 future_values=future_values.to(device),
 start=batch.start,
 frequency=batch.frequency,
)

container.batch_size, container.history_length, container.future_length

## 5) Load model and run inference

In [None]:
import yaml
from src.models.model import TimeSeriesModel

with open(repo_root / "configs/example.yaml") as f:
 config = yaml.safe_load(f)

model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# bfloat16 autocast on CUDA
with (
 torch.no_grad(),
 torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True),
):
 output = model(container)

preds = output["result"].to(torch.float32)
if hasattr(model, "scaler") and "scale_statistics" in output:
 preds = model.scaler.inverse_scale(preds, output["scale_statistics"])

preds.shape

## 6) Plot predictions

In [None]:
import matplotlib.pyplot as plt

plt.set_loglevel("error")

# preds: [B, P, N, Q] for quantiles (univariate -> N=1)
preds_np = preds.cpu().numpy()

batch_size = preds_np.shape[0]
prediction_length = preds_np.shape[1]
num_quantiles = preds_np.shape[-1]

for i in range(batch_size):
 fig, ax = plt.subplots(figsize=(12, 4))

 history = container.history_values[i, :, 0].detach().cpu().numpy()
 future = container.future_values[i, :, 0].detach().cpu().numpy()

 # Time axes
 hist_t = np.arange(len(history))
 fut_t = np.arange(len(history), len(history) + len(future))

 # Plot history and ground truth future
 ax.plot(hist_t, history, label="History", color="black")
 ax.plot(fut_t, future, label="Ground Truth", color="blue")

 # Plot quantiles
 median_idx = num_quantiles // 2
 ax.plot(
 fut_t,
 preds_np[i, :, 0, median_idx],
 label="Prediction (Median)",
 color="orange",
 linestyle="--",
 )
 if num_quantiles >= 3:
 ax.fill_between(
 fut_t,
 preds_np[i, :, 0, 0],
 preds_np[i, :, 0, -1],
 color="orange",
 alpha=0.2,
 label="Prediction Interval",
 )

 ax.axvline(x=len(history), color="k", linestyle=":", alpha=0.7)
 ax.set_xlabel("Time Steps")
 ax.set_ylabel("Value")
 ax.set_title(f"Sample {i + 1}")
 ax.legend()
 ax.grid(True, alpha=0.3)
 plt.show()