Update app.py
Browse files
app.py
CHANGED
|
@@ -186,10 +186,12 @@ elif selected_model == "Генерация текста GPT-моделью по
|
|
| 186 |
user_text_input = st.text_area('Введите информацию о себе для формиорования гороскопа:')
|
| 187 |
|
| 188 |
# GPT2
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
| 193 |
if st.button('Сделать гороскоп'):
|
| 194 |
start_time = time.time()
|
| 195 |
input_ids = tokenizer.encode(user_text_input, return_tensors="pt").to(device)
|
|
|
|
| 186 |
user_text_input = st.text_area('Введите информацию о себе для формиорования гороскопа:')
|
| 187 |
|
| 188 |
# GPT2
|
| 189 |
+
model_name_or_path = "sberbank-ai/rugpt3small_based_on_gpt2"
|
| 190 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
|
| 191 |
+
model = GPT2LMHeadModel.from_pretrained(model_name_or_path).to(DEVICE)
|
| 192 |
+
model = GPT2LMHeadModel(config)
|
| 193 |
+
model.load_state_dict(torch.load('model_dict.pt', map_location=device))
|
| 194 |
+
|
| 195 |
if st.button('Сделать гороскоп'):
|
| 196 |
start_time = time.time()
|
| 197 |
input_ids = tokenizer.encode(user_text_input, return_tensors="pt").to(device)
|