QIUQIU-miniGPT / app.py
Qiuqiu
Upload 3 files
0797801 verified
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import gradio as gr
import os
import urllib.request
# 设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 下载input.txt (vocab)
if not os.path.exists('input.txt'):
url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip' # 你的数据集
urllib.request.urlretrieve(url, 'wikitext.zip')
import zipfile
with zipfile.ZipFile('wikitext.zip', 'r') as z:
z.extract('wikitext-2-raw/wiki.train.raw', '.')
os.rename('wikitext-2-raw/wiki.train.raw', 'input.txt')
os.remove('wikitext.zip')
os.rmdir('wikitext-2-raw')
# load vocab
with open('input.txt', 'r') as f:
text = f.read()
lines = text.split('\n')
text = '\n'.join(line for line in lines if not line.startswith('='))
words = text.split()
words_set = sorted(set(words))
unk_token = '<UNK>'
words_set = [unk_token] + words_set
vocab_size = len(words_set)
stoi = {word: i for i, word in enumerate(words_set)}
itos = {i: word for word, i in stoi.items()}
# 模型类 (简化)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class DecoderBlock(nn.Module):
def __init__(self, d_model, n_heads, dropout):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(nn.Linear(d_model, 4 * d_model), nn.ReLU(), nn.Linear(4 * d_model, d_model))
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_out, _ = self.attn(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_out))
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x
class GPT(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_layers, dropout):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_enc = PositionalEncoding(d_model)
self.blocks = nn.ModuleList([DecoderBlock(d_model, n_heads, dropout) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size)
self.d_model = d_model
def forward(self, x, mask=None, targets=None):
B, L = x.shape
x = self.token_emb(x) * math.sqrt(self.d_model)
x = self.pos_enc(x)
for block in self.blocks:
x = block(x, mask=mask)
x = self.ln_f(x)
logits = self.head(x)
if targets is None:
return logits
else:
return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
# 超参
d_model = 64
n_heads = 4
n_layers = 2
dropout = 0.1
block_size = 32
# Load model (上传model.pth)
model = GPT(vocab_size, d_model, n_heads, n_layers, dropout).to(device)
if os.path.exists('model.pth'):
model.load_state_dict(torch.load('model.pth', map_location=device))
model.eval()
# Causal mask
def generate_causal_mask(block_size):
mask = torch.triu(torch.ones(block_size, block_size), diagonal=1).bool().to(device)
return mask
# generate (续写)
def generate(model, prompt_str, max_new_tokens=50, top_k=40, temperature=0.8):
model.eval()
words = prompt_str.split()
context = torch.tensor([stoi.get(w, 0) for w in words], dtype=torch.long).unsqueeze(0).to(device)
with torch.no_grad():
for _ in range(max_new_tokens):
L = context.shape[1]
mask = generate_causal_mask(L)
logits = model(context, mask=mask)
logits = logits[0, -1] / temperature
probs = F.softmax(logits, dim=-1)
topk_probs, topk_ids = torch.topk(probs, top_k)
next_id = topk_ids[torch.multinomial(topk_probs, 1)].item()
context = torch.cat([context, torch.tensor([[next_id]], device=device)], dim=1)
if next_id == stoi['<UNK>']:
break
gen_words = [itos[i.item()] for i in context[0]]
return ' '.join(gen_words)
# qa_generate (问答)
def qa_generate(question):
prompt = f"Q: {question} A:"
generated = generate(model, prompt, max_new_tokens=30, top_k=30, temperature=0.7)
answer = generated.split("A:")[-1].strip() if "A:" in generated else generated
return answer
# 双模式
def generate_text(mode, input_text):
if mode == "续写":
return generate(model, input_text, 50)
elif mode == "问答":
return qa_generate(input_text)
else:
return "选模式!"
# Gradio
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Dropdown(choices=["续写", "问答"], label="模式", value="续写"),
gr.Textbox(label="输入 (prompt / 问题)", placeholder="e.g., my little panda / what is panda", lines=2)
],
outputs=gr.Textbox(label="输出", lines=10),
title="Mini-GPT 双模式",
description="续写故事或问答!",
examples=[
["续写", "my little panda"],
["问答", "what is panda"]
],
theme=gr.themes.Soft()
)
demo.launch(server_name="0.0.0.0", server_port=7860)