altpuppet Claude commited on
Commit
c87b856
·
1 Parent(s): 426ae1a

Separate GPU inference from post-processing to fix ZeroGPU 404 error

Browse files

- Created dedicated run_gpu_inference() function with @spaces.GPU decorator
- GPU function returns only simple serializable types (numpy arrays, list)
- Main forecast_time_series() handles data loading and post-processing
- Prevents complex objects (Plotly figures, DataFrames, dicts) from crossing GPU boundary
- This matches the pattern from the original working implementation

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>

Files changed (1) hide show
  1. app.py +51 -36
app.py CHANGED
@@ -459,6 +459,50 @@ def create_gradio_app():
459
  return metrics
460
 
461
  @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
463
  """
464
  Runs the TempoPFN forecast.
@@ -631,49 +675,20 @@ def create_gradio_app():
631
  history_values_tensor = values_tensor[:, :-future_length, :]
632
  future_values_tensor = values_tensor[:, -future_length:, :]
633
 
634
- # Ensure start is np.datetime64 and wrap in list
635
  if not isinstance(start, np.datetime64):
636
  start = np.datetime64(start)
637
 
638
- # Load model if needed (GPU-only app)
639
- global model, device
640
- if model is None:
641
- print("--- Loading TempoPFN model for the first time ---")
642
- device = torch.device("cuda:0")
643
- print(f"Downloading model...")
644
- model_path = hf_hub_download(repo_id="AutoML-org/TempoPFN", filename="models/checkpoint_38M.pth")
645
- print(f"Loading model from {model_path} to {device}...")
646
- model = load_model(config_path="configs/example.yaml", model_path=model_path, device=device)
647
- print("--- Model loaded successfully ---")
648
-
649
- # Prepare container and run inference
650
- container = BatchTimeSeriesContainer(
651
- history_values=history_values_tensor.to(device),
652
- future_values=future_values_tensor.to(device),
653
- start=[start], # List with single np.datetime64 element
654
- frequency=[freq_object], # List with single Frequency enum element
655
  )
656
 
657
- # Run inference with autocast
658
- with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
659
- model_output = model(container)
660
-
661
- # Post-process predictions
662
- preds_full = model_output["result"].to(torch.float32)
663
- if hasattr(model, "scaler") and "scale_statistics" in model_output:
664
- preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
665
-
666
- preds_full_cpu = preds_full.detach().cpu()
667
- preds_np = preds_full_cpu.numpy()
668
- history_np = history_values_tensor.squeeze(0).numpy()
669
- future_np = future_values_tensor.squeeze(0).numpy()
670
  preds_squeezed = preds_np.squeeze(0)
671
 
672
- # Get model quantiles (if available)
673
- model_quantiles = None
674
- if model is not None and hasattr(model, "loss_type") and model.loss_type == "quantile":
675
- model_quantiles = model.quantiles
676
-
677
  try:
678
  forecast_plot = plot_multivariate_timeseries(
679
  history_values=history_np,
 
459
  return metrics
460
 
461
  @spaces.GPU
462
+ def run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object):
463
+ """
464
+ GPU-only inference function. Returns only simple serializable types.
465
+ Returns: (history_np, future_np, preds_np, model_quantiles)
466
+ """
467
+ global model, device
468
+ if model is None:
469
+ print("--- Loading TempoPFN model for the first time ---")
470
+ device = torch.device("cuda:0")
471
+ print(f"Downloading model...")
472
+ model_path = hf_hub_download(repo_id="AutoML-org/TempoPFN", filename="models/checkpoint_38M.pth")
473
+ print(f"Loading model from {model_path} to {device}...")
474
+ model = load_model(config_path="configs/example.yaml", model_path=model_path, device=device)
475
+ print("--- Model loaded successfully ---")
476
+
477
+ # Prepare container
478
+ container = BatchTimeSeriesContainer(
479
+ history_values=history_values_tensor.to(device),
480
+ future_values=future_values_tensor.to(device),
481
+ start=[start],
482
+ frequency=[freq_object],
483
+ )
484
+
485
+ # Run inference with autocast
486
+ with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
487
+ model_output = model(container)
488
+
489
+ # Post-process predictions
490
+ preds_full = model_output["result"].to(torch.float32)
491
+ if hasattr(model, "scaler") and "scale_statistics" in model_output:
492
+ preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
493
+
494
+ # Convert to numpy arrays before returning (detach from GPU)
495
+ preds_np = preds_full.detach().cpu().numpy()
496
+ history_np = history_values_tensor.cpu().numpy()
497
+ future_np = future_values_tensor.cpu().numpy()
498
+
499
+ # Get model quantiles if available
500
+ model_quantiles = None
501
+ if hasattr(model, "loss_type") and model.loss_type == "quantile":
502
+ model_quantiles = model.quantiles
503
+
504
+ return history_np, future_np, preds_np, model_quantiles
505
+
506
  def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
507
  """
508
  Runs the TempoPFN forecast.
 
675
  history_values_tensor = values_tensor[:, :-future_length, :]
676
  future_values_tensor = values_tensor[:, -future_length:, :]
677
 
678
+ # Ensure start is np.datetime64
679
  if not isinstance(start, np.datetime64):
680
  start = np.datetime64(start)
681
 
682
+ # Call GPU inference function (returns only simple types)
683
+ history_np, future_np, preds_np, model_quantiles = run_gpu_inference(
684
+ history_values_tensor, future_values_tensor, start, freq_object
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  )
686
 
687
+ # Squeeze arrays for plotting
688
+ history_np = history_np.squeeze(0)
689
+ future_np = future_np.squeeze(0)
 
 
 
 
 
 
 
 
 
 
690
  preds_squeezed = preds_np.squeeze(0)
691
 
 
 
 
 
 
692
  try:
693
  forecast_plot = plot_multivariate_timeseries(
694
  history_values=history_np,