TSUNG / app.py
a2948764576888's picture
Update app.py
8c485e0 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# 初始化 Qwen 模型與 tokenizer(加上 trust_remote_code)
model_id = "Qwen/Qwen-1_8B-Chat"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 載入模型:{model_id} on {device}")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.float32
).to(device)
# 建立 FastAPI 應用
app = FastAPI()
chat_history = []
class Prompt(BaseModel):
text: str
reset: bool = False
@app.post("/chat")
async def chat(prompt: Prompt):
global chat_history
print(f"\n📝 使用者輸入:{prompt.text}")
if prompt.reset:
chat_history = []
print("🔄 Chat history 已重置")
chat_history.append({"role": "user", "content": prompt.text})
# 組合 ChatML 格式
chatml = ""
for msg in chat_history:
chatml += f"<|im_start|>{msg['role']}\n{msg['content']}\n<|im_end|>\n"
chatml += "<|im_start|>assistant\n"
try:
inputs = tokenizer(chatml, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
print("🧠 原始模型回覆:", response)
# 擷取 assistant 回覆內容
if "<|im_start|>assistant\n" in response:
reply = response.split("<|im_end|>")[0].split("<|im_start|>assistant\n")[-1].strip()
else:
reply = response # fallback
if not reply:
reply = "⚠️ 模型未產生回覆,請稍後再試。"
print("⚠️ 回覆為空字串")
chat_history.append({"role": "assistant", "content": reply})
print("✅ 最終回覆:", reply)
return {"reply": reply}
except Exception as e:
print("❌ 模型回應錯誤:", e)
return {"reply": "目前無法取得模型回覆,請稍後再試。"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)