{ "cells": [ { "cell_type": "markdown", "id": "231c6227", "metadata": {}, "source": [ "# Quick Start: Univariate Quantile Forecasting (CUDA, bfloat16)\n", "\n", "This notebook demonstrates how to:\n", "- Generate synthetic sine wave time series data\n", "- Pack data into `BatchTimeSeriesContainer`\n", "- Load a pretrained model (from Hugging Face)\n", "- Run inference with bfloat16 on CUDA\n", "- Visualize predictions" ] }, { "cell_type": "markdown", "id": "bb6c5424-1c63-4cb0-a818-45d4199914e5", "metadata": {}, "source": [ "## 1) Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "612a78e8", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import numpy as np\n", "import torch\n", "from huggingface_hub import hf_hub_download\n", "\n", "# Ensure CUDA is available\n", "if not torch.cuda.is_available():\n", " raise RuntimeError(\"CUDA is required to run this demo. No CUDA device detected.\")\n", "\n", "device = torch.device(\"cuda:0\")\n", "\n", "# Resolve repository root to be robust to running from subdirectories (e.g., examples/)\n", "repo_root = Path.cwd()\n", "if not (repo_root / \"configs\").exists():\n", " repo_root = repo_root.parent\n", "\n", "# Inline plotting\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "id": "3facf37d-0a77-4222-8464-6e42182547f8", "metadata": {}, "source": [ "## 2) Download Checkpoint from Hugging Face" ] }, { "cell_type": "code", "execution_count": null, "id": "16dcb883", "metadata": {}, "outputs": [], "source": [ "print(\"Downloading model checkpoint from Hugging Face Hub...\")\n", "\n", "CHECKPOINT_PATH = hf_hub_download(repo_id=\"AutoML-org/TempoPFN\", filename=\"models/checkpoint_38M.pth\")\n", "\n", "print(f\"Checkpoint is available at: {CHECKPOINT_PATH}\")" ] }, { "cell_type": "markdown", "id": "9be77e34-0c7a-4056-822f-ed2e3e090c40", "metadata": {}, "source": [ "## 3) Generate synthetic sine wave data" ] }, { "cell_type": "code", "execution_count": null, "id": "1127526c", "metadata": {}, "outputs": [], "source": [ "from src.synthetic_generation.generator_params import SineWaveGeneratorParams\n", "from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import (\n", " SineWaveGeneratorWrapper,\n", ")\n", "\n", "batch_size = 3\n", "total_length = 1024\n", "seed = 2025\n", "\n", "sine_params = SineWaveGeneratorParams(global_seed=seed, length=total_length)\n", "wrapper = SineWaveGeneratorWrapper(sine_params)\n", "\n", "batch = wrapper.generate_batch(batch_size=batch_size, seed=seed)\n", "values = torch.from_numpy(batch.values).to(torch.float32)\n", "if values.ndim == 2:\n", " values = values.unsqueeze(-1) # [B, S, 1]\n", "\n", "future_length = 256\n", "history_values = values[:, :-future_length, :]\n", "future_values = values[:, -future_length:, :]\n", "\n", "print(\"History:\", history_values.shape, \"Future:\", future_values.shape)" ] }, { "cell_type": "markdown", "id": "a8844488-e51c-4805-baa9-491bfc67e8ca", "metadata": {}, "source": [ "## 4) Build BatchTimeSeriesContainer" ] }, { "cell_type": "code", "execution_count": null, "id": "f3b4d361", "metadata": {}, "outputs": [], "source": [ "from src.data.containers import BatchTimeSeriesContainer\n", "\n", "container = BatchTimeSeriesContainer(\n", " history_values=history_values.to(device),\n", " future_values=future_values.to(device),\n", " start=batch.start,\n", " frequency=batch.frequency,\n", ")\n", "\n", "container.batch_size, container.history_length, container.future_length" ] }, { "cell_type": "markdown", "id": "b5e7e790-a9aa-49c2-9d45-2dc823036883", "metadata": {}, "source": [ "## 5) Load model and run inference" ] }, { "cell_type": "code", "execution_count": null, "id": "1dd4e0e4", "metadata": {}, "outputs": [], "source": [ "import yaml\n", "from src.models.model import TimeSeriesModel\n", "\n", "with open(repo_root / \"configs/example.yaml\") as f:\n", " config = yaml.safe_load(f)\n", "\n", "model = TimeSeriesModel(**config[\"TimeSeriesModel\"]).to(device)\n", "ckpt = torch.load(CHECKPOINT_PATH, map_location=device)\n", "model.load_state_dict(ckpt[\"model_state_dict\"])\n", "model.eval()\n", "\n", "# bfloat16 autocast on CUDA\n", "with (\n", " torch.no_grad(),\n", " torch.autocast(device_type=\"cuda\", dtype=torch.bfloat16, enabled=True),\n", "):\n", " output = model(container)\n", "\n", "preds = output[\"result\"].to(torch.float32)\n", "if hasattr(model, \"scaler\") and \"scale_statistics\" in output:\n", " preds = model.scaler.inverse_scale(preds, output[\"scale_statistics\"])\n", "\n", "preds.shape" ] }, { "cell_type": "markdown", "id": "ba16120f-27c8-4462-91cb-c9b3e0630a9d", "metadata": {}, "source": [ "## 6) Plot predictions" ] }, { "cell_type": "code", "execution_count": null, "id": "9bf02a0b", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.set_loglevel(\"error\")\n", "\n", "# preds: [B, P, N, Q] for quantiles (univariate -> N=1)\n", "preds_np = preds.cpu().numpy()\n", "\n", "batch_size = preds_np.shape[0]\n", "prediction_length = preds_np.shape[1]\n", "num_quantiles = preds_np.shape[-1]\n", "\n", "for i in range(batch_size):\n", " fig, ax = plt.subplots(figsize=(12, 4))\n", "\n", " history = container.history_values[i, :, 0].detach().cpu().numpy()\n", " future = container.future_values[i, :, 0].detach().cpu().numpy()\n", "\n", " # Time axes\n", " hist_t = np.arange(len(history))\n", " fut_t = np.arange(len(history), len(history) + len(future))\n", "\n", " # Plot history and ground truth future\n", " ax.plot(hist_t, history, label=\"History\", color=\"black\")\n", " ax.plot(fut_t, future, label=\"Ground Truth\", color=\"blue\")\n", "\n", " # Plot quantiles\n", " median_idx = num_quantiles // 2\n", " ax.plot(\n", " fut_t,\n", " preds_np[i, :, 0, median_idx],\n", " label=\"Prediction (Median)\",\n", " color=\"orange\",\n", " linestyle=\"--\",\n", " )\n", " if num_quantiles >= 3:\n", " ax.fill_between(\n", " fut_t,\n", " preds_np[i, :, 0, 0],\n", " preds_np[i, :, 0, -1],\n", " color=\"orange\",\n", " alpha=0.2,\n", " label=\"Prediction Interval\",\n", " )\n", "\n", " ax.axvline(x=len(history), color=\"k\", linestyle=\":\", alpha=0.7)\n", " ax.set_xlabel(\"Time Steps\")\n", " ax.set_ylabel(\"Value\")\n", " ax.set_title(f\"Sample {i + 1}\")\n", " ax.legend()\n", " ax.grid(True, alpha=0.3)\n", " plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }