File size: 2,992 Bytes
78ba245
5717863
78ba245
 
52ec8a5
 
8f6e3cb
 
 
 
 
52ec8a5
8f6e3cb
 
52ec8a5
8f6e3cb
 
 
9c97528
 
 
 
 
 
8f6e3cb
78ba245
5843099
78ba245
52ec8a5
78ba245
5717863
6bcd01d
 
d55ce43
 
78ba245
d55ce43
 
78ba245
d55ce43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78ba245
 
6bcd01d
 
d01f8b2
6bcd01d
444c9fa
43c46bf
 
95f823e
3da7285
af335bf
78ba245
e81430e
6bcd01d
78ba245
ef7f4c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
import torch

MODEL_NAME = "Tamazight-NLP/NLLB-200-600M-Tamazight-All-Data-3-epoch"

NLLB_LANG_MAPPING = {
    "English": "eng_Latn",
    "Standard Moroccan Tamazight": "tzm_Tfng",
    "Tachelhit/Central Atlas Tamazight": "taq_Tfng",
    "Tachelhit/Central Atlas Tamazight (Latin)": "taq_Latn",
    "Tarifit": "kab_Tfng",
    "Tarifit (Latin)": "kab_Latn",
    "Moroccan Darija": "ary_Arab",
    "Modern Standard Arabic": "arb_Arab",
    "Catalan": "cat_Latn",
    "Spanish": "spa_Latn",
    "French": "fra_Latn",
    "German": "deu_Latn",
    "Dutch": "nld_Latn",
    "Russian": "rus_Cyrl",
    "Italian": "ita_Latn",
    "Turkish": "tur_Latn",
    "Esperanto": "epo_Latn"
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)

tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME)


def translate(text, source_lang="English", target_lang="Tachelhit/Central Atlas Tamazight",
              max_length=238, num_beams=4, repetition_penalty=1.0):
    """
    Translate multi-line text while preserving line breaks.
    Each line is translated independently.
    """
    translations = []
    for line in text.split("\n"):
        if line.strip() == "":
            translations.append("")  # preserve empty lines
        else:
            tokenizer.src_lang = NLLB_LANG_MAPPING[source_lang]
            inputs = tokenizer(line, return_tensors="pt").to(model.device)
            translated_tokens = model.generate(
                **inputs,
                forced_bos_token_id=tokenizer.convert_tokens_to_ids(NLLB_LANG_MAPPING[target_lang]),
                max_length=max_length,
                num_beams=num_beams,
                repetition_penalty=float(repetition_penalty),
            )
            translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
            translations.append(translation)
    return "\n".join(translations)


gradio_ui= gr.Interface(
    fn=translate,
    title="NLLB Tamazight Translation Demo",
    inputs= [
        gr.components.Textbox(label="Text", lines=4, placeholder="ⵙⵙⴽⵛⵎ ⴰⴹⵕⵉⵚ...\nEnter text to translate..."),
        gr.components.Dropdown(label="Source Language", choices=list(NLLB_LANG_MAPPING.keys()), value="English"),
        gr.components.Dropdown(label="Target Language", choices=list(NLLB_LANG_MAPPING.keys()), value="Standard Moroccan Tamazight"),
        gr.components.Slider(8, 400, value=238, step=8, label="Max Length (in tokens). Increase in case the output looks truncated."),
        gr.components.Slider(1, 25, value=4, step=1, label="Number of beams. Higher values might improve translation accuracy at the cost of speed."),
        gr.components.Slider(1, 10, value=1.0, step=0.1, label="Repetition penalty."),
    ],
    outputs=gr.components.Textbox(label="Translated text", lines=4)
)

gradio_ui.launch()