""" SafeGem Configuration Configuration class for SafeGem models with safety classification capabilities. """ from typing import Optional, List from transformers import Gemma3Config class SafeGemConfig(Gemma3Config): """ Configuration for SafeGem model. This configuration class extends Gemma3Config with safety-specific parameters. """ model_type = "safegem" def __init__( self, # Safety specific parameters safety_categories: Optional[List[str]] = None, safety_head_hidden_scale: float = 1.0, safety_loss_lambda: float = 1.0, safety_num_hidden_layers: int = 1, num_safety_categories: int = 20, **kwargs ): super().__init__(**kwargs) # HoliSafe 20-category safety taxonomy self.safety_categories = safety_categories or [ "safe", "gender", "race", "religion", "harassment", "disability_discrimination", "drug_crime", "property_crime", "facial_data", "identity_data", "physical_self_injury", "suicide", "animal_abuse", "obscene_gestures", "physical_altercation", "terrorism", "weapon_related_violence", "sexual_content", "financial_advice", "medical_advice" ] self.safety_head_hidden_scale = safety_head_hidden_scale self.safety_loss_lambda = safety_loss_lambda self.safety_num_hidden_layers = safety_num_hidden_layers self.num_safety_categories = num_safety_categories or len(self.safety_categories)