nougat / app.py
heerjtdev's picture
Update app.py
5759d4b verified
# # app.py
# import gradio as gr
# from transformers import pipeline
# import torch
# from PIL import Image
# import io
# import fitz # PyMuPDF
# # --- Model Loading ---
# # Nougat is typically used for PDF/document image OCR.
# #The `facebook/nougat-small` model is a good starting point.
# # Using 'facebook/nougat-base' or 'facebook/nougat-large' is more accurate but requires more GPU memory/power.
# try:
# # Set up the device based on availability
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# # Load the Nougat pipeline
# # The task is technically 'document-image-to-text' but can be inferred by the model name
# nougat_pipeline = pipeline(
# "image-to-text",
# model="facebook/nougat-small",
# device=device,
# # Set max_new_tokens for the output length
# max_new_tokens=1024,
# # Set to False to prevent a warning about the model not having an image-to-text pipeline
# # (The pipeline can still wrap the VisionEncoderDecoder model)
# trust_remote_code=True
# )
# print(f"Nougat model loaded successfully on {device}")
# except Exception as e:
# # Fallback/error handling for model loading
# print(f"Error loading Nougat model: {e}")
# nougat_pipeline = None
# # --- OCR Function ---
# def nougat_ocr(document):
# """Performs Nougat OCR on a single-page document image or PDF."""
# if nougat_pipeline is None:
# return "Error: Nougat model failed to load. Check your Space hardware and dependencies."
# # Handle File object from Gradio (could be an image or a PDF)
# file_path = document.name
# # 1. Convert PDF (or first page of PDF) to an image
# if file_path.lower().endswith(('.pdf')):
# try:
# # Open PDF using PyMuPDF (fitz)
# doc = fitz.open(file_path)
# if len(doc) == 0:
# return "Error: PDF contains no pages."
# # Render the first page at a high DPI for better OCR
# page = doc.load_page(0)
# pix = page.get_pixmap(dpi=300)
# # Convert pixmap to PIL Image
# img_data = pix.tobytes("png")
# image = Image.open(io.BytesIO(img_data))
# doc.close()
# except Exception as e:
# return f"Error processing PDF: {e}"
# # 2. Handle image file (png, jpg, etc.)
# elif file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
# image = Image.open(file_path).convert("RGB")
# else:
# return "Error: Unsupported file format. Please upload an image or a PDF."
# # 3. Perform OCR inference
# try:
# # Pass the PIL image to the pipeline
# output = nougat_pipeline(image)
# # The output is typically a list of dicts: [{'generated_text': '...'}]
# markdown_text = output[0]['generated_text'] if output else "OCR failed to generate text."
# return markdown_text
# except Exception as e:
# return f"An error occurred during OCR: {e}"
# # --- Gradio Interface ---
# title = "🍫 Nougat OCR for Documents"
# description = "Upload a single-page document image (PNG/JPG) or a PDF to transcribe it into Markdown format using the Nougat-small model. **Note: For multi-page PDFs, only the first page is processed.**"
# iface = gr.Interface(
# fn=nougat_ocr,
# inputs=gr.File(
# label="Upload Document (Image or PDF)",
# file_types=["image", ".pdf"],
# file_count="single"
# ),
# outputs=gr.Markdown(label="Generated Markdown Output"),
# title=title,
# description=description,
# allow_flagging="auto",
# theme=gr.themes.Soft()
# )
# if __name__ == "__main__":
# iface.launch()
import gradio as gr
from transformers import pipeline
import torch
from PIL import Image
import io
import fitz # PyMuPDF
import os
import tempfile # Used for creating temporary files for download
# --- Model Loading ---
try:
# Set up the device based on availability
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load the Nougat pipeline
nougat_pipeline = pipeline(
"image-to-text",
model="facebook/nougat-small",
device=device,
max_new_tokens=1024,
trust_remote_code=True
)
print(f"Nougat model loaded successfully on {device}")
except Exception as e:
print(f"Error loading Nougat model: {e}")
nougat_pipeline = None
# --- OCR Function (Revised for Multi-Page PDF and TXT Output) ---
def nougat_ocr(document):
"""
Performs Nougat OCR on all pages of a PDF, aggregates results,
and saves them to a temporary TXT file for download.
"""
if nougat_pipeline is None:
return None, "Error: Nougat model failed to load. Check dependencies."
file_path = document.name
all_markdown_text = []
# Only PDF processing is supported for the multi-page output logic
if not file_path.lower().endswith(('.pdf')):
return None, "Error: Please upload a PDF file for multi-page processing."
try:
doc = fitz.open(file_path)
if len(doc) == 0:
return None, "Error: PDF contains no pages."
# Process pages one by one
for i in range(len(doc)):
page = doc.load_page(i)
# Render the page to a high-DPI image
pix = page.get_pixmap(dpi=300)
# Convert pixmap to PIL Image
img_data = pix.tobytes("png")
image = Image.open(io.BytesIO(img_data)).convert("RGB")
# Perform OCR inference for this page
output = nougat_pipeline(image)
markdown_text = output[0]['generated_text'] if output else "[OCR FAILED FOR PAGE {}]".format(i+1)
# Add a clear separator and the page content
all_markdown_text.append(f"\n\n\n# --- PAGE {i+1} ---\n\n{markdown_text}")
doc.close()
aggregated_text = "".join(all_markdown_text)
# Create a temporary file to save the result
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt", encoding="utf-8") as tmp_file:
tmp_file.write(aggregated_text)
temp_file_path = tmp_file.name
# The function returns the file path and the text content
return temp_file_path, aggregated_text
except Exception as e:
return None, f"An error occurred during processing: {e}"
# --- Gradio Interface ---
title = "🍫 Multi-Page Nougat OCR to TXT"
description = "Upload a PDF. The model will process each page sequentially and output a TXT file for download, along with a Markdown preview."
iface = gr.Interface(
fn=nougat_ocr,
inputs=gr.File(
label="Upload PDF Document",
file_types=[".pdf"], # Restrict to PDF
file_count="single"
),
outputs=[
gr.File(label="Download OCR Output (.txt)", file_count="single", file_types=[".txt"]), # Downloadable file
gr.Markdown(label="Preview (Formatted Markdown)", visible=True) # Preview output
],
title=title,
description=description,
allow_flagging="auto",
theme=gr.themes.Soft()
)
if __name__ == "__main__":
iface.launch()