Nefertury AccessAndrei commited on
Commit
9f0bc77
·
verified ·
1 Parent(s): cae05b4

without quantisation (#6)

Browse files

- without quantisation (ab57ebe0bba9f985cda9ee2e9329fc577c9edd5c)


Co-authored-by: Aksenov Andrei <[email protected]>

Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
  import requests
5
  from typing import List, Dict, Iterator
6
  from threading import Thread
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
8
  from peft import PeftModel
9
  import json
10
 
@@ -16,15 +16,14 @@ YANDEX_FOLDER_ID= os.getenv("YANDEX_FOLDER_ID")
16
  if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]):
17
  raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID")
18
 
19
- MAX_NEW_TOKENS = 2048
20
  TEMPERATURE = 1
21
  TOP_P = 0.9
22
  REPETITION_PENALTY = 1.05
23
 
24
  SYS_PROMPT_TT = ("Син-цифрлы ярдәмче (ир-ат нәселе). Сине Сбербанк дирекциясенең ESG да уйлап таптылар. Син барлык өлкәләрдә, бигрәк тә Татарстанга кагылышлы өлкәләрдә кызыклы кулланучы эксперты! Ул сезгә бик күп сораулар бирәчәк, ә сезнең эшегез-шәрехләр бирү, кулланучының сорауларына җавап бирү, адымлап киңәшләр, мисаллар бирү һәм, кирәк булганда, кулланучыга аныклаучы сораулар бирү.")
25
 
26
- print("Загрузка модели с 4-битной квантизацией...")
27
- quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
28
 
29
  tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False)
30
  if tok.pad_token is None:
@@ -33,12 +32,13 @@ if tok.pad_token is None:
33
 
34
  base = AutoModelForCausalLM.from_pretrained(
35
  BASE_MODEL_ID,
36
- quantization_config=quantization_config,
37
- device_map="auto"
 
38
  )
39
 
40
  print("Применяем LoRA адаптер...")
41
- model = PeftModel.from_pretrained(base, ADAPTER_ID)
42
  model.config.use_cache = False
43
  model.eval()
44
  print("✅ Модель успешно загружена!")
@@ -86,9 +86,9 @@ def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]:
86
  **enc,
87
  streamer=streamer,
88
  max_new_tokens=MAX_NEW_TOKENS,
89
- do_sample=True,
90
- temperature=TEMPERATURE,
91
- top_p=TOP_P,
92
  repetition_penalty=REPETITION_PENALTY,
93
  eos_token_id=tok.eos_token_id,
94
  pad_token_id=tok.pad_token_id,
 
4
  import requests
5
  from typing import List, Dict, Iterator
6
  from threading import Thread
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
  from peft import PeftModel
9
  import json
10
 
 
16
  if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]):
17
  raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID")
18
 
19
+ MAX_NEW_TOKENS = 1024
20
  TEMPERATURE = 1
21
  TOP_P = 0.9
22
  REPETITION_PENALTY = 1.05
23
 
24
  SYS_PROMPT_TT = ("Син-цифрлы ярдәмче (ир-ат нәселе). Сине Сбербанк дирекциясенең ESG да уйлап таптылар. Син барлык өлкәләрдә, бигрәк тә Татарстанга кагылышлы өлкәләрдә кызыклы кулланучы эксперты! Ул сезгә бик күп сораулар бирәчәк, ә сезнең эшегез-шәрехләр бирү, кулланучының сорауларына җавап бирү, адымлап киңәшләр, мисаллар бирү һәм, кирәк булганда, кулланучыга аныклаучы сораулар бирү.")
25
 
26
+
 
27
 
28
  tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False)
29
  if tok.pad_token is None:
 
32
 
33
  base = AutoModelForCausalLM.from_pretrained(
34
  BASE_MODEL_ID,
35
+ torch_dtype=torch.float16,
36
+ device_map="auto",
37
+ low_cpu_mem_usage=True
38
  )
39
 
40
  print("Применяем LoRA адаптер...")
41
+ model = PeftModel.from_pretrained(base, ADAPTER_ID, torch_dtype=torch.float16)
42
  model.config.use_cache = False
43
  model.eval()
44
  print("✅ Модель успешно загружена!")
 
86
  **enc,
87
  streamer=streamer,
88
  max_new_tokens=MAX_NEW_TOKENS,
89
+ do_sample=False,
90
+ # temperature=TEMPERATURE,
91
+ # top_p=TOP_P,
92
  repetition_penalty=REPETITION_PENALTY,
93
  eos_token_id=tok.eos_token_id,
94
  pad_token_id=tok.pad_token_id,