altpuppet Claude commited on
Commit
8e6f2eb
·
1 Parent(s): c87b856

Refactor to match paper's implementation pattern exactly

Browse files

Key changes to align with examples/quick_start_tempo_pfn.py and examples/utils.py:

1. GPU function now takes container and returns model_output dict (not numpy arrays)
2. Container preparation happens OUTSIDE the GPU-decorated function
3. Post-processing (inverse scaling, CPU conversion) happens AFTER GPU call
4. Model stays loaded in global state between calls
5. Only the actual inference (model(container)) is GPU-decorated

This matches the exact pattern from the working reference implementation:
- examples/quick_start_tempo_pfn.py lines 66-75 (container prep outside)
- examples/utils.py lines 56-69 (inference with autocast, then post-process)

The key insight: ZeroGPU works best when the decorated function is minimal
and doesn't handle data preparation or complex post-processing.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>

Files changed (1) hide show
  1. app.py +34 -34
app.py CHANGED
@@ -459,12 +459,14 @@ def create_gradio_app():
459
  return metrics
460
 
461
  @spaces.GPU
462
- def run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object):
463
  """
464
- GPU-only inference function. Returns only simple serializable types.
465
- Returns: (history_np, future_np, preds_np, model_quantiles)
466
  """
467
  global model, device
 
 
468
  if model is None:
469
  print("--- Loading TempoPFN model for the first time ---")
470
  device = torch.device("cuda:0")
@@ -474,34 +476,11 @@ def create_gradio_app():
474
  model = load_model(config_path="configs/example.yaml", model_path=model_path, device=device)
475
  print("--- Model loaded successfully ---")
476
 
477
- # Prepare container
478
- container = BatchTimeSeriesContainer(
479
- history_values=history_values_tensor.to(device),
480
- future_values=future_values_tensor.to(device),
481
- start=[start],
482
- frequency=[freq_object],
483
- )
484
-
485
- # Run inference with autocast
486
  with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
487
  model_output = model(container)
488
 
489
- # Post-process predictions
490
- preds_full = model_output["result"].to(torch.float32)
491
- if hasattr(model, "scaler") and "scale_statistics" in model_output:
492
- preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
493
-
494
- # Convert to numpy arrays before returning (detach from GPU)
495
- preds_np = preds_full.detach().cpu().numpy()
496
- history_np = history_values_tensor.cpu().numpy()
497
- future_np = future_values_tensor.cpu().numpy()
498
-
499
- # Get model quantiles if available
500
- model_quantiles = None
501
- if hasattr(model, "loss_type") and model.loss_type == "quantile":
502
- model_quantiles = model.quantiles
503
-
504
- return history_np, future_np, preds_np, model_quantiles
505
 
506
  def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
507
  """
@@ -679,16 +658,37 @@ def create_gradio_app():
679
  if not isinstance(start, np.datetime64):
680
  start = np.datetime64(start)
681
 
682
- # Call GPU inference function (returns only simple types)
683
- history_np, future_np, preds_np, model_quantiles = run_gpu_inference(
684
- history_values_tensor, future_values_tensor, start, freq_object
 
 
 
 
 
 
 
685
  )
686
 
687
- # Squeeze arrays for plotting
688
- history_np = history_np.squeeze(0)
689
- future_np = future_np.squeeze(0)
 
 
 
 
 
 
 
 
 
690
  preds_squeezed = preds_np.squeeze(0)
691
 
 
 
 
 
 
692
  try:
693
  forecast_plot = plot_multivariate_timeseries(
694
  history_values=history_np,
 
459
  return metrics
460
 
461
  @spaces.GPU
462
+ def run_gpu_inference(container):
463
  """
464
+ GPU-only inference function matching the paper's implementation.
465
+ Takes a container and returns model output dict.
466
  """
467
  global model, device
468
+
469
+ # Load model once on first call
470
  if model is None:
471
  print("--- Loading TempoPFN model for the first time ---")
472
  device = torch.device("cuda:0")
 
476
  model = load_model(config_path="configs/example.yaml", model_path=model_path, device=device)
477
  print("--- Model loaded successfully ---")
478
 
479
+ # Run inference with bfloat16 autocast (exactly like the paper's examples/utils.py)
 
 
 
 
 
 
 
 
480
  with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
481
  model_output = model(container)
482
 
483
+ return model_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
  def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
486
  """
 
658
  if not isinstance(start, np.datetime64):
659
  start = np.datetime64(start)
660
 
661
+ # Prepare container (exactly like the paper's example)
662
+ global device
663
+ if device is None:
664
+ device = torch.device("cuda:0")
665
+
666
+ container = BatchTimeSeriesContainer(
667
+ history_values=history_values_tensor.to(device),
668
+ future_values=future_values_tensor.to(device),
669
+ start=[start],
670
+ frequency=[freq_object],
671
  )
672
 
673
+ # Run GPU inference (returns model_output dict)
674
+ model_output = run_gpu_inference(container)
675
+
676
+ # Post-process predictions (exactly like examples/utils.py lines 65-69)
677
+ preds_full = model_output["result"].to(torch.float32)
678
+ if model is not None and hasattr(model, "scaler") and "scale_statistics" in model_output:
679
+ preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
680
+
681
+ # Convert to numpy for plotting
682
+ preds_np = preds_full.detach().cpu().numpy()
683
+ history_np = history_values_tensor.cpu().numpy().squeeze(0)
684
+ future_np = future_values_tensor.cpu().numpy().squeeze(0)
685
  preds_squeezed = preds_np.squeeze(0)
686
 
687
+ # Get model quantiles if available
688
+ model_quantiles = None
689
+ if model is not None and hasattr(model, "loss_type") and model.loss_type == "quantile":
690
+ model_quantiles = model.quantiles
691
+
692
  try:
693
  forecast_plot = plot_multivariate_timeseries(
694
  history_values=history_np,