Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import math | |
| import gradio as gr | |
| import os | |
| import urllib.request | |
| # 设备 | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # 下载input.txt (vocab) | |
| if not os.path.exists('input.txt'): | |
| url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip' # 你的数据集 | |
| urllib.request.urlretrieve(url, 'wikitext.zip') | |
| import zipfile | |
| with zipfile.ZipFile('wikitext.zip', 'r') as z: | |
| z.extract('wikitext-2-raw/wiki.train.raw', '.') | |
| os.rename('wikitext-2-raw/wiki.train.raw', 'input.txt') | |
| os.remove('wikitext.zip') | |
| os.rmdir('wikitext-2-raw') | |
| # load vocab | |
| with open('input.txt', 'r') as f: | |
| text = f.read() | |
| lines = text.split('\n') | |
| text = '\n'.join(line for line in lines if not line.startswith('=')) | |
| words = text.split() | |
| words_set = sorted(set(words)) | |
| unk_token = '<UNK>' | |
| words_set = [unk_token] + words_set | |
| vocab_size = len(words_set) | |
| stoi = {word: i for i, word in enumerate(words_set)} | |
| itos = {i: word for word, i in stoi.items()} | |
| # 模型类 (简化) | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_len=5000): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1).float() | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer('pe', pe.unsqueeze(0)) | |
| def forward(self, x): | |
| return x + self.pe[:, :x.size(1)] | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, d_model, n_heads, dropout): | |
| super().__init__() | |
| self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.ffn = nn.Sequential(nn.Linear(d_model, 4 * d_model), nn.ReLU(), nn.Linear(4 * d_model, d_model)) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None): | |
| attn_out, _ = self.attn(x, x, x, attn_mask=mask) | |
| x = self.norm1(x + self.dropout(attn_out)) | |
| ffn_out = self.ffn(x) | |
| x = self.norm2(x + self.dropout(ffn_out)) | |
| return x | |
| class GPT(nn.Module): | |
| def __init__(self, vocab_size, d_model, n_heads, n_layers, dropout): | |
| super().__init__() | |
| self.token_emb = nn.Embedding(vocab_size, d_model) | |
| self.pos_enc = PositionalEncoding(d_model) | |
| self.blocks = nn.ModuleList([DecoderBlock(d_model, n_heads, dropout) for _ in range(n_layers)]) | |
| self.ln_f = nn.LayerNorm(d_model) | |
| self.head = nn.Linear(d_model, vocab_size) | |
| self.d_model = d_model | |
| def forward(self, x, mask=None, targets=None): | |
| B, L = x.shape | |
| x = self.token_emb(x) * math.sqrt(self.d_model) | |
| x = self.pos_enc(x) | |
| for block in self.blocks: | |
| x = block(x, mask=mask) | |
| x = self.ln_f(x) | |
| logits = self.head(x) | |
| if targets is None: | |
| return logits | |
| else: | |
| return F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) | |
| # 超参 | |
| d_model = 64 | |
| n_heads = 4 | |
| n_layers = 2 | |
| dropout = 0.1 | |
| block_size = 32 | |
| # Load model (上传model.pth) | |
| model = GPT(vocab_size, d_model, n_heads, n_layers, dropout).to(device) | |
| if os.path.exists('model.pth'): | |
| model.load_state_dict(torch.load('model.pth', map_location=device)) | |
| model.eval() | |
| # Causal mask | |
| def generate_causal_mask(block_size): | |
| mask = torch.triu(torch.ones(block_size, block_size), diagonal=1).bool().to(device) | |
| return mask | |
| # generate (续写) | |
| def generate(model, prompt_str, max_new_tokens=50, top_k=40, temperature=0.8): | |
| model.eval() | |
| words = prompt_str.split() | |
| context = torch.tensor([stoi.get(w, 0) for w in words], dtype=torch.long).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| for _ in range(max_new_tokens): | |
| L = context.shape[1] | |
| mask = generate_causal_mask(L) | |
| logits = model(context, mask=mask) | |
| logits = logits[0, -1] / temperature | |
| probs = F.softmax(logits, dim=-1) | |
| topk_probs, topk_ids = torch.topk(probs, top_k) | |
| next_id = topk_ids[torch.multinomial(topk_probs, 1)].item() | |
| context = torch.cat([context, torch.tensor([[next_id]], device=device)], dim=1) | |
| if next_id == stoi['<UNK>']: | |
| break | |
| gen_words = [itos[i.item()] for i in context[0]] | |
| return ' '.join(gen_words) | |
| # qa_generate (问答) | |
| def qa_generate(question): | |
| prompt = f"Q: {question} A:" | |
| generated = generate(model, prompt, max_new_tokens=30, top_k=30, temperature=0.7) | |
| answer = generated.split("A:")[-1].strip() if "A:" in generated else generated | |
| return answer | |
| # 双模式 | |
| def generate_text(mode, input_text): | |
| if mode == "续写": | |
| return generate(model, input_text, 50) | |
| elif mode == "问答": | |
| return qa_generate(input_text) | |
| else: | |
| return "选模式!" | |
| # Gradio | |
| demo = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Dropdown(choices=["续写", "问答"], label="模式", value="续写"), | |
| gr.Textbox(label="输入 (prompt / 问题)", placeholder="e.g., my little panda / what is panda", lines=2) | |
| ], | |
| outputs=gr.Textbox(label="输出", lines=10), | |
| title="Mini-GPT 双模式", | |
| description="续写故事或问答!", | |
| examples=[ | |
| ["续写", "my little panda"], | |
| ["问答", "what is panda"] | |
| ], | |
| theme=gr.themes.Soft() | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |