Spaces:
Runtime error
Runtime error
| import torch | |
| import clip | |
| from PIL import Image | |
| from io import BytesIO | |
| import os | |
| import requests | |
| # Model information dictionary containing model paths and language subcategories | |
| model_info = { | |
| "hindi": { | |
| "path": "models/clip_finetuned_hindienglish_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglish_real.pth", | |
| "subcategories": ["hindi", "english"] | |
| }, | |
| "hinengasm": { | |
| "path": "models/clip_finetuned_hindienglishassamese_real.pth", | |
| "url": "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishassamese_real.pth", | |
| "subcategories": ["hindi", "english", "assamese"] | |
| }, | |
| "hinengben": { | |
| "path": "models/clip_finetuned_hindienglishbengali_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishbengali_real.pth", | |
| "subcategories": ["hindi", "english", "bengali"] | |
| }, | |
| "hinengguj": { | |
| "path": "models/clip_finetuned_hindienglishgujarati_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishgujarati_real.pth", | |
| "subcategories": ["hindi", "english", "gujarati"] | |
| }, | |
| "hinengkan": { | |
| "path": "models/clip_finetuned_hindienglishkannada_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishkannada_real.pth", | |
| "subcategories": ["hindi", "english", "kannada"] | |
| }, | |
| "hinengmal": { | |
| "path": "models/clip_finetuned_hindienglishmalayalam_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmalayalam_real.pth", | |
| "subcategories": ["hindi", "english", "malayalam"] | |
| }, | |
| "hinengmar": { | |
| "path": "models/clip_finetuned_hindienglishmarathi_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmarathi_real.pth", | |
| "subcategories": ["hindi", "english", "marathi"] | |
| }, | |
| "hinengmei": { | |
| "path": "models/clip_finetuned_hindienglishmeitei_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishmeitei_real.pth", | |
| "subcategories": ["hindi", "english", "meitei"] | |
| }, | |
| "hinengodi": { | |
| "path": "models/clip_finetuned_hindienglishodia_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishodia_real.pth", | |
| "subcategories": ["hindi", "english", "odia"] | |
| }, | |
| "hinengpun": { | |
| "path": "models/clip_finetuned_hindienglishpunjabi_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishpunjabi_real.pth", | |
| "subcategories": ["hindi", "english", "punjabi"] | |
| }, | |
| "hinengtam": { | |
| "path": "models/clip_finetuned_hindienglishtamil_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtamil_real.pth", | |
| "subcategories": ["hindi", "english", "tamil"] | |
| }, | |
| "hinengtel": { | |
| "path": "models/clip_finetuned_hindienglishtelugu_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishtelugu_real.pth", | |
| "subcategories": ["hindi", "english", "telugu"] | |
| }, | |
| "hinengurd": { | |
| "path": "models/clip_finetuned_hindienglishurdu_real.pth", | |
| "url" : "https://github.com/anikde/STscriptdetect/releases/download/V1/clip_finetuned_hindienglishurdu_real.pth", | |
| "subcategories": ["hindi", "english", "urdu"] | |
| }, | |
| } | |
| # Set device to CUDA if available, otherwise use CPU | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| clip_model, preprocess = clip.load("ViT-B/32", device=device) | |
| class CLIPFineTuner(torch.nn.Module): | |
| """ | |
| Fine-tuning class for the CLIP model to adapt to specific tasks. | |
| Attributes: | |
| model (torch.nn.Module): The CLIP model to be fine-tuned. | |
| classifier (torch.nn.Linear): A linear classifier to map features to the desired number of classes. | |
| """ | |
| def __init__(self, model, num_classes): | |
| """ | |
| Initializes the fine-tuner with the CLIP model and classifier. | |
| Args: | |
| model (torch.nn.Module): The base CLIP model. | |
| num_classes (int): The number of target classes for classification. | |
| """ | |
| super(CLIPFineTuner, self).__init__() | |
| self.model = model | |
| self.classifier = torch.nn.Linear(model.visual.output_dim, num_classes) | |
| def forward(self, x): | |
| """ | |
| Forward pass for image classification. | |
| Args: | |
| x (torch.Tensor): Preprocessed input tensor for an image. | |
| Returns: | |
| torch.Tensor: Logits for each class. | |
| """ | |
| with torch.no_grad(): | |
| features = self.model.encode_image(x).float() # Extract image features from CLIP model | |
| return self.classifier(features) # Return class logits | |
| class CLIPidentifier: | |
| def __init__(self): | |
| pass | |
| # Ensure model file exists; download directly if not | |
| def ensure_model(self, model_name): | |
| model_path = model_info[model_name]["path"] | |
| url = model_info[model_name]["url"] | |
| root_model_dir = "IndicPhotoOCR/script_identification/" | |
| model_path = os.path.join(root_model_dir, model_path) | |
| if not os.path.exists(model_path): | |
| print(f"Model not found locally. Downloading {model_name} from {url}...") | |
| response = requests.get(url, stream=True) | |
| os.makedirs(f"{root_model_dir}/models", exist_ok=True) | |
| with open(f"{model_path}", "wb") as f: | |
| f.write(response.content) | |
| print(f"Downloaded model for {model_name}.") | |
| return model_path | |
| # Prediction function to verify and load the model | |
| def identify(self, image_path, model_name): | |
| """ | |
| Predicts the class of an input image using a fine-tuned CLIP model. | |
| Args: | |
| image_path (str): Path to the input image file. | |
| model_name (str): Name of the model (e.g., hineng, hinengpun, hinengguj) as specified in `model_info`. | |
| Returns: | |
| dict: Contains either `predicted_class` if successful or `error` if an exception occurs. | |
| Example usage: | |
| result = predict("sample_image.jpg", "hinengguj") | |
| print(result) # Output might be {'predicted_class': 'hindi'} | |
| """ | |
| try: | |
| # Validate model name and retrieve associated subcategories | |
| if model_name not in model_info: | |
| return {"error": "Invalid model name"} | |
| # Ensure the model file is downloaded and accessible | |
| model_path = self.ensure_model(model_name) | |
| subcategories = model_info[model_name]["subcategories"] | |
| num_classes = len(subcategories) | |
| # Load the fine-tuned model with the specified number of classes | |
| model_ft = CLIPFineTuner(clip_model, num_classes) | |
| model_ft.load_state_dict(torch.load(model_path, map_location=device)) | |
| model_ft = model_ft.to(device) | |
| model_ft.eval() | |
| # Load and preprocess the image | |
| image = Image.open(image_path).convert("RGB") | |
| input_tensor = preprocess(image).unsqueeze(0).to(device) | |
| # Run the model and get the prediction | |
| outputs = model_ft(input_tensor) | |
| _, predicted_idx = torch.max(outputs, 1) | |
| predicted_class = subcategories[predicted_idx.item()] | |
| return predicted_class | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # if __name__ == "__main__": | |
| # import argparse | |
| # # Argument parser for command line usage | |
| # parser = argparse.ArgumentParser(description="Image classification using CLIP fine-tuned model") | |
| # parser.add_argument("image_path", type=str, help="Path to the input image") | |
| # parser.add_argument("model_name", type=str, choices=model_info.keys(), help="Name of the model (e.g., hineng, hinengpun, hinengguj)") | |
| # args = parser.parse_args() | |
| # # Execute prediction with command line inputs | |
| # result = predict(args.image_path, args.model_name) | |
| # print(result) |