Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| import subprocess | |
| import sys | |
| import threading | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from io import BytesIO | |
| import pypdfium2 as pdfium | |
| from transformers import ( | |
| LightOnOCRForConditionalGeneration, | |
| LightOnOCRProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Choose best attention implementation based on device | |
| if device == "cuda": | |
| attn_implementation = "sdpa" | |
| dtype = torch.bfloat16 | |
| print("Using sdpa for GPU") | |
| else: | |
| attn_implementation = "eager" # Best for CPU | |
| dtype = torch.float32 | |
| print("Using eager attention for CPU") | |
| # Initialize the LightOnOCR model and processor | |
| print(f"Loading model on {device} with {attn_implementation} attention...") | |
| model = LightOnOCRForConditionalGeneration.from_pretrained( | |
| "lightonai/LightOnOCR-1B-1025", | |
| attn_implementation=attn_implementation, | |
| torch_dtype=dtype, | |
| trust_remote_code=True | |
| ).to(device).eval() | |
| processor = LightOnOCRProcessor.from_pretrained( | |
| "lightonai/LightOnOCR-1B-1025", | |
| trust_remote_code=True | |
| ) | |
| print("Model loaded successfully!") | |
| def render_pdf_page(page, max_resolution=1540, scale=2.77): | |
| """Render a PDF page to PIL Image.""" | |
| width, height = page.get_size() | |
| pixel_width = width * scale | |
| pixel_height = height * scale | |
| resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height) | |
| target_scale = scale * resize_factor | |
| return page.render(scale=target_scale, rev_byteorder=True).to_pil() | |
| def process_pdf(pdf_path, page_num=1): | |
| """Extract a specific page from PDF.""" | |
| pdf = pdfium.PdfDocument(pdf_path) | |
| total_pages = len(pdf) | |
| page_idx = min(max(int(page_num) - 1, 0), total_pages - 1) | |
| page = pdf[page_idx] | |
| img = render_pdf_page(page) | |
| pdf.close() | |
| return img, total_pages, page_idx + 1 | |
| def clean_output_text(text): | |
| """Remove chat template artifacts from output.""" | |
| # Remove common chat template markers | |
| markers_to_remove = ["system", "user", "assistant"] | |
| # Split by lines and filter | |
| lines = text.split('\n') | |
| cleaned_lines = [] | |
| for line in lines: | |
| stripped = line.strip() | |
| # Skip lines that are just template markers | |
| if stripped.lower() not in markers_to_remove: | |
| cleaned_lines.append(line) | |
| # Join back and strip leading/trailing whitespace | |
| cleaned = '\n'.join(cleaned_lines).strip() | |
| # Alternative approach: if there's an "assistant" marker, take everything after it | |
| if "assistant" in text.lower(): | |
| parts = text.split("assistant", 1) | |
| if len(parts) > 1: | |
| cleaned = parts[1].strip() | |
| return cleaned | |
| def extract_text_from_image(image, temperature=0.2, stream=False): | |
| """Extract text from image using LightOnOCR model.""" | |
| # Prepare the chat format | |
| chat = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "url": image}, | |
| ], | |
| } | |
| ] | |
| # Apply chat template and tokenize | |
| inputs = processor.apply_chat_template( | |
| chat, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ) | |
| # Move inputs to device AND convert to the correct dtype | |
| inputs = { | |
| k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16] | |
| else v.to(device) if isinstance(v, torch.Tensor) | |
| else v | |
| for k, v in inputs.items() | |
| } | |
| generation_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=2048, | |
| temperature=temperature if temperature > 0 else 0.0, | |
| use_cache=True, | |
| do_sample=temperature > 0, | |
| ) | |
| if stream: | |
| # Setup streamer for streaming generation | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generation_kwargs["streamer"] = streamer | |
| # Run generation in a separate thread | |
| thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Yield chunks as they arrive | |
| full_text = "" | |
| for new_text in streamer: | |
| full_text += new_text | |
| # Clean the accumulated text | |
| cleaned_text = clean_output_text(full_text) | |
| yield cleaned_text | |
| thread.join() | |
| else: | |
| # Non-streaming generation | |
| with torch.no_grad(): | |
| outputs = model.generate(**generation_kwargs) | |
| # Decode the output | |
| output_text = processor.decode(outputs[0], skip_special_tokens=True) | |
| # Clean the output | |
| cleaned_text = clean_output_text(output_text) | |
| yield cleaned_text | |
| def process_input(file_input, temperature, page_num, enable_streaming): | |
| """Process uploaded file (image or PDF) and extract text with optional streaming.""" | |
| if file_input is None: | |
| yield "Please upload an image or PDF first.", "", "", None, gr.update() | |
| return | |
| image_to_process = None | |
| page_info = "" | |
| file_path = file_input if isinstance(file_input, str) else file_input.name | |
| # Handle PDF files | |
| if file_path.lower().endswith('.pdf'): | |
| try: | |
| image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num)) | |
| page_info = f"Processing page {actual_page} of {total_pages}" | |
| except Exception as e: | |
| yield f"Error processing PDF: {str(e)}", "", "", None, gr.update() | |
| return | |
| # Handle image files | |
| else: | |
| try: | |
| image_to_process = Image.open(file_path) | |
| page_info = "Processing image" | |
| except Exception as e: | |
| yield f"Error opening image: {str(e)}", "", "", None, gr.update() | |
| return | |
| try: | |
| # Extract text using LightOnOCR with optional streaming | |
| for extracted_text in extract_text_from_image(image_to_process, temperature, stream=enable_streaming): | |
| yield extracted_text, extracted_text, page_info, image_to_process, gr.update() | |
| except Exception as e: | |
| error_msg = f"Error during text extraction: {str(e)}" | |
| yield error_msg, error_msg, page_info, image_to_process, gr.update() | |
| def update_slider(file_input): | |
| """Update page slider based on PDF page count.""" | |
| if file_input is None: | |
| return gr.update(maximum=20, value=1) | |
| file_path = file_input if isinstance(file_input, str) else file_input.name | |
| if file_path.lower().endswith('.pdf'): | |
| try: | |
| pdf = pdfium.PdfDocument(file_path) | |
| total_pages = len(pdf) | |
| pdf.close() | |
| return gr.update(maximum=total_pages, value=1) | |
| except: | |
| return gr.update(maximum=20, value=1) | |
| else: | |
| return gr.update(maximum=1, value=1) | |
| # Create Gradio interface | |
| with gr.Blocks(title="π Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f""" | |
| # β οΈ **HEADS UP: This space is now on CPU and runs very slowly.** | |
| For much faster results, check out the [GPU version here](https://huggingface.co/spaces/lightonai/LightOnOCR-1B-Demo-zero). | |
| --- | |
| # π Image/PDF to Text Extraction with LightOnOCR | |
| **π‘ How to use:** | |
| 1. Upload an image or PDF | |
| 2. For PDFs: select which page to extract (1-20) | |
| 3. Adjust temperature if needed | |
| 4. Click "Extract Text" | |
| **Note:** The Markdown rendering for tables may not always be perfect. Check the raw output for complex tables! | |
| **Model:** LightOnOCR-1B-1025 by LightOn AI | |
| **Device:** {device.upper()} | |
| **Attention:** {attn_implementation} | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File( | |
| label="πΌοΈ Upload Image or PDF", | |
| file_types=[".pdf", ".png", ".jpg", ".jpeg"], | |
| type="filepath" | |
| ) | |
| rendered_image = gr.Image( | |
| label="π Preview", | |
| type="pil", | |
| height=400, | |
| interactive=False | |
| ) | |
| num_pages = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=1, | |
| step=1, | |
| label="PDF: Page Number", | |
| info="Select which page to extract" | |
| ) | |
| page_info = gr.Textbox( | |
| label="Processing Info", | |
| value="", | |
| interactive=False | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| label="Temperature", | |
| info="0.0 = deterministic, Higher = more varied" | |
| ) | |
| enable_streaming = gr.Checkbox( | |
| label="Enable Streaming", | |
| value=False, | |
| info="Show text progressively as it's generated" | |
| ) | |
| submit_btn = gr.Button("Extract Text", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| with gr.Column(scale=2): | |
| output_text = gr.Markdown( | |
| label="π Extracted Text (Rendered)", | |
| value="*Extracted text will appear here...*" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| raw_output = gr.Textbox( | |
| label="Raw Markdown Output", | |
| placeholder="Raw text will appear here...", | |
| lines=20, | |
| max_lines=30, | |
| show_copy_button=True | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=process_input, | |
| inputs=[file_input, temperature, num_pages, enable_streaming], | |
| outputs=[output_text, raw_output, page_info, rendered_image, num_pages] | |
| ) | |
| file_input.change( | |
| fn=update_slider, | |
| inputs=[file_input], | |
| outputs=[num_pages] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "*Extracted text will appear here...*", "", "", None, 1), | |
| outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |