ghosthets commited on
Commit
02f2b8a
·
verified ·
1 Parent(s): 4637032

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -8
app.py CHANGED
@@ -1,16 +1,18 @@
1
  import flask
2
  from flask import request, jsonify
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
4
  import torch
5
 
6
  app = flask.Flask(__name__)
7
 
8
  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
9
 
10
- print("🔄 Loading fast chat model...")
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(model_id)
13
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
 
14
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  model.to(device)
@@ -26,16 +28,39 @@ def chat():
26
  if not msg:
27
  return jsonify({"error": "No message sent"}), 400
28
 
29
- inputs = tokenizer(msg, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
30
  output = model.generate(
31
  **inputs,
32
- max_length=200,
33
  do_sample=True,
34
- top_p=0.92,
35
- temperature=0.7
 
36
  )
37
 
38
- reply = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  return jsonify({"reply": reply})
40
 
41
  except Exception as e:
 
1
  import flask
2
  from flask import request, jsonify
3
+ # Use AutoModelForCausalLM for Decoder-only models like TinyLlama
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
 
7
  app = flask.Flask(__name__)
8
 
9
  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
10
 
11
+ print("🔄 Loading TinyLlama model...")
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
+ # Load using AutoModelForCausalLM
15
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) # Using bfloat16 for better memory/speed on GPU
16
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  model.to(device)
 
28
  if not msg:
29
  return jsonify({"error": "No message sent"}), 400
30
 
31
+ # --- Key Change 1: Apply Chat Template ---
32
+ # Format the user message into the model's required chat template
33
+ chat_history = [{"role": "user", "content": msg}]
34
+ # add_generation_prompt=True ensures the model knows it needs to respond
35
+ formatted_prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
36
+
37
+ # Tokenize the formatted prompt
38
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
39
+
40
+ # Generation
41
  output = model.generate(
42
  **inputs,
43
+ max_length=256,
44
  do_sample=True,
45
+ top_p=0.9,
46
+ temperature=0.7,
47
+ eos_token_id=tokenizer.eos_token_id
48
  )
49
 
50
+ # Decode the output
51
+ full_reply = tokenizer.decode(output[0], skip_special_tokens=False)
52
+
53
+ # --- Key Change 2: Extract only the generated response ---
54
+ # The output includes the input prompt, so we extract only the response part.
55
+
56
+ # Identify the assistant marker used by TinyLlama's chat template
57
+ if "[/INST]" in full_reply:
58
+ # This structure is often used: <s>[INST] User Prompt [/INST] Assistant Reply
59
+ reply = full_reply.split("[/INST]")[-1].strip()
60
+ else:
61
+ # Fallback: decode only the newly generated tokens
62
+ reply = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
63
+
64
  return jsonify({"reply": reply})
65
 
66
  except Exception as e: