Yuchan commited on
Commit
cc78280
ยท
verified ยท
1 Parent(s): 6075db7

Update Model_torch.py

Browse files
Files changed (1) hide show
  1. Model_torch.py +169 -189
Model_torch.py CHANGED
@@ -1,15 +1,23 @@
1
- import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- from torch.utils.data import Dataset, DataLoader
5
- import numpy as np
6
  import sentencepiece as spm
7
  import requests
8
- import os
9
 
 
 
 
10
  TOKENIZER_PATH = "ko_unigram.model"
11
- DATA_PATH = "corpus.txt" # 36M ๋ฌธ์žฅ ํ…์ŠคํŠธ ํŒŒ์ผ
12
- max_len = 128
 
 
 
 
 
 
 
13
  # ===============================
14
  # 1๏ธโƒฃ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
15
  # ===============================
@@ -19,215 +27,187 @@ def download_file(url, save_path):
19
  with open(save_path, "wb") as f:
20
  for chunk in r.iter_content(8192*2):
21
  f.write(chunk)
22
- print(f"โœ… {save_path} ์ €์žฅ๋จ")
23
 
24
  if not os.path.exists(TOKENIZER_PATH):
25
  download_file(
26
  "https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true",
27
- TOKENIZER_PATH
28
  )
29
  if not os.path.exists(DATA_PATH):
30
  download_file(
31
  "https://huggingface.co/datasets/Yuchan5386/1/resolve/main/shuffled_corpus.txt?download=true",
32
- DATA_PATH
33
  )
 
34
  # ===============================
35
- # SentencePiece
36
  # ===============================
37
- sp = spm.SentencePieceProcessor("ko_unigram.model")
38
-
39
  pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
40
- start_id = sp.piece_to_id("<start>")
41
- end_id = sp.piece_to_id("<end>")
42
  vocab_size = sp.get_piece_size()
43
- max_len = 512
44
- batch_size = 32
45
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
46
 
47
- def text_to_ids(text):
48
- return sp.encode(text, out_type=int)
49
 
50
- def ids_to_text(ids):
51
- return sp.decode(ids)
52
 
53
  # ===============================
54
- # Dataset
55
  # ===============================
56
- class TextDataset(Dataset):
57
- def __init__(self, file_path, num_lines=None):
58
- self.lines = []
59
- with open(file_path, "r", encoding="utf-8") as f:
60
- for i, line in enumerate(f):
61
- if num_lines is not None and i >= num_lines:
62
- break
63
- line = line.strip()
64
- if line:
65
- self.lines.append(line)
66
-
67
- def __len__(self):
68
- return len(self.lines)
69
-
70
- def __getitem__(self, idx):
71
- text = self.lines[idx]
72
- ids = text_to_ids(text)[:max_len-1]
73
- full_input = ids + [end_id]
74
- pad_len = max_len - len(full_input)
75
- full_input += [pad_id]*pad_len
76
- target = full_input[1:] + [pad_id]
77
- return torch.tensor(full_input, dtype=torch.long), torch.tensor(target, dtype=torch.long)
78
-
79
- dataset = TextDataset("corpus.txt", num_lines=100000)
80
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
81
 
82
  # ===============================
83
- # ๋ชจ๋ธ ์ •์˜
84
  # ===============================
85
- class SwiGLU(nn.Module):
86
- def __init__(self, d_model):
87
  super().__init__()
88
- self.W = nn.Linear(d_model, 3500)
89
- self.W1 = nn.Linear(1750, d_model)
90
- def forward(self, x):
91
- x = self.W(x.float())
92
- a,b = x.chunk(2, dim=-1)
93
- return self.W1(F.silu(a)*b).to(x.dtype)
94
 
95
- class SparseCausalAttention(nn.Module):
96
- def __init__(self, num_heads, head_dim, window_size=8):
 
 
 
 
 
 
 
 
 
 
 
97
  super().__init__()
98
- self.num_heads = num_heads
99
- self.head_dim = head_dim
100
- self.window_size = window_size
101
- self.q = nn.Linear(head_dim*num_heads, num_heads*head_dim)
102
- self.k = nn.Linear(head_dim*num_heads, num_heads*head_dim)
103
- self.v = nn.Linear(head_dim*num_heads, num_heads*head_dim)
104
- self.out = nn.Linear(num_heads*head_dim, head_dim*num_heads)
105
 
106
  def forward(self, x):
107
- B,L,D = x.shape
108
- q = self.q(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)
109
- k = self.k(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)
110
- v = self.v(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)
111
- q = q / (self.head_dim ** 0.5)
112
-
113
- attn_scores = torch.matmul(q, k.transpose(-2,-1))
114
- mask = torch.tril(torch.ones(L,L, device=x.device))
115
- band_mask = torch.triu(mask, -self.window_size)
116
- attn_scores = attn_scores.masked_fill(band_mask==0, float('-inf'))
117
- attn_probs = F.softmax(attn_scores, dim=-1)
118
- out = torch.matmul(attn_probs, v)
119
- out = out.transpose(1,2).reshape(B,L,D)
120
- return self.out(out)
121
-
122
- class Lo(nn.Module):
123
- def __init__(self,d_model):
124
- super().__init__()
125
- self.d = nn.Linear(d_model,64)
126
- self.w = nn.Linear(64,d_model)
127
- self.norm = nn.LayerNorm(d_model)
128
- def forward(self,x):
129
- return self.norm(self.w(F.silu(self.d(x))) + x)
130
-
131
- class Block(nn.Module):
132
- def __init__(self,d_model):
133
- super().__init__()
134
- self.attn = SparseCausalAttention(num_heads=2, head_dim=64)
135
- self.glu = SwiGLU(d_model)
136
- self.norm = nn.LayerNorm(d_model)
137
- self.lo = Lo(d_model)
138
- def forward(self,x):
139
- x = self.attn(x)
140
- x = self.norm(self.glu(x)+x)
141
- x = self.lo(x)
142
- return x
143
-
144
- class ReLM(nn.Module):
145
- def __init__(self,vocab_size,max_seq_len,d_model,n_layers):
146
- super().__init__()
147
- self.token_embedding = nn.Embedding(vocab_size,d_model)
148
- self.pos_embedding = nn.Embedding(max_seq_len,d_model)
149
- self.blocks = nn.ModuleList([Block(d_model) for _ in range(n_layers)])
150
- self.ln_f = nn.LayerNorm(d_model)
151
- self.d_model = d_model
152
-
153
- def forward(self,x):
154
- B,L = x.shape
155
- positions = torch.arange(L,device=x.device).unsqueeze(0)
156
- x = self.token_embedding(x) + self.pos_embedding(positions)
157
  for block in self.blocks:
158
  x = block(x)
159
  x = self.ln_f(x)
160
- logits = x @ self.token_embedding.weight.T
161
- return logits
162
-
163
- # ๋ชจ๋ธ, ์˜ตํ‹ฐ๋งˆ์ด์ €, ์Šค์ผ€์ค„๋Ÿฌ, ์†์‹ค ํ•จ์ˆ˜
164
- model = ReLM(vocab_size, max_len, 128, 2).to(device)
165
- optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
166
- scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
167
- loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id)
168
-
169
- # ์ •์  ๊ทธ๋ž˜ํ”„ ์ปดํŒŒ์ผ
170
- model = torch.compile(model, mode="default")
171
-
172
- scaler = torch.cuda.amp.GradScaler()
173
- epochs = 1
174
- for epoch in range(epochs):
175
- model.train()
176
- total_loss = 0
177
- for step, (x, y) in enumerate(dataloader):
178
- x, y = x.to(device), y.to(device)
179
- optimizer.zero_grad()
180
-
181
- with torch.cuda.amp.autocast(): # mixed precision
182
- logits = model(x)
183
- loss = loss_fn(logits.view(-1, vocab_size), y.view(-1))
184
-
185
- scaler.scale(loss).backward()
186
- scaler.unscale_(optimizer)
187
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
188
- scaler.step(optimizer)
189
- scaler.update()
190
-
191
- total_loss += loss.item()
192
- if step % 100 == 0:
193
- print(f"Epoch {epoch+1}, Step {step}, Loss: {loss.item():.4f}")
194
-
195
- scheduler.step()
196
- print(f"Epoch {epoch+1} ์™„๋ฃŒ, ํ‰๊ท  Loss: {total_loss/len(dataloader):.4f}")
197
-
198
- torch.save(model.state_dict(), "relm_model.pth")
199
- print("โœ… ๋ชจ๋ธ ์ €์žฅ ์™„๋ฃŒ!")
200
-
201
- # ===============================
202
- # Top-p ์ƒ˜ํ”Œ๋ง ์ƒ์„ฑ
203
- # ===============================
204
- def generate_text_topp(model, prompt, max_len=150, max_gen=150, p=0.9, temperature=0.6, min_len=20):
205
- model.eval()
206
- model_input = text_to_ids(f"<start> {prompt}")
207
- model_input = model_input[:max_len]
208
- generated = list(model_input)
209
- with torch.no_grad():
210
- for step in range(max_gen):
211
- input_seq = generated[-max_len:] if len(generated)>max_len else generated
212
- input_tensor = torch.tensor([input_seq + [pad_id]*(max_len-len(input_seq))], device=device)
213
- logits = model(input_tensor)
214
- next_logits = logits[0,len(input_seq)-1]
215
- next_logits[end_id] -= 5.0
216
- next_logits[pad_id] -= 10.0
217
- probs = F.softmax(next_logits/temperature, dim=-1).cpu().numpy()
218
- sorted_indices = np.argsort(probs)[::-1]
219
- sorted_probs = probs[sorted_indices]
220
- cumulative_probs = np.cumsum(sorted_probs)
221
- cutoff = np.searchsorted(cumulative_probs,p)
222
- top_indices = sorted_indices[:cutoff+1]
223
- top_probs = sorted_probs[:cutoff+1]
224
- top_probs /= top_probs.sum()
225
- next_token = np.random.choice(top_indices, p=top_probs)
226
- if next_token==end_id and len(generated)>=min_len:
227
- break
228
- generated.append(int(next_token))
229
- return ids_to_text(generated)
230
-
231
- # ํ…Œ์ŠคํŠธ
232
- print("\n===== ์ƒ์„ฑ ๊ฒฐ๊ณผ =====")
233
- print(generate_text_topp(model, "์ง€๋‚œ 2๋…„ ๋™์•ˆ ์ถœ์—ฐ์—ฐ์ด ๊ตญ๊ฐ€๊ฐ€ ํ•„์š”ํ•œ ์—ฐ๊ตฌ๋ฅผ", p=0.9))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, random, numpy as np, torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from torch.utils.data import IterableDataset, DataLoader
 
5
  import sentencepiece as spm
6
  import requests
 
7
 
8
+ # ===============================
9
+ # 0๏ธโƒฃ ํ™˜๊ฒฝ ์„ค์ •
10
+ # ===============================
11
  TOKENIZER_PATH = "ko_unigram.model"
12
+ DATA_PATH = "corpus.txt"
13
+ MAX_LEN = 128
14
+ EMBED_DIM = 384
15
+ LATENT_DIM = 384
16
+ BATCH_SIZE = 384
17
+ NEGATIVE_RATIO = 1
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
  # ===============================
22
  # 1๏ธโƒฃ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ
23
  # ===============================
 
27
  with open(save_path, "wb") as f:
28
  for chunk in r.iter_content(8192*2):
29
  f.write(chunk)
30
+ print(f"Saved {save_path}")
31
 
32
  if not os.path.exists(TOKENIZER_PATH):
33
  download_file(
34
  "https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true",
35
+ TOKENIZER_PATH,
36
  )
37
  if not os.path.exists(DATA_PATH):
38
  download_file(
39
  "https://huggingface.co/datasets/Yuchan5386/1/resolve/main/shuffled_corpus.txt?download=true",
40
+ DATA_PATH,
41
  )
42
+
43
  # ===============================
44
+ # 2๏ธโƒฃ ํ† ํฌ๋‚˜์ด์ € ์ค€๋น„
45
  # ===============================
46
+ sp = spm.SentencePieceProcessor(TOKENIZER_PATH)
 
47
  pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
 
 
48
  vocab_size = sp.get_piece_size()
 
 
 
49
 
50
+ def encode_sentence(sentence, max_len=MAX_LEN):
51
+ return sp.encode(sentence, out_type=int)[:max_len]
52
 
53
+ def pad_sentence(tokens):
54
+ return tokens + [pad_id] * (MAX_LEN - len(tokens))
55
 
56
  # ===============================
57
+ # 3๏ธโƒฃ Streaming Dataset
58
  # ===============================
59
+ class PairStream(IterableDataset):
60
+ def __init__(self, txt_path, negative_ratio):
61
+ self.sentences = [line.strip() for line in open(txt_path, encoding="utf-8") if line.strip()]
62
+ self.neg_ratio = negative_ratio
63
+
64
+ def __iter__(self):
65
+ while True:
66
+ for s1 in self.sentences:
67
+ x1 = pad_sentence(encode_sentence(s1))
68
+ yield (torch.tensor(x1), torch.tensor(x1), torch.tensor(1.0))
69
+ for _ in range(self.neg_ratio):
70
+ s2 = random.choice(self.sentences)
71
+ x2 = pad_sentence(encode_sentence(s2))
72
+ yield (torch.tensor(x1), torch.tensor(x2), torch.tensor(0.0))
73
+
74
+ stream_ds = PairStream(DATA_PATH, NEGATIVE_RATIO)
75
+ loader = DataLoader(stream_ds, batch_size=BATCH_SIZE)
 
 
 
 
 
 
 
 
76
 
77
  # ===============================
78
+ # 4๏ธโƒฃ Sentence Encoder ์ •์˜
79
  # ===============================
80
+ class EncoderBlock(nn.Module):
81
+ def __init__(self, embed_dim, latent_dim):
82
  super().__init__()
83
+ self.mha = nn.MultiheadAttention(embed_dim, num_heads=8, batch_first=True)
84
+ self.WB = nn.Linear(embed_dim, embed_dim * 3)
85
+ self.W = nn.Linear(embed_dim * 3 // 2, embed_dim)
86
+ self.ln1 = nn.LayerNorm(embed_dim)
87
+ self.ln2 = nn.LayerNorm(embed_dim)
88
+ self.ln3 = nn.LayerNorm(embed_dim)
89
 
90
+ def forward(self, x):
91
+ x1 = self.ln1(x)
92
+ attn, _ = self.mha(x1, x1, x1)
93
+ x = attn + x
94
+ x2 = self.ln2(x)
95
+ w = self.WB(x2)
96
+ a, b = torch.chunk(w, 2, dim=-1)
97
+ g = F.silu(a) * b
98
+ out = self.W(g)
99
+ return self.ln3(out) + x
100
+
101
+ class SentenceEncoder(nn.Module):
102
+ def __init__(self, vocab_size, embed_dim, latent_dim, max_len):
103
  super().__init__()
104
+ self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_id)
105
+ self.pos = nn.Embedding(max_len, embed_dim)
106
+ self.blocks = nn.ModuleList([EncoderBlock(embed_dim, latent_dim) for _ in range(2)])
107
+ self.ln_f = nn.LayerNorm(embed_dim)
108
+ self.latent = nn.Linear(embed_dim, latent_dim)
 
 
109
 
110
  def forward(self, x):
111
+ b, l = x.shape
112
+ pos_ids = torch.arange(l, device=x.device).unsqueeze(0).expand(b, l)
113
+ x = self.embed(x) + self.pos(pos_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  for block in self.blocks:
115
  x = block(x)
116
  x = self.ln_f(x)
117
+ x = x.mean(dim=1)
118
+ return torch.tanh(self.latent(x))
119
+
120
+ encoder = SentenceEncoder(vocab_size, EMBED_DIM, LATENT_DIM, MAX_LEN).to(device)
121
+
122
+ # ===============================
123
+ # 5๏ธโƒฃ Cosine + Contrastive Loss
124
+ # ===============================
125
+ def cosine_sim(v1, v2, eps=1e-8):
126
+ dot = (v1 * v2).sum(dim=-1)
127
+ norm = v1.norm(dim=-1) * v2.norm(dim=-1) + eps
128
+ return dot / norm
129
+
130
+ def contrastive_loss(pred, label, margin=0.7):
131
+ dist = 1 - pred
132
+ pos_loss = label * dist.pow(2)
133
+ neg_loss = (1 - label) * (torch.clamp(margin - dist, min=0).pow(2))
134
+ return (pos_loss + neg_loss).mean()
135
+
136
+ optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-5)
137
+
138
+
139
+ encoder = torch.compile(encoder)
140
+ cosine_sim = torch.compile(cosine_sim)
141
+ contrastive_loss = torch.compile(contrastive_loss)
142
+ # ===============================
143
+ # 6๏ธโƒฃ ํ•™์Šต ๋ฃจํ”„
144
+ # ===============================
145
+ steps_per_epoch = 23119910 // BATCH_SIZE
146
+
147
+ from tqdm import tqdm
148
+
149
+ encoder.train()
150
+
151
+ progress = tqdm(range(steps_per_epoch), desc="Training", ncols=120)
152
+
153
+ for step, batch in zip(progress, loader):
154
+ x1, x2, y = [b.to(device) for b in batch]
155
+
156
+ # forward
157
+ v1 = encoder(x1)
158
+ v2 = encoder(x2)
159
+ pred = cosine_sim(v1, v2)
160
+
161
+ loss = contrastive_loss(pred, y)
162
+
163
+ # backward
164
+ optimizer.zero_grad()
165
+ loss.backward()
166
+ optimizer.step()
167
+
168
+ # ๐Ÿ“‰ tqdm์— loss ํ‘œ์‹œ
169
+ progress.set_postfix({"loss": f"{loss.item():.4f}"})
170
+
171
+ # ===============================
172
+ # 7๏ธโƒฃ ๊ฒ€์ƒ‰์šฉ ๋ฒกํ„ฐ ์ƒ์„ฑ
173
+ # ===============================
174
+ LIMIT = 4000
175
+ prompts = []
176
+ for i, line in enumerate(open(DATA_PATH, "r", encoding="utf-8")):
177
+ if i >= LIMIT: break
178
+ line = line.strip()
179
+ if line:
180
+ prompts.append(line)
181
+
182
+ @torch.no_grad()
183
+ def get_sentence_vector(sentence):
184
+ tokens = pad_sentence(encode_sentence(sentence))
185
+ x = torch.tensor([tokens]).to(device)
186
+ return encoder(x).cpu().numpy()[0]
187
+
188
+ if os.path.exists("corpus_vectors.npy"):
189
+ corpus_vectors = np.load("corpus_vectors.npy")
190
+ else:
191
+ corpus_vectors = np.stack([get_sentence_vector(p) for p in prompts]).astype(np.float16)
192
+ np.save("corpus_vectors.npy", corpus_vectors)
193
+
194
+ corpus_norms = np.linalg.norm(corpus_vectors, axis=1)
195
+
196
+ # ===============================
197
+ # 8๏ธโƒฃ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
198
+ # ===============================
199
+ def search(query, top_k=3):
200
+ q_vec = get_sentence_vector(query).astype(np.float16)
201
+ sims = corpus_vectors @ q_vec
202
+ sims /= (corpus_norms * np.linalg.norm(q_vec) + 1e-8)
203
+ top_idx = np.argsort(sims)[::-1][:top_k]
204
+ return [(prompts[i], float(sims[i])) for i in top_idx]
205
+
206
+
207
+ # ===============================
208
+ # ๐Ÿ”Ÿ ํ…Œ์ŠคํŠธ
209
+ # ===============================
210
+ query = "์ ์‹ฌ์ด๋‚˜ ์ €๋…์„ ์šฐ๋ฆฌ์™€ ํ•จ๊ป˜ ๋จน์„ ๊ฑด๊ฐ€์š”?"
211
+ results = search(query)
212
+ for p, s in results:
213
+ print(f"Prompt: {p}\n์œ ์‚ฌ๋„: {s:.3f}\n---")