import os, json, random, numpy as np, torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import IterableDataset, DataLoader import sentencepiece as spm import requests # =============================== # 0️⃣ 환경 설정 # =============================== TOKENIZER_PATH = "ko_unigram.model" DATA_PATH = "corpus.txt" MAX_LEN = 128 EMBED_DIM = 384 LATENT_DIM = 384 BATCH_SIZE = 384 NEGATIVE_RATIO = 1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # =============================== # 1️⃣ 파일 다운로드 # =============================== def download_file(url, save_path): r = requests.get(url, stream=True) r.raise_for_status() with open(save_path, "wb") as f: for chunk in r.iter_content(8192*2): f.write(chunk) print(f"Saved {save_path}") if not os.path.exists(TOKENIZER_PATH): download_file( "https://huggingface.co/Yuchan5386/inlam-100m/resolve/main/ko_unigram.model?download=true", TOKENIZER_PATH, ) if not os.path.exists(DATA_PATH): download_file( "https://huggingface.co/datasets/Yuchan5386/1/resolve/main/shuffled_corpus.txt?download=true", DATA_PATH, ) # =============================== # 2️⃣ 토크나이저 준비 # =============================== sp = spm.SentencePieceProcessor(TOKENIZER_PATH) pad_id = sp.piece_to_id("") if sp.piece_to_id("") != -1 else 0 vocab_size = sp.get_piece_size() def encode_sentence(sentence, max_len=MAX_LEN): return sp.encode(sentence, out_type=int)[:max_len] def pad_sentence(tokens): return tokens + [pad_id] * (MAX_LEN - len(tokens)) # =============================== # 3️⃣ Streaming Dataset # =============================== class PairStream(IterableDataset): def __init__(self, txt_path, negative_ratio): self.sentences = [line.strip() for line in open(txt_path, encoding="utf-8") if line.strip()] self.neg_ratio = negative_ratio def __iter__(self): while True: for s1 in self.sentences: x1 = pad_sentence(encode_sentence(s1)) yield (torch.tensor(x1), torch.tensor(x1), torch.tensor(1.0)) for _ in range(self.neg_ratio): s2 = random.choice(self.sentences) x2 = pad_sentence(encode_sentence(s2)) yield (torch.tensor(x1), torch.tensor(x2), torch.tensor(0.0)) stream_ds = PairStream(DATA_PATH, NEGATIVE_RATIO) loader = DataLoader(stream_ds, batch_size=BATCH_SIZE) # =============================== # 4️⃣ Sentence Encoder 정의 # =============================== class EncoderBlock(nn.Module): def __init__(self, embed_dim, latent_dim): super().__init__() self.mha = nn.MultiheadAttention(embed_dim, num_heads=8, batch_first=True) self.WB = nn.Linear(embed_dim, embed_dim * 3) self.W = nn.Linear(embed_dim * 3 // 2, embed_dim) self.ln1 = nn.LayerNorm(embed_dim) self.ln2 = nn.LayerNorm(embed_dim) self.ln3 = nn.LayerNorm(embed_dim) def forward(self, x): x1 = self.ln1(x) attn, _ = self.mha(x1, x1, x1) x = attn + x x2 = self.ln2(x) w = self.WB(x2) a, b = torch.chunk(w, 2, dim=-1) g = F.silu(a) * b out = self.W(g) return self.ln3(out) + x class SentenceEncoder(nn.Module): def __init__(self, vocab_size, embed_dim, latent_dim, max_len): super().__init__() self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_id) self.pos = nn.Embedding(max_len, embed_dim) self.blocks = nn.ModuleList([EncoderBlock(embed_dim, latent_dim) for _ in range(2)]) self.ln_f = nn.LayerNorm(embed_dim) self.latent = nn.Linear(embed_dim, latent_dim) def forward(self, x): b, l = x.shape pos_ids = torch.arange(l, device=x.device).unsqueeze(0).expand(b, l) x = self.embed(x) + self.pos(pos_ids) for block in self.blocks: x = block(x) x = self.ln_f(x) x = x.mean(dim=1) return torch.tanh(self.latent(x)) encoder = SentenceEncoder(vocab_size, EMBED_DIM, LATENT_DIM, MAX_LEN).to(device) # =============================== # 5️⃣ Cosine + Contrastive Loss # =============================== def cosine_sim(v1, v2, eps=1e-8): dot = (v1 * v2).sum(dim=-1) norm = v1.norm(dim=-1) * v2.norm(dim=-1) + eps return dot / norm def contrastive_loss(pred, label, margin=0.7): dist = 1 - pred pos_loss = label * dist.pow(2) neg_loss = (1 - label) * (torch.clamp(margin - dist, min=0).pow(2)) return (pos_loss + neg_loss).mean() optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-5) encoder = torch.compile(encoder) cosine_sim = torch.compile(cosine_sim) contrastive_loss = torch.compile(contrastive_loss) # =============================== # 6️⃣ 학습 루프 # =============================== steps_per_epoch = 23119910 // BATCH_SIZE from tqdm import tqdm encoder.train() progress = tqdm(range(steps_per_epoch), desc="Training", ncols=120) for step, batch in zip(progress, loader): x1, x2, y = [b.to(device) for b in batch] # forward v1 = encoder(x1) v2 = encoder(x2) pred = cosine_sim(v1, v2) loss = contrastive_loss(pred, y) # backward optimizer.zero_grad() loss.backward() optimizer.step() # 📉 tqdm에 loss 표시 progress.set_postfix({"loss": f"{loss.item():.4f}"}) # =============================== # 7️⃣ 검색용 벡터 생성 # =============================== LIMIT = 4000 prompts = [] for i, line in enumerate(open(DATA_PATH, "r", encoding="utf-8")): if i >= LIMIT: break line = line.strip() if line: prompts.append(line) @torch.no_grad() def get_sentence_vector(sentence): tokens = pad_sentence(encode_sentence(sentence)) x = torch.tensor([tokens]).to(device) return encoder(x).cpu().numpy()[0] if os.path.exists("corpus_vectors.npy"): corpus_vectors = np.load("corpus_vectors.npy") else: corpus_vectors = np.stack([get_sentence_vector(p) for p in prompts]).astype(np.float16) np.save("corpus_vectors.npy", corpus_vectors) corpus_norms = np.linalg.norm(corpus_vectors, axis=1) # =============================== # 8️⃣ 검색 함수 # =============================== def search(query, top_k=3): q_vec = get_sentence_vector(query).astype(np.float16) sims = corpus_vectors @ q_vec sims /= (corpus_norms * np.linalg.norm(q_vec) + 1e-8) top_idx = np.argsort(sims)[::-1][:top_k] return [(prompts[i], float(sims[i])) for i in top_idx] # =============================== # 🔟 테스트 # =============================== query = "점심이나 저녁을 우리와 함께 먹을 건가요?" results = search(query) for p, s in results: print(f"Prompt: {p}\n유사도: {s:.3f}\n---")