Refactor to match paper's implementation pattern exactly
Browse filesKey changes to align with examples/quick_start_tempo_pfn.py and examples/utils.py:
1. GPU function now takes container and returns model_output dict (not numpy arrays)
2. Container preparation happens OUTSIDE the GPU-decorated function
3. Post-processing (inverse scaling, CPU conversion) happens AFTER GPU call
4. Model stays loaded in global state between calls
5. Only the actual inference (model(container)) is GPU-decorated
This matches the exact pattern from the working reference implementation:
- examples/quick_start_tempo_pfn.py lines 66-75 (container prep outside)
- examples/utils.py lines 56-69 (inference with autocast, then post-process)
The key insight: ZeroGPU works best when the decorated function is minimal
and doesn't handle data preparation or complex post-processing.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <[email protected]>
|
@@ -459,12 +459,14 @@ def create_gradio_app():
|
|
| 459 |
return metrics
|
| 460 |
|
| 461 |
@spaces.GPU
|
| 462 |
-
def run_gpu_inference(
|
| 463 |
"""
|
| 464 |
-
GPU-only inference function
|
| 465 |
-
|
| 466 |
"""
|
| 467 |
global model, device
|
|
|
|
|
|
|
| 468 |
if model is None:
|
| 469 |
print("--- Loading TempoPFN model for the first time ---")
|
| 470 |
device = torch.device("cuda:0")
|
|
@@ -474,34 +476,11 @@ def create_gradio_app():
|
|
| 474 |
model = load_model(config_path="configs/example.yaml", model_path=model_path, device=device)
|
| 475 |
print("--- Model loaded successfully ---")
|
| 476 |
|
| 477 |
-
#
|
| 478 |
-
container = BatchTimeSeriesContainer(
|
| 479 |
-
history_values=history_values_tensor.to(device),
|
| 480 |
-
future_values=future_values_tensor.to(device),
|
| 481 |
-
start=[start],
|
| 482 |
-
frequency=[freq_object],
|
| 483 |
-
)
|
| 484 |
-
|
| 485 |
-
# Run inference with autocast
|
| 486 |
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
| 487 |
model_output = model(container)
|
| 488 |
|
| 489 |
-
|
| 490 |
-
preds_full = model_output["result"].to(torch.float32)
|
| 491 |
-
if hasattr(model, "scaler") and "scale_statistics" in model_output:
|
| 492 |
-
preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
|
| 493 |
-
|
| 494 |
-
# Convert to numpy arrays before returning (detach from GPU)
|
| 495 |
-
preds_np = preds_full.detach().cpu().numpy()
|
| 496 |
-
history_np = history_values_tensor.cpu().numpy()
|
| 497 |
-
future_np = future_values_tensor.cpu().numpy()
|
| 498 |
-
|
| 499 |
-
# Get model quantiles if available
|
| 500 |
-
model_quantiles = None
|
| 501 |
-
if hasattr(model, "loss_type") and model.loss_type == "quantile":
|
| 502 |
-
model_quantiles = model.quantiles
|
| 503 |
-
|
| 504 |
-
return history_np, future_np, preds_np, model_quantiles
|
| 505 |
|
| 506 |
def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
|
| 507 |
"""
|
|
@@ -679,16 +658,37 @@ def create_gradio_app():
|
|
| 679 |
if not isinstance(start, np.datetime64):
|
| 680 |
start = np.datetime64(start)
|
| 681 |
|
| 682 |
-
#
|
| 683 |
-
|
| 684 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
)
|
| 686 |
|
| 687 |
-
#
|
| 688 |
-
|
| 689 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
preds_squeezed = preds_np.squeeze(0)
|
| 691 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
try:
|
| 693 |
forecast_plot = plot_multivariate_timeseries(
|
| 694 |
history_values=history_np,
|
|
|
|
| 459 |
return metrics
|
| 460 |
|
| 461 |
@spaces.GPU
|
| 462 |
+
def run_gpu_inference(container):
|
| 463 |
"""
|
| 464 |
+
GPU-only inference function matching the paper's implementation.
|
| 465 |
+
Takes a container and returns model output dict.
|
| 466 |
"""
|
| 467 |
global model, device
|
| 468 |
+
|
| 469 |
+
# Load model once on first call
|
| 470 |
if model is None:
|
| 471 |
print("--- Loading TempoPFN model for the first time ---")
|
| 472 |
device = torch.device("cuda:0")
|
|
|
|
| 476 |
model = load_model(config_path="configs/example.yaml", model_path=model_path, device=device)
|
| 477 |
print("--- Model loaded successfully ---")
|
| 478 |
|
| 479 |
+
# Run inference with bfloat16 autocast (exactly like the paper's examples/utils.py)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
| 481 |
model_output = model(container)
|
| 482 |
|
| 483 |
+
return model_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
|
| 485 |
def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
|
| 486 |
"""
|
|
|
|
| 658 |
if not isinstance(start, np.datetime64):
|
| 659 |
start = np.datetime64(start)
|
| 660 |
|
| 661 |
+
# Prepare container (exactly like the paper's example)
|
| 662 |
+
global device
|
| 663 |
+
if device is None:
|
| 664 |
+
device = torch.device("cuda:0")
|
| 665 |
+
|
| 666 |
+
container = BatchTimeSeriesContainer(
|
| 667 |
+
history_values=history_values_tensor.to(device),
|
| 668 |
+
future_values=future_values_tensor.to(device),
|
| 669 |
+
start=[start],
|
| 670 |
+
frequency=[freq_object],
|
| 671 |
)
|
| 672 |
|
| 673 |
+
# Run GPU inference (returns model_output dict)
|
| 674 |
+
model_output = run_gpu_inference(container)
|
| 675 |
+
|
| 676 |
+
# Post-process predictions (exactly like examples/utils.py lines 65-69)
|
| 677 |
+
preds_full = model_output["result"].to(torch.float32)
|
| 678 |
+
if model is not None and hasattr(model, "scaler") and "scale_statistics" in model_output:
|
| 679 |
+
preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
|
| 680 |
+
|
| 681 |
+
# Convert to numpy for plotting
|
| 682 |
+
preds_np = preds_full.detach().cpu().numpy()
|
| 683 |
+
history_np = history_values_tensor.cpu().numpy().squeeze(0)
|
| 684 |
+
future_np = future_values_tensor.cpu().numpy().squeeze(0)
|
| 685 |
preds_squeezed = preds_np.squeeze(0)
|
| 686 |
|
| 687 |
+
# Get model quantiles if available
|
| 688 |
+
model_quantiles = None
|
| 689 |
+
if model is not None and hasattr(model, "loss_type") and model.loss_type == "quantile":
|
| 690 |
+
model_quantiles = model.quantiles
|
| 691 |
+
|
| 692 |
try:
|
| 693 |
forecast_plot = plot_multivariate_timeseries(
|
| 694 |
history_values=history_np,
|