Spaces:
Sleeping
Sleeping
File size: 2,992 Bytes
78ba245 5717863 78ba245 52ec8a5 8f6e3cb 52ec8a5 8f6e3cb 52ec8a5 8f6e3cb 9c97528 8f6e3cb 78ba245 5843099 78ba245 52ec8a5 78ba245 5717863 6bcd01d d55ce43 78ba245 d55ce43 78ba245 d55ce43 78ba245 6bcd01d d01f8b2 6bcd01d 444c9fa 43c46bf 95f823e 3da7285 af335bf 78ba245 e81430e 6bcd01d 78ba245 ef7f4c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
import torch
MODEL_NAME = "Tamazight-NLP/NLLB-200-600M-Tamazight-All-Data-3-epoch"
NLLB_LANG_MAPPING = {
"English": "eng_Latn",
"Standard Moroccan Tamazight": "tzm_Tfng",
"Tachelhit/Central Atlas Tamazight": "taq_Tfng",
"Tachelhit/Central Atlas Tamazight (Latin)": "taq_Latn",
"Tarifit": "kab_Tfng",
"Tarifit (Latin)": "kab_Latn",
"Moroccan Darija": "ary_Arab",
"Modern Standard Arabic": "arb_Arab",
"Catalan": "cat_Latn",
"Spanish": "spa_Latn",
"French": "fra_Latn",
"German": "deu_Latn",
"Dutch": "nld_Latn",
"Russian": "rus_Cyrl",
"Italian": "ita_Latn",
"Turkish": "tur_Latn",
"Esperanto": "epo_Latn"
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME)
def translate(text, source_lang="English", target_lang="Tachelhit/Central Atlas Tamazight",
max_length=238, num_beams=4, repetition_penalty=1.0):
"""
Translate multi-line text while preserving line breaks.
Each line is translated independently.
"""
translations = []
for line in text.split("\n"):
if line.strip() == "":
translations.append("") # preserve empty lines
else:
tokenizer.src_lang = NLLB_LANG_MAPPING[source_lang]
inputs = tokenizer(line, return_tensors="pt").to(model.device)
translated_tokens = model.generate(
**inputs,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(NLLB_LANG_MAPPING[target_lang]),
max_length=max_length,
num_beams=num_beams,
repetition_penalty=float(repetition_penalty),
)
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
translations.append(translation)
return "\n".join(translations)
gradio_ui= gr.Interface(
fn=translate,
title="NLLB Tamazight Translation Demo",
inputs= [
gr.components.Textbox(label="Text", lines=4, placeholder="ⵙⵙⴽⵛⵎ ⴰⴹⵕⵉⵚ...\nEnter text to translate..."),
gr.components.Dropdown(label="Source Language", choices=list(NLLB_LANG_MAPPING.keys()), value="English"),
gr.components.Dropdown(label="Target Language", choices=list(NLLB_LANG_MAPPING.keys()), value="Standard Moroccan Tamazight"),
gr.components.Slider(8, 400, value=238, step=8, label="Max Length (in tokens). Increase in case the output looks truncated."),
gr.components.Slider(1, 25, value=4, step=1, label="Number of beams. Higher values might improve translation accuracy at the cost of speed."),
gr.components.Slider(1, 10, value=1.0, step=0.1, label="Repetition penalty."),
],
outputs=gr.components.Textbox(label="Translated text", lines=4)
)
gradio_ui.launch() |