Image Classification

FlareSense-v2

This model predicts on 15 minutes spectrograms if they contain a burst or not, see paper:

Usage

pip install torch torchvision huggingface_hub ecallisto_ng
"""
FlareSense v2 - Simple Usage Example

This script demonstrates how to use the FlareSense model to predict solar radio bursts
on e-Callisto data. The model is automatically downloaded from HuggingFace and cached locally.

Usage:
    python example_usage.py

The model will predict on a 15-minute window of data from a specific instrument.
"""

import torch
import numpy as np
from datetime import datetime
from huggingface_hub import hf_hub_download
from ecallisto_ng.data_download.downloader import get_ecallisto_data
from ecallisto_ng.data_processing.utils import subtract_constant_background
from ecallisto_ng.plotting.plotting import plot_spectrogram
from plotly.io import show
import torch.nn as nn
from torchvision import models
import os

# ============================================================================
# Model Definition
# ============================================================================

class GrayScaleResNet(nn.Module):
    """ResNet model adapted for grayscale images (single channel)."""
    
    def __init__(self, n_classes=1, resnet_type="resnet34"):
        super().__init__()
        
        # Load pretrained ResNet (without num_classes parameter)
        if resnet_type == "resnet34":
            self.resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        elif resnet_type == "resnet18":
            self.resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        elif resnet_type == "resnet50":
            self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        else:
            raise ValueError(f"Unsupported resnet_type: {resnet_type}")
        
        # Replace the final fully connected layer for our number of classes
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_features, n_classes)
    
    def forward(self, x):
        # Convert grayscale (1 channel) to 3 channels by expanding
        if x.size(1) == 1:
            x = x.expand(-1, 3, -1, -1)
        return self.resnet(x)


# ============================================================================
# Data Processing Functions
# ============================================================================

def remove_background(df_spectrogram) -> torch.Tensor:
    """
    Remove constant background from spectrogram DataFrame.
    Uses the median of the first 300 timepoints as the background.
    
    Args:
        df_spectrogram: Pandas DataFrame with time as index and frequency as columns
        
    Returns:
        Torch tensor with background removed (frequency x time)
    """
    # Subtract constant background using ecallisto_ng function
    df_processed = subtract_constant_background(df_spectrogram, n=300)
    
    # Convert to numpy and transpose to (frequency, time)
    # DataFrame is (time, frequency), we need (frequency, time)
    array_processed = df_processed.values.T
    
    # Convert to torch tensor
    tensor = torch.from_numpy(array_processed).float()
    
    return tensor


def remove_background_median(spectrogram_tensor: torch.Tensor) -> torch.Tensor:
    """
    Remove row-wise median background from spectrogram tensor.
    This is applied AFTER the constant background subtraction.
    
    Args:
        spectrogram_tensor: Tensor of shape (frequency, time)
        
    Returns:
        Tensor with median background removed
    """
    # Calculate the median of each row (frequency band)
    median_values = torch.median(spectrogram_tensor, dim=1).values
    
    # Subtract the median from each row
    background_removed = spectrogram_tensor - median_values[:, None]
    
    return background_removed


def resize_spectrogram(spectrogram_tensor: torch.Tensor, target_size=(128, 512)) -> torch.Tensor:
    """
    Resize spectrogram to target size using bilinear interpolation.
    
    Args:
        spectrogram_tensor: Input tensor (frequency, time)
        target_size: Target size (height, width)
        
    Returns:
        Resized tensor (1, height, width)
    """
    # Add batch and channel dimensions for interpolation
    x = spectrogram_tensor.unsqueeze(0).unsqueeze(0)
    
    # Resize using bilinear interpolation
    resized = torch.nn.functional.interpolate(
        x, size=target_size, mode='bilinear', align_corners=False
    )
    
    # Remove batch dimension, keep channel dimension (1, H, W)
    return resized.squeeze(0)


def min_max_scale(tensor: torch.Tensor, feature_range=(0, 1)) -> torch.Tensor:
    """
    Apply Min-Max scaling to a tensor.
    
    Args:
        tensor: Input tensor
        feature_range: Desired range (default: (0, 1))
        
    Returns:
        Scaled tensor
    """
    min_val, max_val = feature_range
    tensor_min = tensor.min()
    tensor_max = tensor.max()
    
    # Avoid division by zero
    if tensor_max - tensor_min == 0:
        return torch.zeros_like(tensor)
    
    scaled_tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
    scaled_tensor = scaled_tensor * (max_val - min_val) + min_val
    
    return scaled_tensor


def preprocess_spectrogram(df_spectrogram) -> torch.Tensor:
    """
    Complete preprocessing pipeline for a spectrogram DataFrame.
    This follows the exact same pipeline as the training code.
    
    Args:
        df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data
        
    Returns:
        Preprocessed tensor ready for model input (1, 128, 512)
    """
    # Step 1: Remove constant background and convert to tensor (frequency x time)
    tensor = remove_background(df_spectrogram)
    
    # Step 2: Remove row-wise median background
    tensor = remove_background_median(tensor)
    
    # Step 3: Resize to target size (128, 512)
    # This uses normal_resize since custom_resize is False in config
    tensor = resize_spectrogram(tensor, target_size=(128, 512))
    
    # Step 4: Min-max scale to [0, 1]
    tensor = min_max_scale(tensor, feature_range=(0, 1))
    
    return tensor


# ============================================================================
# Model Loading and Prediction
# ============================================================================

def load_flaresense_model(device="cpu"):
    """
    Load the FlareSense model from HuggingFace Hub.
    The model is automatically downloaded and cached locally.
    
    Args:
        device: Device to load model on ('cpu' or 'cuda')
        
    Returns:
        Loaded model in evaluation mode
    """
    # Model configuration (from best_v2.yml)
    REPO_ID = "i4ds/flaresense-v2"
    MODEL_FILENAME = "model.ckpt"
    RESNET_TYPE = "resnet34"
    
    print(f"Downloading model from {REPO_ID}...")
    checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
    print(f"Model cached at: {checkpoint_path}")
    
    # Initialize model
    model = GrayScaleResNet(n_classes=1, resnet_type=RESNET_TYPE)
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device)
    if "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint
    
    # Remove '_orig_mod.' prefix from keys (added by torch.compile)
    new_state_dict = {}
    for key, value in state_dict.items():
        new_key = key.replace("_orig_mod.", "")
        new_state_dict[new_key] = value
    
    model.load_state_dict(new_state_dict)
    
    # Set to evaluation mode and move to device
    model.eval()
    model.to(device)
    
    print(f"Model loaded successfully on {device}")
    return model


def sigmoid(x, temperature=0.4974):
    """
    Convert logit to probability using temperature-scaled sigmoid.
    
    Args:
        x: Logit value
        temperature: Temperature parameter for calibration
        
    Returns:
        Probability [0, 1]
    """
    return 1 / (1 + np.exp(-x / temperature))


def predict_burst(model, df_spectrogram, device="cpu"):
    """
    Predict solar radio burst on a single spectrogram DataFrame.
    
    Args:
        model: Loaded FlareSense model
        df_spectrogram: Pandas DataFrame (time x frequency) from get_ecallisto_data
        device: Device to run prediction on
        
    Returns:
        tuple: (logit, probability)
            - logit: Raw model output
            - probability: Calibrated probability [0, 1]
    """
    # Preprocess the DataFrame
    input_tensor = preprocess_spectrogram(df_spectrogram)
    
    # Add batch dimension and move to device
    input_batch = input_tensor.unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        logit = model(input_batch).squeeze().item()
    
    # Convert to probability
    probability = sigmoid(logit)
    
    return logit, probability


# ============================================================================
# Main Example
# ============================================================================

def main():
    """Main example demonstrating how to use FlareSense for prediction."""
    
    # Configuration
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}\n")
    
    # Example: Predict on data from May 7, 2021
    # Create a 15-minute window centered around 03:40:30
    # This gives us exactly 15 minutes: 03:33:00 to 03:48:00
    start_time = datetime(2021, 5, 7, 3, 33, 0)
    end_time = datetime(2021, 5, 7, 3, 48, 0)
    
    instrument = "Australia-ASSA_01"
    
    print(f"Example prediction on instrument: {instrument}")
    
    # Load model (downloaded and cached automatically)
    model = load_flaresense_model(device=device)
    
    # Fetch data from e-Callisto
    print(f"Fetching data from e-Callisto...")
    df_dict = get_ecallisto_data(start_time, end_time, instrument)
    
    
    df_spectrogram = df_dict[instrument]

    print(f"Data shape: {df_spectrogram.shape} (time x frequency)")
    print(f"Time range: {df_spectrogram.index[0]} to {df_spectrogram.index[-1]}")
    print(f"Frequency range: {df_spectrogram.columns[0]:.2f} - {df_spectrogram.columns[-1]:.2f} MHz\n")
    
    # Predict (pass the DataFrame directly)
    print("Running prediction...")
    logit, probability = predict_burst(model, df_spectrogram, device=device)
    
    # Display results
    print("\n" + "="*60)
    print("PREDICTION RESULTS")
    print("="*60)
    print(f"Logit:       {logit:.4f}")
    print(f"Probability: {probability:.4f} ({probability*100:.2f}%)")
    burst_detected = probability > 0.5
    print(f"Prediction:  {'BURST DETECTED ☀️' if burst_detected else 'No burst'}")
    print("="*60)
    
    # Plot and save the spectrogram
    print("\nGenerating spectrogram plot...")
    df_processed = subtract_constant_background(df_dict[instrument])

    
    # Show the plot
    fig = plot_spectrogram(df_processed)
    show(fig)


if __name__ == "__main__":
    main()
Downloads last month
3
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train i4ds/flaresense-v2