File size: 7,243 Bytes
3297f8d 7327516 9f0bc77 3297f8d 5d4bb2c 3297f8d 2f2eda6 3297f8d 9f0bc77 cae05b4 2f2eda6 cae05b4 3297f8d 9f0bc77 2f2eda6 3297f8d 2f2eda6 9f0bc77 2f2eda6 3297f8d 9f0bc77 2f2eda6 3297f8d 2f2eda6 166f868 2f2eda6 7327516 3297f8d 7327516 3297f8d 7327516 3297f8d 7327516 2f2eda6 3297f8d 2f2eda6 3297f8d 7327516 2f2eda6 7327516 2f2eda6 7327516 9f0bc77 7327516 2f2eda6 7327516 2f2eda6 15de1d7 5d4bb2c 2f2eda6 cae05b4 2f2eda6 15de1d7 5d4bb2c 4ab7e58 cc4a1f7 a059d9b cc4a1f7 2f2eda6 cc4a1f7 2f2eda6 cc4a1f7 2f2eda6 3297f8d cc4a1f7 2f2eda6 3297f8d 2f2eda6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import os
import torch
import gradio as gr
import requests
from typing import List, Dict, Iterator
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from peft import PeftModel
import json
BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a"
ADAPTER_ID = os.getenv("ADAPTER_ID")
YANDEX_API_KEY = os.getenv("YANDEX_API_KEY")
YANDEX_FOLDER_ID= os.getenv("YANDEX_FOLDER_ID")
if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]):
raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID")
MAX_NEW_TOKENS = 1024
TEMPERATURE = 1
TOP_P = 0.9
REPETITION_PENALTY = 1.05
SYS_PROMPT_TT = ("Син-цифрлы ярдәмче (ир-ат нәселе). Сине Сбербанк дирекциясенең ESG да уйлап таптылар. Син барлык өлкәләрдә, бигрәк тә Татарстанга кагылышлы өлкәләрдә кызыклы кулланучы эксперты! Ул сезгә бик күп сораулар бирәчәк, ә сезнең эшегез-шәрехләр бирү, кулланучының сорауларына җавап бирү, адымлап киңәшләр, мисаллар бирү һәм, кирәк булганда, кулланучыга аныклаучы сораулар бирү.")
tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
print("Применяем LoRA адаптер...")
model = PeftModel.from_pretrained(base, ADAPTER_ID, torch_dtype=torch.float16)
model.config.use_cache = False
model.eval()
print("✅ Модель успешно загружена!")
YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate"
YANDEX_DETECT_URL = "https://translate.api.cloud.yandex.net/translate/v2/detect"
def detect_language(text: str) -> str:
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
payload = {"folderId": YANDEX_FOLDER_ID, "text": text}
try:
resp = requests.post(YANDEX_DETECT_URL, headers=headers, json=payload, timeout=10)
resp.raise_for_status()
return resp.json().get("languageCode", "ru")
except requests.exceptions.RequestException:
return "ru"
def ru2tt(text: str) -> str:
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
payload = {"folderId": YANDEX_FOLDER_ID, "texts": [text], "sourceLanguageCode": "ru", "targetLanguageCode": "tt"}
try:
resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30)
resp.raise_for_status()
return resp.json()["translations"][0]["text"]
except requests.exceptions.RequestException:
return text
def render_prompt(messages: List[Dict[str, str]]) -> str:
return tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# --- 4) Стриминговая генерация (без тримминга) ---
@torch.inference_mode()
def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]:
prompt = render_prompt(messages)
enc = tok(prompt, return_tensors="pt")
enc = {k: v.to(model.device) for k, v in enc.items()}
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**enc,
streamer=streamer,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
# temperature=TEMPERATURE,
# top_p=TOP_P,
repetition_penalty=REPETITION_PENALTY,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
acc = ""
for chunk in streamer:
acc += chunk
yield acc
def chat_fn(message: str, ui_history: list, messages_state: List[Dict[str, str]]):
if not messages_state or messages_state[0].get("role") != "system":
messages_state = [{"role": "system", "content": SYS_PROMPT_TT}]
detected = detect_language(message)
user_tt = ru2tt(message) if detected != "tt" else message
messages = messages_state + [{"role": "user", "content": user_tt}]
ui_history = ui_history + [[user_tt, ""]]
last = ""
for partial in generate_tt_reply_stream(messages):
last = partial
ui_history[-1][1] = partial
yield ui_history, messages_state + [
{"role": "user", "content": user_tt},
{"role": "assistant", "content": partial},
]
final_state = messages + [{"role": "assistant", "content": last}]
print("STATE:", json.dumps(final_state, ensure_ascii=False))
with gr.Blocks(
theme=gr.themes.Soft(),
css="""
#chatbot .message.bot,
#chatbot .message.bot .markdown,
#chatbot .message.bot .prose,
#chatbot .message.bot p,
#chatbot .message.bot li,
#chatbot .message.bot pre,
#chatbot .message.bot code {
font-size: 22px !important;
line-height: 1.7 !important;
}
#chatbot .gr-chatbot_message.gr-chatbot_message__bot,
#chatbot .gr-chatbot_message.gr-chatbot_message__bot .gr-chatbot_markdown > *,
#chatbot .gr-chatbot_message--assistant,
#chatbot .gr-chatbot_message--assistant .gr-chatbot_markdown > * {
font-size: 22px !important;
line-height: 1.7 !important;
}
#chatbot .gr-chatbot { font-size: 18px !important; line-height: 1.5; }
#chatbot .gr-chatbot_message { font-size: 18px !important; }
#chatbot .gr-chatbot_markdown > * { font-size: 18px !important; line-height: 1.6; }
#msg textarea { font-size: 24px !important; }
#clear { font-size: 16px !important; }
#title h2 { font-size: 28px !important; }
"""
) as demo:
gr.Markdown("## Татарский чат-бот от команды Сбера", elem_id="title")
messages_state = gr.State([{"role": "system", "content": SYS_PROMPT_TT}])
chatbot = gr.Chatbot(
label="Диалог",
height=500,
bubble_full_width=False,
elem_id="chatbot"
)
msg = gr.Textbox(
label="Хәбәрегезне рус яки татар телендә языгыз",
placeholder="Татарстанның башкаласы нинди шәһәр? / Какая столица Татарстана?",
elem_id="msg"
)
clear = gr.Button("🗑️ Чистарту", elem_id="clear")
msg.submit(
chat_fn,
inputs=[msg, chatbot, messages_state],
outputs=[chatbot, messages_state],
)
msg.submit(lambda: "", None, msg)
def _reset():
return [], [{"role": "system", "content": SYS_PROMPT_TT}]
clear.click(_reset, inputs=None, outputs=[chatbot, messages_state], queue=False)
clear.click(lambda: "", None, msg, queue=False)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|