altpuppet
Claude
commited on
Commit
·
c87b856
1
Parent(s):
426ae1a
Separate GPU inference from post-processing to fix ZeroGPU 404 error
Browse files- Created dedicated run_gpu_inference() function with @spaces.GPU decorator
- GPU function returns only simple serializable types (numpy arrays, list)
- Main forecast_time_series() handles data loading and post-processing
- Prevents complex objects (Plotly figures, DataFrames, dicts) from crossing GPU boundary
- This matches the pattern from the original working implementation
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <[email protected]>
app.py
CHANGED
|
@@ -459,6 +459,50 @@ def create_gradio_app():
|
|
| 459 |
return metrics
|
| 460 |
|
| 461 |
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
|
| 463 |
"""
|
| 464 |
Runs the TempoPFN forecast.
|
|
@@ -631,49 +675,20 @@ def create_gradio_app():
|
|
| 631 |
history_values_tensor = values_tensor[:, :-future_length, :]
|
| 632 |
future_values_tensor = values_tensor[:, -future_length:, :]
|
| 633 |
|
| 634 |
-
# Ensure start is np.datetime64
|
| 635 |
if not isinstance(start, np.datetime64):
|
| 636 |
start = np.datetime64(start)
|
| 637 |
|
| 638 |
-
#
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
print("--- Loading TempoPFN model for the first time ---")
|
| 642 |
-
device = torch.device("cuda:0")
|
| 643 |
-
print(f"Downloading model...")
|
| 644 |
-
model_path = hf_hub_download(repo_id="AutoML-org/TempoPFN", filename="models/checkpoint_38M.pth")
|
| 645 |
-
print(f"Loading model from {model_path} to {device}...")
|
| 646 |
-
model = load_model(config_path="configs/example.yaml", model_path=model_path, device=device)
|
| 647 |
-
print("--- Model loaded successfully ---")
|
| 648 |
-
|
| 649 |
-
# Prepare container and run inference
|
| 650 |
-
container = BatchTimeSeriesContainer(
|
| 651 |
-
history_values=history_values_tensor.to(device),
|
| 652 |
-
future_values=future_values_tensor.to(device),
|
| 653 |
-
start=[start], # List with single np.datetime64 element
|
| 654 |
-
frequency=[freq_object], # List with single Frequency enum element
|
| 655 |
)
|
| 656 |
|
| 657 |
-
#
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
# Post-process predictions
|
| 662 |
-
preds_full = model_output["result"].to(torch.float32)
|
| 663 |
-
if hasattr(model, "scaler") and "scale_statistics" in model_output:
|
| 664 |
-
preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
|
| 665 |
-
|
| 666 |
-
preds_full_cpu = preds_full.detach().cpu()
|
| 667 |
-
preds_np = preds_full_cpu.numpy()
|
| 668 |
-
history_np = history_values_tensor.squeeze(0).numpy()
|
| 669 |
-
future_np = future_values_tensor.squeeze(0).numpy()
|
| 670 |
preds_squeezed = preds_np.squeeze(0)
|
| 671 |
|
| 672 |
-
# Get model quantiles (if available)
|
| 673 |
-
model_quantiles = None
|
| 674 |
-
if model is not None and hasattr(model, "loss_type") and model.loss_type == "quantile":
|
| 675 |
-
model_quantiles = model.quantiles
|
| 676 |
-
|
| 677 |
try:
|
| 678 |
forecast_plot = plot_multivariate_timeseries(
|
| 679 |
history_values=history_np,
|
|
|
|
| 459 |
return metrics
|
| 460 |
|
| 461 |
@spaces.GPU
|
| 462 |
+
def run_gpu_inference(history_values_tensor, future_values_tensor, start, freq_object):
|
| 463 |
+
"""
|
| 464 |
+
GPU-only inference function. Returns only simple serializable types.
|
| 465 |
+
Returns: (history_np, future_np, preds_np, model_quantiles)
|
| 466 |
+
"""
|
| 467 |
+
global model, device
|
| 468 |
+
if model is None:
|
| 469 |
+
print("--- Loading TempoPFN model for the first time ---")
|
| 470 |
+
device = torch.device("cuda:0")
|
| 471 |
+
print(f"Downloading model...")
|
| 472 |
+
model_path = hf_hub_download(repo_id="AutoML-org/TempoPFN", filename="models/checkpoint_38M.pth")
|
| 473 |
+
print(f"Loading model from {model_path} to {device}...")
|
| 474 |
+
model = load_model(config_path="configs/example.yaml", model_path=model_path, device=device)
|
| 475 |
+
print("--- Model loaded successfully ---")
|
| 476 |
+
|
| 477 |
+
# Prepare container
|
| 478 |
+
container = BatchTimeSeriesContainer(
|
| 479 |
+
history_values=history_values_tensor.to(device),
|
| 480 |
+
future_values=future_values_tensor.to(device),
|
| 481 |
+
start=[start],
|
| 482 |
+
frequency=[freq_object],
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Run inference with autocast
|
| 486 |
+
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
| 487 |
+
model_output = model(container)
|
| 488 |
+
|
| 489 |
+
# Post-process predictions
|
| 490 |
+
preds_full = model_output["result"].to(torch.float32)
|
| 491 |
+
if hasattr(model, "scaler") and "scale_statistics" in model_output:
|
| 492 |
+
preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
|
| 493 |
+
|
| 494 |
+
# Convert to numpy arrays before returning (detach from GPU)
|
| 495 |
+
preds_np = preds_full.detach().cpu().numpy()
|
| 496 |
+
history_np = history_values_tensor.cpu().numpy()
|
| 497 |
+
future_np = future_values_tensor.cpu().numpy()
|
| 498 |
+
|
| 499 |
+
# Get model quantiles if available
|
| 500 |
+
model_quantiles = None
|
| 501 |
+
if hasattr(model, "loss_type") and model.loss_type == "quantile":
|
| 502 |
+
model_quantiles = model.quantiles
|
| 503 |
+
|
| 504 |
+
return history_np, future_np, preds_np, model_quantiles
|
| 505 |
+
|
| 506 |
def forecast_time_series(data_source, stock_ticker, uploaded_file, forecast_horizon, history_length, seed, synth_generator="Sine Waves", synth_complexity=5):
|
| 507 |
"""
|
| 508 |
Runs the TempoPFN forecast.
|
|
|
|
| 675 |
history_values_tensor = values_tensor[:, :-future_length, :]
|
| 676 |
future_values_tensor = values_tensor[:, -future_length:, :]
|
| 677 |
|
| 678 |
+
# Ensure start is np.datetime64
|
| 679 |
if not isinstance(start, np.datetime64):
|
| 680 |
start = np.datetime64(start)
|
| 681 |
|
| 682 |
+
# Call GPU inference function (returns only simple types)
|
| 683 |
+
history_np, future_np, preds_np, model_quantiles = run_gpu_inference(
|
| 684 |
+
history_values_tensor, future_values_tensor, start, freq_object
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
)
|
| 686 |
|
| 687 |
+
# Squeeze arrays for plotting
|
| 688 |
+
history_np = history_np.squeeze(0)
|
| 689 |
+
future_np = future_np.squeeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
preds_squeezed = preds_np.squeeze(0)
|
| 691 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
try:
|
| 693 |
forecast_plot = plot_multivariate_timeseries(
|
| 694 |
history_values=history_np,
|