""" SafeGem: Vision-Language Model with Visual Guard Module This implementation extends Gemma3ForConditionalGeneration with image safety classification capabilities using a pooling-based approach for safety feature extraction. """ import torch import torch.nn as nn from typing import Optional, Tuple, List, Union from dataclasses import dataclass from transformers.modeling_outputs import CausalLMOutputWithPast from transformers import Gemma3ForConditionalGeneration from transformers.utils import logging from .configuration_safegem import SafeGemConfig logger = logging.get_logger(__name__) local_rank = None def rank0_print(*args): if local_rank == 0 or local_rank == '0' or local_rank is None: print(*args) @dataclass class SafeGemOutput(CausalLMOutputWithPast): """ Output class for SafeGem with safety classification results. """ loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None img_safety_logits: Optional[torch.FloatTensor] = None img_safety_probs: Optional[torch.FloatTensor] = None class SafetyMLP(nn.Module): """ Multi-layer perceptron for safety classification (Visual Guard Module). """ def __init__( self, input_size: int, hidden_size: int, output_size: int, num_hidden_layers: int = 1 ): super().__init__() layers = [] # First layer layers.append(nn.Linear(input_size, hidden_size)) layers.append(nn.GELU()) layers.append(nn.Dropout(0.1)) # Additional hidden layers for _ in range(num_hidden_layers - 1): layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(nn.GELU()) layers.append(nn.Dropout(0.1)) # Output layer layers.append(nn.Linear(hidden_size, output_size)) self.mlp = nn.Sequential(*layers) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: torch.nn.init.constant_(module.bias, 0) def forward(self, x): return self.mlp(x) class SafeGemForConditionalGeneration(Gemma3ForConditionalGeneration): """ SafeGem model with Visual Guard Module for image safety classification. This model extends Gemma3ForConditionalGeneration with: 1. Visual Guard Module (VGM) - a safety classification head 2. Pooling-based safety feature extraction from image tokens 3. Simultaneous text generation and safety classification Key design principles: - Minimal modification to base Gemma3 forward pass - Extract safety features from visual tokens using mean pooling - Non-invasive architecture that maintains full base model capabilities """ config_class = SafeGemConfig def __init__(self, config: SafeGemConfig): super().__init__(config) # Add safety head (Visual Guard Module) if safety configuration is present num_safety_categories = getattr(config, 'num_safety_categories', None) if num_safety_categories and num_safety_categories > 0: hidden_size = config.text_config.hidden_size safety_head_hidden_scale = getattr(config, 'safety_head_hidden_scale', 1.0) safety_hidden_size = int(hidden_size * safety_head_hidden_scale) safety_num_hidden_layers = getattr(config, 'safety_num_hidden_layers', 1) rank0_print(f"🔧 [INIT] Initializing Visual Guard Module: {hidden_size} -> {safety_hidden_size} -> {num_safety_categories}") self.img_safety_head = SafetyMLP( input_size=hidden_size, hidden_size=safety_hidden_size, output_size=num_safety_categories, num_hidden_layers=safety_num_hidden_layers ) else: rank0_print(f"🔧 [INIT] No safety configuration found, Visual Guard Module not initialized") self.img_safety_head = None def _extract_image_features_pooling( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None, image_hidden_states: Optional[torch.Tensor] = None ) -> Optional[torch.Tensor]: """ Extract image features using pooling over visual tokens. Args: hidden_states: [batch_size, seq_len, hidden_size] attention_mask: [batch_size, seq_len] input_ids: [batch_size, seq_len] image_hidden_states: [batch_size, num_images, num_patches, hidden_size] Returns: image_features: [batch_size, hidden_size] or None """ # First try to use image_hidden_states if available (from vision tower) if image_hidden_states is not None: # Handle different shapes of image_hidden_states if len(image_hidden_states.shape) == 3: # [batch_size, num_patches, hidden_size] batch_size, num_patches, hidden_size = image_hidden_states.shape # Mean over patches: [batch_size, hidden_size] pooled_features = image_hidden_states.mean(dim=1) return pooled_features elif len(image_hidden_states.shape) == 4: # [batch_size, num_images, num_patches, hidden_size] batch_size, num_images, num_patches, hidden_size = image_hidden_states.shape # Mean over patches: [batch_size, num_images, hidden_size] pooled_per_image = image_hidden_states.mean(dim=2) # Mean over images: [batch_size, hidden_size] pooled_features = pooled_per_image.mean(dim=1) rank0_print(f"🔧 [POOL] 4D pooled features shape: {pooled_features.shape}") return pooled_features else: rank0_print(f"🔧 [POOL] Unexpected image_hidden_states shape: {image_hidden_states.shape}") return None # Fallback: return None if no image_hidden_states if input_ids is None: rank0_print("🔧 [POOL] No input_ids available for image token detection") return None rank0_print("🔧 [POOL] No image_hidden_states available, cannot extract image features") return None def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, pixel_values: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, do_safety: bool = True, # Default to True for training, can be overridden for generation safety_labels: Optional[torch.LongTensor] = None, **kwargs ) -> Union[Tuple, SafeGemOutput]: """ Forward pass with optional safety classification. Args: do_safety: Whether to perform safety classification (default: True) All other args: Same as Gemma3ForConditionalGeneration Returns: SafeGemOutput with optional safety classification results """ # Force output_hidden_states if we need safety classification # BUT only during initial forward pass, not during generation if do_safety and self.img_safety_head is not None and past_key_values is None: output_hidden_states = True return_dict = True # Standard Gemma3 forward pass - NO MODIFICATIONS outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, pixel_values=pixel_values, return_dict=True, **kwargs ) # Fix NaN/Inf in logits if present if outputs.logits is not None: nan_count = torch.isnan(outputs.logits).sum() inf_count = torch.isinf(outputs.logits).sum() if nan_count > 0 or inf_count > 0: if past_key_values is None: print(f"[CRITICAL] Found NaN or Inf in logits! NaN count: {nan_count}, Inf count: {inf_count}") replacement_values = torch.randn_like(outputs.logits) * 0.001 outputs.logits = torch.where( torch.isnan(outputs.logits) | torch.isinf(outputs.logits), replacement_values, outputs.logits ) # Fix logits shape if needed if len(outputs.logits.shape) == 4 and outputs.logits.shape[1] == 1: outputs.logits = outputs.logits.squeeze(1) # Initialize safety outputs img_safety_logits = None img_safety_probs = None # Check if we should perform safety classification is_generation = past_key_values is not None has_images = pixel_values is not None should_do_safety = ( do_safety and self.img_safety_head is not None and (outputs.hidden_states is not None or outputs.image_hidden_states is not None) and has_images and not is_generation ) if should_do_safety: # Extract image features image_features = self._extract_image_features_pooling( hidden_states=outputs.hidden_states[-1] if outputs.hidden_states else None, attention_mask=attention_mask, input_ids=input_ids, image_hidden_states=outputs.image_hidden_states ) if image_features is not None: # Run through Visual Guard Module img_safety_logits = self.img_safety_head(image_features) img_safety_probs = torch.softmax(img_safety_logits, dim=-1) else: rank0_print("🔧 [SafeGem] ❌ Image features extraction failed") # Return results if return_dict is False: output = (outputs.loss, outputs.logits, outputs.past_key_values, outputs.hidden_states, outputs.attentions) if img_safety_logits is not None: output += (img_safety_logits, img_safety_probs) return output else: # During generation, return standard output if is_generation or past_key_values is not None: return outputs else: # During training/inference, return custom output with safety info return SafeGemOutput( loss=outputs.loss, logits=outputs.logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=outputs.image_hidden_states, img_safety_logits=img_safety_logits, img_safety_probs=img_safety_probs )