Yuchan commited on
Commit
7d7e323
·
verified ·
1 Parent(s): be37607

Update Mo_jax.py

Browse files
Files changed (1) hide show
  1. Mo_jax.py +164 -291
Mo_jax.py CHANGED
@@ -1,21 +1,13 @@
1
- # Flax + JAX TPU-ready reimplementation of your ReLM model and training loop.
2
- # Requirements:
3
- # pip install --upgrade "jax[tpu]" flax optax sentencepiece
4
-
5
- import os
6
- import math
7
- import numpy as np
8
- import sentencepiece as spm
9
  from functools import partial
10
- from typing import Any, Callable, Optional, Tuple, Sequence
11
- import requests
12
- import jax
13
- import jax.numpy as jnp
14
  from jax import random
15
  from flax import linen as nn
16
  from flax.training import train_state, checkpoints
17
  import optax
18
- import tqdm
19
 
20
  def download_file(url, save_path):
21
  r = requests.get(url, stream=True)
@@ -28,13 +20,10 @@ def download_file(url, save_path):
28
  # Config
29
  # ------------------
30
  SEQ_LEN = 512
31
- # global batch size (across all devices)
32
  GLOBAL_BATCH = 256
33
- # adjust for memory
34
- LIMIT = 200_000 # number of sequences to load (reduce if OOM)
35
  VOCAB_MODEL = "ko_unigram.model"
36
  CORPUS_PATH = "corpus.txt"
37
- DTYPE = jnp.bfloat16 if jax.local_devices()[0].platform == "tpu" else jnp.float32
38
  SEED = 42
39
  LEARNING_RATE = 1e-4
40
  EPOCHS = 1
@@ -51,351 +40,235 @@ if not os.path.exists(VOCAB_MODEL):
51
  VOCAB_MODEL
52
  )
53
 
54
- # Derived
55
  NUM_DEVICES = jax.device_count()
56
- assert GLOBAL_BATCH % NUM_DEVICES == 0, "GLOBAL_BATCH must be divisible by device count"
57
  PER_DEVICE_BATCH = GLOBAL_BATCH // NUM_DEVICES
58
-
59
- print("devices:", jax.devices())
60
- print("num_devices:", NUM_DEVICES, "per_device_batch:", PER_DEVICE_BATCH, "dtype:", DTYPE)
61
 
62
  # ------------------
63
- # Tokenizer loader
64
  # ------------------
65
  sp = spm.SentencePieceProcessor()
66
  sp.load(VOCAB_MODEL)
67
-
68
- pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>") != -1 else 0
69
  start_id = sp.piece_to_id("<start>")
70
  end_id = sp.piece_to_id("<end>")
71
  vocab_size = sp.get_piece_size()
72
  print("vocab_size:", vocab_size, "pad_id:", pad_id, "start_id:", start_id, "end_id:", end_id)
73
 
74
  # ------------------
75
- # Data pipeline (simple, numpy-based)
76
- # - Reads corpus line-by-line, tokenizes, pads/truncates to SEQ_LEN.
77
- # - Builds a numpy array (N, SEQ_LEN) for inputs and targets (shifted by 1).
78
- # - Shards batches across devices for pmap.
79
  # ------------------
80
- def line_to_ids(line: str, max_len: int = SEQ_LEN):
81
  ids = sp.encode(line.strip(), out_type=int)
82
- if len(ids) > max_len - 1:
83
- ids = ids[: max_len - 1]
84
- ids = ids + [end_id]
85
- pad_len = max_len - len(ids)
86
- ids = ids + [pad_id] * pad_len
87
  return np.array(ids, dtype=np.int32)
88
 
89
- def build_dataset(corpus_path: str, limit: int = LIMIT):
90
  arr = []
91
  with open(corpus_path, "r", encoding="utf-8") as f:
92
  for i, line in enumerate(f):
93
- if i >= limit:
94
- break
95
- line = line.strip()
96
- if not line:
97
- continue
98
  arr.append(line_to_ids(line))
99
- data = np.stack(arr, axis=0) # (N, SEQ_LEN)
100
- print("Loaded dataset shape:", data.shape)
101
  return data
102
 
103
- # create inputs and targets
104
  data_np = build_dataset(CORPUS_PATH, LIMIT)
105
  inputs = data_np
106
- targets = np.concatenate([data_np[:,1:], np.full((data_np.shape[0],1), pad_id, dtype=np.int32)], axis=1)
107
 
108
- # shuffle and create batches
109
- def create_batch_iter(inputs: np.ndarray, targets: np.ndarray, batch_size: int, rng: np.random.Generator):
110
- idx = np.arange(inputs.shape[0])
111
- rng.shuffle(idx)
112
- for i in range(0, len(idx) - batch_size + 1, batch_size):
113
  batch_idx = idx[i:i+batch_size]
114
- x = inputs[batch_idx]
115
- y = targets[batch_idx]
116
- yield x, y
117
 
118
- # helper to shard numpy batch for pmap: shape (num_devices, per_device, ...)
119
- def shard(xs: np.ndarray):
120
- return xs.reshape((NUM_DEVICES, -1) + xs.shape[1:])
121
 
122
  # ------------------
123
- # Flax model implementation
124
  # ------------------
125
  class SwiGLU(nn.Module):
126
  d_model: int
127
-
128
  @nn.compact
129
- def __call__(self, x):
130
- # project to 2*intermediate, then split
131
- proj = nn.Dense(self.d_model * 2, dtype=jnp.float32)(x) # keep proj in float32
132
- x_val, x_gate = jnp.split(proj, 2, axis=-1)
133
  out = x_val * nn.silu(x_gate)
134
- out = nn.Dense(self.d_model, dtype=jnp.float32)(out)
135
- return out.astype(x.dtype)
136
 
137
  class LoU(nn.Module):
138
- d_model: int
139
- clip_value: float = 5.0
140
- eps: float = 1e-6
141
-
142
  @nn.compact
143
- def __call__(self, x):
144
- # x: (batch, seq, d)
145
- x_f32 = x.astype(jnp.float32)
146
- residual = x_f32
147
-
148
- norm1 = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)
149
- x_norm = norm1(x_f32)
150
-
151
- Q = nn.Dense(self.d_model, dtype=jnp.float32)
152
- K = nn.Dense(self.d_model, dtype=jnp.float32)
153
- V = nn.Dense(self.d_model, dtype=jnp.float32)
154
-
155
- q = Q(x_norm)
156
- k = K(x_norm)
157
- v = V(x_norm)
158
-
159
- g_q = (jnp.tanh(q) + 1.0) / 2.0
160
- g_k = (jnp.tanh(k) + 1.0) / 2.0
161
- score = g_q * g_k # (b, seq, d)
162
-
163
- alpha_linear = nn.Dense(1, dtype=jnp.float32)
164
- alpha_dynamic = alpha_linear(x_norm) # (b, seq, 1)
165
-
166
- # EMA over time: use scan across sequence axis
167
- # transpose to (seq, batch, d) to scan over time
168
- score_t = jnp.transpose(score, (1,0,2))
169
- alpha_t = jnp.transpose(alpha_dynamic, (1,0,2))
170
-
171
- def step(carry, inputs):
172
- prev_ema = carry
173
- x_t, a_t = inputs
174
- new = a_t * x_t + (1.0 - a_t) * prev_ema
175
- return new, new
176
-
177
- init = score_t[0]
178
- _, ema_seq = jax.lax.scan(step, init, (score_t[1:], alpha_t[1:]))
179
- ema_full = jnp.concatenate([init[None, ...], ema_seq], axis=0) # (seq, batch, d)
180
- ema = jnp.transpose(ema_full, (1,0,2)) # (batch, seq, d)
181
-
182
- mean_last = jnp.mean(ema, axis=-1, keepdims=True)
183
- denom = jnp.maximum(mean_last, self.eps)
184
- score_norm = ema / denom
185
- score_clipped = jnp.clip(score_norm, -self.clip_value, self.clip_value)
186
-
187
- x_comb = score_clipped * v
188
- out = x_comb + residual
189
- out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(out)
190
- out = SwiGLU(self.d_model)(out.astype(x.dtype))
191
- return out.astype(x.dtype)
192
 
193
  class Lo(nn.Module):
194
- d_model: int
195
-
196
  @nn.compact
197
- def __call__(self, x):
198
- h = nn.Dense(64, dtype=jnp.float32)(x)
199
- h = nn.silu(h)
200
- h = nn.Dense(self.d_model, dtype=jnp.float32)(h)
201
- out = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)(h) + x
202
- return out.astype(x.dtype)
203
 
204
  class Block(nn.Module):
205
- d_model: int
206
-
207
  @nn.compact
208
- def __call__(self, x):
209
- x = LoU(self.d_model)(x)
210
- x = Lo(self.d_model)(x)
211
  return x
212
 
213
  class ReLM(nn.Module):
214
- vocab_size: int
215
- max_seq_len: int
216
- d_model: int
217
- n_layers: int
218
- dtype: Any = jnp.float32
219
-
220
  def setup(self):
221
- self.token_embed = nn.Embed(self.vocab_size, self.d_model, dtype=self.dtype)
222
- self.pos_embed = nn.Embed(self.max_seq_len, self.d_model, dtype=self.dtype)
223
- self.blocks = [Block(self.d_model) for _ in range(self.n_layers)]
224
- self.ln_f = nn.LayerNorm(epsilon=1e-5, dtype=jnp.float32)
225
-
226
- def __call__(self, x, deterministic=True):
227
- # x: (batch, seq)
228
- b, seq = x.shape
229
- positions = jnp.arange(seq)[None, :]
230
- x = self.token_embed(x) + self.pos_embed(positions)
231
- for blk in self.blocks:
232
- x = blk(x)
233
- x = self.ln_f(x)
234
- # tie weights: token embedding matrix
235
- embedding_matrix = self.token_embed.embedding # (vocab, d)
236
- logits = jnp.einsum("bld,vd->blv", x, embedding_matrix)
237
- return logits.astype(jnp.float32)
238
 
239
  # ------------------
240
  # Loss & metrics
241
  # ------------------
242
- def smoothed_cross_entropy(logits, targets, pad_id, eps=0.1):
243
- # logits: (b, seq, v)
244
- # targets: (b, seq) int32
245
- vocab = logits.shape[-1]
246
- logits = logits.reshape(-1, vocab)
247
- targets = targets.reshape(-1)
248
- mask = (targets != pad_id).astype(jnp.float32)
249
- # one-hot smoothed
250
- one_hot = jax.nn.one_hot(targets, vocab)
251
- smooth = (1.0 - eps) * one_hot + eps / float(vocab)
252
- log_probs = jax.nn.log_softmax(logits, axis=-1)
253
- loss_per_token = -jnp.sum(smooth * log_probs, axis=-1)
254
- loss_per_token = loss_per_token * mask
255
- denom = jnp.sum(mask) + 1e-8
256
- loss = jnp.sum(loss_per_token) / denom
257
- return loss
258
-
259
- def masked_perplexity_from_logits(logits, targets, pad_id, eps=0.1):
260
- vocab = logits.shape[-1]
261
- logits = logits.reshape(-1, vocab)
262
- targets = targets.reshape(-1)
263
- mask = (targets != pad_id).astype(jnp.float32)
264
- one_hot = jax.nn.one_hot(targets, vocab)
265
- smooth = (1.0 - eps) * one_hot + eps / float(vocab)
266
- log_probs = jax.nn.log_softmax(logits, axis=-1)
267
- loss_per_token = -jnp.sum(smooth * log_probs, axis=-1) * mask
268
- mean_loss = jnp.sum(loss_per_token) / (jnp.sum(mask) + 1e-8)
269
- return jnp.exp(mean_loss)
270
 
271
  # ------------------
272
- # Training state
273
  # ------------------
274
- class TrainState(train_state.TrainState):
275
- pass
276
-
277
- def create_train_state(rng, model, learning_rate):
278
- params = model.init(rng, jnp.zeros((1, SEQ_LEN), dtype=jnp.int32))["params"]
279
- tx = optax.chain(
280
- optax.clip_by_global_norm(1.0),
281
- optax.adamw(learning_rate=learning_rate, b1=0.9, b2=0.95, eps=1e-8)
282
- )
283
- return TrainState.create(apply_fn=model.apply, params=params, tx=tx)
284
 
285
  # ------------------
286
- # pmap'd step functions
287
  # ------------------
288
  @partial(jax.pmap, axis_name="batch")
289
- def train_step(state, batch_x, batch_y, rng):
290
  def loss_fn(params):
291
- logits = state.apply_fn({"params": params}, batch_x, deterministic=False)
292
- loss = smoothed_cross_entropy(logits, batch_y, pad_id)
293
- return loss, logits
294
-
295
- grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
296
- (loss, logits), grads = grad_fn(state.params)
297
- grads = jax.lax.pmean(grads, axis_name="batch")
298
- new_state = state.apply_gradients(grads=grads)
299
- # metrics
300
- ppl = masked_perplexity_from_logits(logits, batch_y, pad_id)
301
- metrics = {"loss": loss, "ppl": ppl}
302
- metrics = jax.lax.pmean(metrics, axis_name="batch")
303
- return new_state, metrics
304
-
305
- @partial(jax.pmap, axis_name="batch")
306
- def eval_step(state, batch_x, batch_y):
307
- logits = state.apply_fn({"params": state.params}, batch_x, deterministic=True)
308
- loss = smoothed_cross_entropy(logits, batch_y, pad_id)
309
- ppl = masked_perplexity_from_logits(logits, batch_y, pad_id)
310
- metrics = {"loss": loss, "ppl": ppl}
311
- metrics = jax.lax.pmean(metrics, axis_name="batch")
312
- return metrics
313
 
314
  # ------------------
315
- # Training loop
316
  # ------------------
317
- rng = random.PRNGKey(SEED)
318
- rng, init_rng = random.split(rng)
319
- model = ReLM(vocab_size=vocab_size, max_seq_len=SEQ_LEN, d_model=512, n_layers=9, dtype=DTYPE)
320
- state = create_train_state(init_rng, model, LEARNING_RATE)
321
-
322
- # replicate to devices
323
- state = jax.device_put_replicated(state, jax.local_devices())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- print("Starting training...")
 
 
 
 
 
 
 
326
 
327
- global_step = 0
328
  for epoch in range(EPOCHS):
329
  print(f"Epoch {epoch+1}/{EPOCHS}")
330
- np_rng = np.random.default_rng(SEED + epoch)
331
- batch_iter = create_batch_iter(inputs, targets, GLOBAL_BATCH, np_rng)
332
- pbar = tqdm.tqdm(batch_iter, total= max(1, inputs.shape[0] // GLOBAL_BATCH))
333
-
334
- for batch_x, batch_y in pbar:
335
- # shard
336
- batch_x = shard(batch_x)
337
- batch_y = shard(batch_y)
338
- rng, step_rng = random.split(rng)
339
- # make per-device rngs
340
- step_rngs = random.split(step_rng, NUM_DEVICES)
341
- state, metrics = train_step(state, batch_x, batch_y, step_rngs)
342
- # metrics are per-device; take first replica
343
- m = jax.tree_util.tree_map(lambda x: x[0], metrics)
344
- pbar.set_postfix(loss=float(m["loss"]), ppl=float(m["ppl"]))
345
- global_step += 1
346
 
347
  # ------------------
348
- # Save params
349
  # ------------------
350
- save_dir = "./checkpoints"
351
- os.makedirs(save_dir, exist_ok=True)
352
- # save using flax.serialization via checkpoints
353
- checkpoints.save_checkpoint(save_dir, jax.tree_map(lambda x: np.array(x), state), step=global_step, keep=3)
354
- print("Saved checkpoint to", save_dir)
355
 
356
  # ------------------
357
- # Sampling (top-p) - single-device (CPU) sampling for simplicity
358
  # ------------------
359
- import math
360
-
361
- def top_p_sample_logits(rng, logits, p=0.9, temperature=1.0):
362
- # logits: (vocab,)
363
- probs = jax.nn.softmax(logits / temperature)
364
- # convert to numpy for sorting (ok for single token)
365
- probs_np = np.array(probs)
366
- sorted_idx = np.argsort(probs_np)[::-1]
367
- sorted_probs = probs_np[sorted_idx]
368
- cum = np.cumsum(sorted_probs)
369
- cutoff = np.searchsorted(cum, p)
370
- top_idx = sorted_idx[: cutoff + 1]
371
- top_probs = sorted_probs[: cutoff + 1]
372
- top_probs = top_probs / top_probs.sum()
373
- # sample
374
- next_token = np.random.choice(top_idx, p=top_probs)
375
- return int(next_token)
376
-
377
- def generate_text(state, prompt: str, max_gen=256, p=0.9, temperature=0.8, min_len=20):
378
- # load params from replicated state (take first replica)
379
- params = jax.tree_map(lambda x: np.array(x[0]), state.params)
380
- tokens = sp.encode("<start> " + prompt, out_type=int)
381
- generated = tokens.copy()
382
- for step in range(max_gen):
383
- cur = generated[-SEQ_LEN:]
384
- if len(cur) < SEQ_LEN:
385
- cur = cur + [pad_id] * (SEQ_LEN - len(cur))
386
- x = np.array([cur], dtype=np.int32)
387
- logits = model.apply({"params": params}, x, deterministic=True) # (1, seq, vocab)
388
- logits = np.array(logits[0, len(generated)-1 if len(generated)-1 < SEQ_LEN else SEQ_LEN-1])
389
- # penalize end/pad a bit
390
- logits[end_id] -= 5.0
391
- logits[pad_id] -= 10.0
392
- next_id = top_p_sample_logits(None, logits, p=p, temperature=temperature)
393
- generated.append(next_id)
394
- if next_id == end_id and len(generated) >= min_len:
395
- break
396
- return sp.decode(generated)
397
-
398
- # quick generate
399
  print("\n\n===== 생성 결과 =====")
400
- print(generate_text(state, "지난 2년 동안 출연연이 국가가 필요한 연구를", p=0.9))
401
-
 
1
+ # TPU 최적화 Flax + JAX ReLM
2
+ import os, math, numpy as np, sentencepiece as spm, requests, tqdm
 
 
 
 
 
 
3
  from functools import partial
4
+ from typing import Any
5
+ import jax, jax.numpy as jnp
 
 
6
  from jax import random
7
  from flax import linen as nn
8
  from flax.training import train_state, checkpoints
9
  import optax
10
+ import requests
11
 
12
  def download_file(url, save_path):
13
  r = requests.get(url, stream=True)
 
20
  # Config
21
  # ------------------
22
  SEQ_LEN = 512
 
23
  GLOBAL_BATCH = 256
24
+ LIMIT = 200_000
 
25
  VOCAB_MODEL = "ko_unigram.model"
26
  CORPUS_PATH = "corpus.txt"
 
27
  SEED = 42
28
  LEARNING_RATE = 1e-4
29
  EPOCHS = 1
 
40
  VOCAB_MODEL
41
  )
42
 
43
+ DTYPE = jnp.bfloat16 if jax.local_devices()[0].platform == "tpu" else jnp.float32
44
  NUM_DEVICES = jax.device_count()
 
45
  PER_DEVICE_BATCH = GLOBAL_BATCH // NUM_DEVICES
46
+ print("devices:", jax.devices(), "dtype:", DTYPE)
 
 
47
 
48
  # ------------------
49
+ # Tokenizer
50
  # ------------------
51
  sp = spm.SentencePieceProcessor()
52
  sp.load(VOCAB_MODEL)
53
+ pad_id = sp.piece_to_id("<pad>") if sp.piece_to_id("<pad>")!=-1 else 0
 
54
  start_id = sp.piece_to_id("<start>")
55
  end_id = sp.piece_to_id("<end>")
56
  vocab_size = sp.get_piece_size()
57
  print("vocab_size:", vocab_size, "pad_id:", pad_id, "start_id:", start_id, "end_id:", end_id)
58
 
59
  # ------------------
60
+ # Data pipeline
 
 
 
61
  # ------------------
62
+ def line_to_ids(line, max_len=SEQ_LEN):
63
  ids = sp.encode(line.strip(), out_type=int)
64
+ if len(ids) > max_len-1: ids = ids[:max_len-1]
65
+ ids += [end_id] + [pad_id]*(max_len-len(ids)-1)
 
 
 
66
  return np.array(ids, dtype=np.int32)
67
 
68
+ def build_dataset(corpus_path, limit=LIMIT):
69
  arr = []
70
  with open(corpus_path, "r", encoding="utf-8") as f:
71
  for i, line in enumerate(f):
72
+ if i>=limit: break
73
+ line=line.strip()
74
+ if not line: continue
 
 
75
  arr.append(line_to_ids(line))
76
+ data = np.stack(arr, axis=0)
77
+ print("Loaded dataset:", data.shape)
78
  return data
79
 
 
80
  data_np = build_dataset(CORPUS_PATH, LIMIT)
81
  inputs = data_np
82
+ targets = np.concatenate([data_np[:,1:], np.full((data_np.shape[0],1), pad_id, np.int32)], axis=1)
83
 
84
+ def create_batch_iter(inputs, targets, batch_size, rng):
85
+ idx = np.arange(inputs.shape[0]); rng.shuffle(idx)
86
+ for i in range(0,len(idx)-batch_size+1,batch_size):
 
 
87
  batch_idx = idx[i:i+batch_size]
88
+ yield inputs[batch_idx], targets[batch_idx]
 
 
89
 
90
+ def shard(xs): return xs.reshape(NUM_DEVICES, -1, xs.shape[1])
 
 
91
 
92
  # ------------------
93
+ # Model
94
  # ------------------
95
  class SwiGLU(nn.Module):
96
  d_model: int
97
+ dtype: Any = DTYPE
98
  @nn.compact
99
+ def __call__(self,x):
100
+ proj = nn.Dense(self.d_model*2,dtype=self.dtype)(x)
101
+ x_val, x_gate = jnp.split(proj,2,-1)
 
102
  out = x_val * nn.silu(x_gate)
103
+ return nn.Dense(self.d_model,dtype=self.dtype)(out)
 
104
 
105
  class LoU(nn.Module):
106
+ d_model:int
107
+ dtype:Any=DTYPE
 
 
108
  @nn.compact
109
+ def __call__(self,x):
110
+ residual = x
111
+ x_norm = nn.LayerNorm(epsilon=1e-5,dtype=self.dtype)(x)
112
+ Q=nn.Dense(self.d_model,dtype=self.dtype)
113
+ K=nn.Dense(self.d_model,dtype=self.dtype)
114
+ V=nn.Dense(self.d_model,dtype=self.dtype)
115
+ q,k,v = Q(x_norm),K(x_norm),V(x_norm)
116
+ g_q = (jnp.tanh(q)+1)/2; g_k=(jnp.tanh(k)+1)/2
117
+ score = g_q*g_k
118
+ alpha_dynamic = nn.Dense(1,dtype=self.dtype)(x_norm)
119
+ # EMA scan along seq axis
120
+ score_t = jnp.transpose(score,(1,0,2))
121
+ alpha_t = jnp.transpose(alpha_dynamic,(1,0,2))
122
+ def step(prev,cur): s,a=cur; new=a*s+(1-a)*prev; return new,new
123
+ init = score_t[0]; _,ema_seq=jax.lax.scan(step,init,(score_t[1:],alpha_t[1:]))
124
+ ema_full=jnp.concatenate([init[None,...],ema_seq],0)
125
+ ema = jnp.transpose(ema_full,(1,0,2))
126
+ out = v*ema + residual
127
+ out = nn.LayerNorm(epsilon=1e-5,dtype=self.dtype)(out)
128
+ return SwiGLU(self.d_model,self.dtype)(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  class Lo(nn.Module):
131
+ d_model:int
132
+ dtype:Any=DTYPE
133
  @nn.compact
134
+ def __call__(self,x):
135
+ h=nn.Dense(64,dtype=self.dtype)(x); h=nn.silu(h)
136
+ h=nn.Dense(self.d_model,dtype=self.dtype)(h)
137
+ return nn.LayerNorm(epsilon=1e-5,dtype=self.dtype)(h)+x
 
 
138
 
139
  class Block(nn.Module):
140
+ d_model:int
141
+ dtype:Any=DTYPE
142
  @nn.compact
143
+ def __call__(self,x):
144
+ x=LoU(self.d_model,self.dtype)(x)
145
+ x=Lo(self.d_model,self.dtype)(x)
146
  return x
147
 
148
  class ReLM(nn.Module):
149
+ vocab_size:int; max_seq_len:int; d_model:int; n_layers:int; dtype:Any=DTYPE
 
 
 
 
 
150
  def setup(self):
151
+ self.token_embed = nn.Embed(self.vocab_size,self.d_model,dtype=self.dtype)
152
+ self.pos_embed = nn.Embed(self.max_seq_len,self.d_model,dtype=self.dtype)
153
+ self.blocks=[Block(self.d_model,self.dtype) for _ in range(self.n_layers)]
154
+ self.ln_f=nn.LayerNorm(epsilon=1e-5,dtype=self.dtype)
155
+ def __call__(self,x,deterministic=True):
156
+ b,seq=x.shape
157
+ pos=jnp.arange(seq)[None,:]
158
+ x=self.token_embed(x)+self.pos_embed(pos)
159
+ for blk in self.blocks: x=blk(x)
160
+ x=self.ln_f(x)
161
+ logits=jnp.einsum("bld,vd->blv",x,self.token_embed.embedding)
162
+ return logits
 
 
 
 
 
163
 
164
  # ------------------
165
  # Loss & metrics
166
  # ------------------
167
+ def smoothed_ce(logits,targets,pad_id,eps=0.1):
168
+ vocab=logits.shape[-1]
169
+ logits=logits.reshape(-1,vocab)
170
+ targets=targets.reshape(-1)
171
+ mask=(targets!=pad_id).astype(jnp.float32)
172
+ one_hot=jax.nn.one_hot(targets,vocab)
173
+ smooth=(1-eps)*one_hot+eps/vocab
174
+ log_probs=jax.nn.log_softmax(logits)
175
+ loss=-jnp.sum(smooth*log_probs,axis=-1)*mask
176
+ return jnp.sum(loss)/(jnp.sum(mask)+1e-8)
177
+
178
+ def masked_ppl(logits,targets,pad_id,eps=0.1):
179
+ vocab=logits.shape[-1]
180
+ logits=logits.reshape(-1,vocab)
181
+ targets=targets.reshape(-1)
182
+ mask=(targets!=pad_id).astype(jnp.float32)
183
+ one_hot=jax.nn.one_hot(targets,vocab)
184
+ smooth=(1-eps)*one_hot+eps/vocab
185
+ loss=-jnp.sum(smooth*jax.nn.log_softmax(logits),axis=-1)*mask
186
+ return jnp.exp(jnp.sum(loss)/(jnp.sum(mask)+1e-8))
 
 
 
 
 
 
 
 
187
 
188
  # ------------------
189
+ # Train state
190
  # ------------------
191
+ class TrainState(train_state.TrainState): pass
192
+ def create_train_state(rng,model,lr):
193
+ params=model.init(rng,jnp.zeros((1,SEQ_LEN),dtype=jnp.int32))["params"]
194
+ tx=optax.chain(optax.clip_by_global_norm(1.0),optax.adamw(lr,b1=0.9,b2=0.95,eps=1e-8))
195
+ return TrainState.create(apply_fn=model.apply,params=params,tx=tx)
 
 
 
 
 
196
 
197
  # ------------------
198
+ # pmap step
199
  # ------------------
200
  @partial(jax.pmap, axis_name="batch")
201
+ def train_step(state,bx,by,rngs):
202
  def loss_fn(params):
203
+ logits=state.apply_fn({"params":params},bx,deterministic=False)
204
+ return smoothed_ce(logits,by,pad_id),logits
205
+ (loss,logits),grads=jax.value_and_grad(loss_fn,has_aux=True)(state.params)
206
+ grads=jax.lax.pmean(grads,"batch")
207
+ state=state.apply_gradients(grads=grads)
208
+ metrics={"loss":loss,"ppl":masked_ppl(logits,by,pad_id)}
209
+ metrics=jax.lax.pmean(metrics,"batch")
210
+ return state,metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  # ------------------
213
+ # Top-p sampling (JAX-native)
214
  # ------------------
215
+ def top_p_sample(rng, logits, p=0.9, temperature=1.0):
216
+ probs=jax.nn.softmax(logits/temperature)
217
+ sorted_probs,sorted_idx=jax.lax.top_k(probs,logits.shape[-1])
218
+ cum_probs=jnp.cumsum(sorted_probs)
219
+ mask=cum_probs<=p
220
+ top_probs=jnp.where(mask,sorted_probs,0.0)
221
+ top_probs=top_probs/jnp.sum(top_probs)
222
+ return int(sorted_idx[jax.random.categorical(rng,jnp.log(top_probs))])
223
+
224
+ def generate_text(state,prompt,max_gen=256,p=0.9,temperature=0.8,min_len=20):
225
+ params=jax.tree_map(lambda x: np.array(x[0]),state.params)
226
+ tokens=sp.encode("<start> "+prompt,out_type=int)
227
+ generated=tokens.copy()
228
+ rng=random.PRNGKey(SEED)
229
+ for step in range(max_gen):
230
+ cur=generated[-SEQ_LEN:]
231
+ if len(cur)<SEQ_LEN: cur=cur+[pad_id]*(SEQ_LEN-len(cur))
232
+ x=jnp.array([cur],dtype=jnp.int32)
233
+ logits=model.apply({"params":params},x,deterministic=True)[0,len(generated)-1]
234
+ logits=logits.at[end_id].add(-5.0).at[pad_id].add(-10.0)
235
+ next_id=top_p_sample(rng,logits,p,temperature)
236
+ generated.append(next_id)
237
+ if next_id==end_id and len(generated)>=min_len: break
238
+ return sp.decode(generated)
239
 
240
+ # ------------------
241
+ # Training
242
+ # ------------------
243
+ rng=random.PRNGKey(SEED)
244
+ rng,init_rng=random.split(rng)
245
+ model=ReLM(vocab_size=vocab_size,max_seq_len=SEQ_LEN,d_model=512,n_layers=9,dtype=DTYPE)
246
+ state=create_train_state(init_rng,model,LEARNING_RATE)
247
+ state=jax.device_put_replicated(state,jax.local_devices())
248
 
249
+ global_step=0
250
  for epoch in range(EPOCHS):
251
  print(f"Epoch {epoch+1}/{EPOCHS}")
252
+ np_rng=np.random.default_rng(SEED+epoch)
253
+ batch_iter=create_batch_iter(inputs,targets,GLOBAL_BATCH,np_rng)
254
+ pbar=tqdm.tqdm(batch_iter,total=max(1,inputs.shape[0]//GLOBAL_BATCH))
255
+ for bx,by in pbar:
256
+ bx_sh,by_sh=shard(bx),shard(by)
257
+ state,metrics=train_step(state,bx_sh,by_sh,jax.random.split(rng,NUM_DEVICES))
258
+ m=jax.tree_util.tree_map(lambda x:x[0],metrics)
259
+ pbar.set_postfix(loss=float(m["loss"]),ppl=float(m["ppl"]))
260
+ global_step+=1
 
 
 
 
 
 
 
261
 
262
  # ------------------
263
+ # Save
264
  # ------------------
265
+ save_dir="./checkpoints"
266
+ os.makedirs(save_dir,exist_ok=True)
267
+ checkpoints.save_checkpoint(save_dir,jax.tree_map(lambda x:np.array(x),state),step=global_step,keep=3)
268
+ print("Saved checkpoint to",save_dir)
 
269
 
270
  # ------------------
271
+ # Generate
272
  # ------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  print("\n\n===== 생성 결과 =====")
274
+ print(generate_text(state,"지난 2년 동안 출연연이 국가가 필요한 연구를",p=0.9))