nroggendorff commited on
Commit
84f4c93
·
verified ·
1 Parent(s): c446fbc

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +33 -4
train.py CHANGED
@@ -166,8 +166,30 @@ def main():
166
  model_name = "datalab-to/chandra"
167
  batch_size = 20
168
 
 
 
 
 
169
  if not os.path.exists(preprocessed_dataset):
170
- run_preprocessing(input_dataset, preprocessed_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  print("Loading preprocessed dataset...")
173
  ds = datasets.load_from_disk(preprocessed_dataset)
@@ -220,9 +242,16 @@ def main():
220
  shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
221
  final_ds = datasets.concatenate_datasets(shards)
222
 
223
- print(f"Final dataset size: {len(final_ds)}")
224
- print("Pushing to hub...")
225
- final_ds.push_to_hub(output_dataset, create_pr=False)
 
 
 
 
 
 
 
226
 
227
  print("Cleaning up temporary files...")
228
  for f in temp_files:
 
166
  model_name = "datalab-to/chandra"
167
  batch_size = 20
168
 
169
+ init_flag = os.environ.get("INIT", "0")
170
+ is_first_run = init_flag == "0"
171
+ is_second_run = init_flag == "1"
172
+
173
  if not os.path.exists(preprocessed_dataset):
174
+ print(f"[{'First' if is_first_run else 'Second'} Run] Running preprocessing...")
175
+ ds_full = datasets.load_dataset(input_dataset, split="train")
176
+ total_size = len(ds_full)
177
+ midpoint = total_size // 2
178
+
179
+ if is_first_run:
180
+ ds_to_process = ds_full.select(range(0, midpoint))
181
+ else:
182
+ ds_to_process = ds_full.select(range(midpoint, total_size))
183
+
184
+ print(
185
+ f"[{'First' if is_first_run else 'Second'} Run] Saving selected shard to disk..."
186
+ )
187
+ ds_to_process.save_to_disk("temp_input_shard")
188
+
189
+ run_preprocessing("temp_input_shard", preprocessed_dataset)
190
+
191
+ # Clean up temp input shard
192
+ shutil.rmtree("temp_input_shard")
193
 
194
  print("Loading preprocessed dataset...")
195
  ds = datasets.load_from_disk(preprocessed_dataset)
 
242
  shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
243
  final_ds = datasets.concatenate_datasets(shards)
244
 
245
+ if is_first_run:
246
+ print("First run: pushing first half to hub...")
247
+ final_ds.push_to_hub(output_dataset, create_pr=False)
248
+ else:
249
+ print("Second run: loading first half and merging...")
250
+ first_half_ds = datasets.load_dataset(output_dataset, split="train")
251
+ merged_ds = datasets.concatenate_datasets([first_half_ds, final_ds])
252
+ print(f"Final merged dataset size: {len(merged_ds)}")
253
+ print("Pushing full dataset with create_pr=True...")
254
+ merged_ds.push_to_hub(output_dataset, create_pr=True)
255
 
256
  print("Cleaning up temporary files...")
257
  for f in temp_files: