FlareSense-v2
This model predicts on 15 minutes spectrograms if they contain a burst or not, see paper:
Usage
pip install torch torchvision huggingface_hub ecallisto_ng
"""
FlareSense v2 - Simple Usage Example
This script demonstrates how to use the FlareSense model to predict solar radio bursts
on e-Callisto data. The model is automatically downloaded from HuggingFace and cached locally.
Usage:
python example_usage.py
The model will predict on a 15-minute window of data from a specific instrument.
"""
import torch
import numpy as np
from datetime import datetime
from huggingface_hub import hf_hub_download
from ecallisto_ng.data_download.downloader import get_ecallisto_data
from ecallisto_ng.data_processing.utils import subtract_constant_background
from ecallisto_ng.plotting.plotting import plot_spectrogram
from plotly.io import show
import torch.nn as nn
from torchvision import models
import os
# ============================================================================
# Model Definition
# ============================================================================
class GrayScaleResNet(nn.Module):
"""ResNet model adapted for grayscale images (single channel)."""
def __init__(self, n_classes=1, resnet_type="resnet34"):
super().__init__()
# Load pretrained ResNet (without num_classes parameter)
if resnet_type == "resnet34":
self.resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
elif resnet_type == "resnet18":
self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
elif resnet_type == "resnet50":
self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
else:
raise ValueError(f"Unsupported resnet_type: {resnet_type}")
# Replace the final fully connected layer for our number of classes
num_features = self.resnet.fc.in_features
self.resnet.fc = nn.Linear(num_features, n_classes)
def forward(self, x):
# Convert grayscale (1 channel) to 3 channels by expanding
if x.size(1) == 1:
x = x.expand(-1, 3, -1, -1)
return self.resnet(x)
# ============================================================================
# Data Processing Functions
# ============================================================================
def remove_background(df_spectrogram) -> torch.Tensor:
"""
Remove constant background from spectrogram DataFrame.
Uses the median of the first 300 timepoints as the background.
Args:
df_spectrogram: Pandas DataFrame with time as index and frequency as columns
Returns:
Torch tensor with background removed (frequency x time)
"""
# Subtract constant background using ecallisto_ng function
df_processed = subtract_constant_background(df_spectrogram, n=300)
# Convert to numpy and transpose to (frequency, time)
# DataFrame is (time, frequency), we need (frequency, time)
array_processed = df_processed.values.T
# Convert to torch tensor
tensor = torch.from_numpy(array_processed).float()
return tensor
def remove_background_median(spectrogram_tensor: torch.Tensor) -> torch.Tensor:
"""
Remove row-wise median background from spectrogram tensor.
This is applied AFTER the constant background subtraction.
Args:
spectrogram_tensor: Tensor of shape (frequency, time)
Returns:
Tensor with median background removed
"""
# Calculate the median of each row (frequency band)
median_values = torch.median(spectrogram_tensor, dim=1).values
# Subtract the median from each row
background_removed = spectrogram_tensor - median_values[:, None]
return background_removed
def resize_spectrogram(spectrogram_tensor: torch.Tensor, target_size=(128, 512)) -> torch.Tensor:
"""
Resize spectrogram to target size using bilinear interpolation.
Args:
spectrogram_tensor: Input tensor (frequency, time)
target_size: Target size (height, width)
Returns:
Resized tensor (1, height, width)
"""
# Add batch and channel dimensions for interpolation
x = spectrogram_tensor.unsqueeze(0).unsqueeze(0)
# Resize using bilinear interpolation
resized = torch.nn.functional.interpolate(
x, size=target_size, mode='bilinear', align_corners=False
)
# Remove batch dimension, keep channel dimension (1, H, W)
return resized.squeeze(0)
def min_max_scale(tensor: torch.Tensor, feature_range=(0, 1)) -> torch.Tensor:
"""
Apply Min-Max scaling to a tensor.
Args:
tensor: Input tensor
feature_range: Desired range (default: (0, 1))
Returns:
Scaled tensor
"""
min_val, max_val = feature_range
tensor_min = tensor.min()
tensor_max = tensor.max()
# Avoid division by zero
if tensor_max - tensor_min == 0:
return torch.zeros_like(tensor)
scaled_tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
scaled_tensor = scaled_tensor * (max_val - min_val) + min_val
return scaled_tensor
def preprocess_spectrogram(df_spectrogram) -> torch.Tensor:
"""
Complete preprocessing pipeline for a spectrogram DataFrame.
This follows the exact same pipeline as the training code.
Args:
df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data
Returns:
Preprocessed tensor ready for model input (1, 128, 512)
"""
# Step 1: Remove constant background and convert to tensor (frequency x time)
tensor = remove_background(df_spectrogram)
# Step 2: Remove row-wise median background
tensor = remove_background_median(tensor)
# Step 3: Resize to target size (128, 512)
# This uses normal_resize since custom_resize is False in config
tensor = resize_spectrogram(tensor, target_size=(128, 512))
# Step 4: Min-max scale to [0, 1]
tensor = min_max_scale(tensor, feature_range=(0, 1))
return tensor
# ============================================================================
# Model Loading and Prediction
# ============================================================================
def load_flaresense_model(device="cpu"):
"""
Load the FlareSense model from HuggingFace Hub.
The model is automatically downloaded and cached locally.
Args:
device: Device to load model on ('cpu' or 'cuda')
Returns:
Loaded model in evaluation mode
"""
# Model configuration (from best_v2.yml)
REPO_ID = "i4ds/flaresense-v2"
MODEL_FILENAME = "model.ckpt"
RESNET_TYPE = "resnet34"
print(f"Downloading model from {REPO_ID}...")
checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
print(f"Model cached at: {checkpoint_path}")
# Initialize model
model = GrayScaleResNet(n_classes=1, resnet_type=RESNET_TYPE)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
# Remove '_orig_mod.' prefix from keys (added by torch.compile)
new_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace("_orig_mod.", "")
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict)
# Set to evaluation mode and move to device
model.eval()
model.to(device)
print(f"Model loaded successfully on {device}")
return model
def sigmoid(x, temperature=0.4974):
"""
Convert logit to probability using temperature-scaled sigmoid.
Args:
x: Logit value
temperature: Temperature parameter for calibration
Returns:
Probability [0, 1]
"""
return 1 / (1 + np.exp(-x / temperature))
def predict_burst(model, df_spectrogram, device="cpu"):
"""
Predict solar radio burst on a single spectrogram DataFrame.
Args:
model: Loaded FlareSense model
df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data
device: Device to run prediction on
Returns:
tuple: (logit, probability)
- logit: Raw model output
- probability: Calibrated probability [0, 1]
"""
# Preprocess the DataFrame
input_tensor = preprocess_spectrogram(df_spectrogram)
# Add batch dimension and move to device
input_batch = input_tensor.unsqueeze(0).to(device)
# Predict
with torch.no_grad():
logit = model(input_batch).squeeze().item()
# Convert to probability
probability = sigmoid(logit)
return logit, probability
# ============================================================================
# Main Example
# ============================================================================
def main():
"""Main example demonstrating how to use FlareSense for prediction."""
# Configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}\n")
# Example: Predict on data from May 7, 2021
# Create a 15-minute window centered around 03:40:30
# This gives us exactly 15 minutes: 03:33:00 to 03:48:00
start_time = datetime(2021, 5, 7, 3, 33, 0)
end_time = datetime(2021, 5, 7, 3, 48, 0)
instrument = "Australia-ASSA_01"
print(f"Example prediction on instrument: {instrument}")
# Load model (downloaded and cached automatically)
model = load_flaresense_model(device=device)
# Fetch data from e-Callisto
print(f"Fetching data from e-Callisto...")
df_dict = get_ecallisto_data(start_time, end_time, instrument)
df_spectrogram = df_dict[instrument]
print(f"Data shape: {df_spectrogram.shape} (time x frequency)")
print(f"Time range: {df_spectrogram.index[0]} to {df_spectrogram.index[-1]}")
print(f"Frequency range: {df_spectrogram.columns[0]:.2f} - {df_spectrogram.columns[-1]:.2f} MHz\n")
# Predict (pass the DataFrame directly)
print("Running prediction...")
logit, probability = predict_burst(model, df_spectrogram, device=device)
# Display results
print("\n" + "="*60)
print("PREDICTION RESULTS")
print("="*60)
print(f"Logit: {logit:.4f}")
print(f"Probability: {probability:.4f} ({probability*100:.2f}%)")
burst_detected = probability > 0.5
print(f"Prediction: {'BURST DETECTED ☀️' if burst_detected else 'No burst'}")
print("="*60)
# Plot and save the spectrogram
print("\nGenerating spectrogram plot...")
df_processed = subtract_constant_background(df_dict[instrument])
# Show the plot
fig = plot_spectrogram(df_processed)
show(fig)
if __name__ == "__main__":
main()
- Downloads last month
- 3