File size: 2,338 Bytes
20d700f
f83fe9a
c5cf3dc
 
9781d54
8c485e0
 
 
 
 
 
c5cf3dc
8c485e0
1a5b265
 
c5cf3dc
8c485e0
c5cf3dc
 
8c485e0
9781d54
f83fe9a
 
9781d54
9749329
 
f83fe9a
 
c5cf3dc
f83fe9a
 
 
9749329
8c485e0
 
9749329
 
8c485e0
9749329
 
 
 
 
 
 
 
 
c1e4301
8c485e0
c5cf3dc
 
 
 
 
 
 
8c485e0
 
 
 
 
 
 
 
 
 
 
 
 
9749329
 
8c485e0
f83fe9a
8c485e0
c1e4301
 
f83fe9a
c1e4301
20d700f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# 初始化 Qwen 模型與 tokenizer(加上 trust_remote_code)
model_id = "Qwen/Qwen-1_8B-Chat"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"🚀 載入模型:{model_id} on {device}")

tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    torch_dtype=torch.float32
).to(device)

# 建立 FastAPI 應用
app = FastAPI()

chat_history = []

class Prompt(BaseModel):
    text: str
    reset: bool = False

@app.post("/chat")
async def chat(prompt: Prompt):
    global chat_history

    print(f"\n📝 使用者輸入:{prompt.text}")
    if prompt.reset:
        chat_history = []
        print("🔄 Chat history 已重置")

    chat_history.append({"role": "user", "content": prompt.text})

    # 組合 ChatML 格式
    chatml = ""
    for msg in chat_history:
        chatml += f"<|im_start|>{msg['role']}\n{msg['content']}\n<|im_end|>\n"
    chatml += "<|im_start|>assistant\n"

    try:
        inputs = tokenizer(chatml, return_tensors="pt").to(device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7,
            top_p=0.9
        )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

        print("🧠 原始模型回覆:", response)

        # 擷取 assistant 回覆內容
        if "<|im_start|>assistant\n" in response:
            reply = response.split("<|im_end|>")[0].split("<|im_start|>assistant\n")[-1].strip()
        else:
            reply = response  # fallback

        if not reply:
            reply = "⚠️ 模型未產生回覆,請稍後再試。"
            print("⚠️ 回覆為空字串")

        chat_history.append({"role": "assistant", "content": reply})
        print("✅ 最終回覆:", reply)
        return {"reply": reply}

    except Exception as e:
        print("❌ 模型回應錯誤:", e)
        return {"reply": "目前無法取得模型回覆,請稍後再試。"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)