resnet18_anomaly / src /gradcam.py
Swarno-Coder's picture
Upload folder using huggingface_hub
1a4a0d3 verified
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from src.model import get_model
class GradCAM:
def __init__(self, model_path, target_layer='layer4'):
self.model = get_model(pretrained=False)
self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
self.model.eval()
self.target_layer = target_layer
self.gradients = None
self.activations = None
self.hooks = []
self._register_hooks()
def _register_hooks(self):
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0]
def forward_hook(module, input, output):
self.activations = output
for name, module in self.model.named_modules():
if target_layer in name:
self.hooks.append(module.register_forward_hook(forward_hook))
self.hooks.append(module.register_backward_hook(backward_hook))
def generate(self, input_tensor, class_idx=None):
self.model.zero_grad()
output = self.model(input_tensor)
if class_idx is None:
class_idx = output.argmax().item()
score = output[0, class_idx]
score.backward()
gradients = self.gradients[0]
activations = self.activations[0]
weights = torch.mean(gradients, dim=(1, 2), keepdim=True)
cam = torch.sum(weights * activations, dim=1, keepdim=True)
cam = F.relu(cam)
cam = cam.squeeze().detach().cpu().numpy()
cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
# Remove hooks
for hook in self.hooks:
hook.remove()
return cam
def __del__(self):
for hook in self.hooks:
hook.remove()
# Usage example:
# gradcam = GradCAM('../models/resnet18_anomaly.pth')
# input_tensor = torch.randn(1, 3, 224, 224) # From preprocess
# heatmap = gradcam.generate(input_tensor)