Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
|
@@ -166,8 +166,30 @@ def main():
|
|
| 166 |
model_name = "datalab-to/chandra"
|
| 167 |
batch_size = 20
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
if not os.path.exists(preprocessed_dataset):
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
print("Loading preprocessed dataset...")
|
| 173 |
ds = datasets.load_from_disk(preprocessed_dataset)
|
|
@@ -220,9 +242,16 @@ def main():
|
|
| 220 |
shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
|
| 221 |
final_ds = datasets.concatenate_datasets(shards)
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
print("Cleaning up temporary files...")
|
| 228 |
for f in temp_files:
|
|
|
|
| 166 |
model_name = "datalab-to/chandra"
|
| 167 |
batch_size = 20
|
| 168 |
|
| 169 |
+
init_flag = os.environ.get("INIT", "0")
|
| 170 |
+
is_first_run = init_flag == "0"
|
| 171 |
+
is_second_run = init_flag == "1"
|
| 172 |
+
|
| 173 |
if not os.path.exists(preprocessed_dataset):
|
| 174 |
+
print(f"[{'First' if is_first_run else 'Second'} Run] Running preprocessing...")
|
| 175 |
+
ds_full = datasets.load_dataset(input_dataset, split="train")
|
| 176 |
+
total_size = len(ds_full)
|
| 177 |
+
midpoint = total_size // 2
|
| 178 |
+
|
| 179 |
+
if is_first_run:
|
| 180 |
+
ds_to_process = ds_full.select(range(0, midpoint))
|
| 181 |
+
else:
|
| 182 |
+
ds_to_process = ds_full.select(range(midpoint, total_size))
|
| 183 |
+
|
| 184 |
+
print(
|
| 185 |
+
f"[{'First' if is_first_run else 'Second'} Run] Saving selected shard to disk..."
|
| 186 |
+
)
|
| 187 |
+
ds_to_process.save_to_disk("temp_input_shard")
|
| 188 |
+
|
| 189 |
+
run_preprocessing("temp_input_shard", preprocessed_dataset)
|
| 190 |
+
|
| 191 |
+
# Clean up temp input shard
|
| 192 |
+
shutil.rmtree("temp_input_shard")
|
| 193 |
|
| 194 |
print("Loading preprocessed dataset...")
|
| 195 |
ds = datasets.load_from_disk(preprocessed_dataset)
|
|
|
|
| 242 |
shards = [cast(Dataset, datasets.load_from_disk(f)) for f in temp_files]
|
| 243 |
final_ds = datasets.concatenate_datasets(shards)
|
| 244 |
|
| 245 |
+
if is_first_run:
|
| 246 |
+
print("First run: pushing first half to hub...")
|
| 247 |
+
final_ds.push_to_hub(output_dataset, create_pr=False)
|
| 248 |
+
else:
|
| 249 |
+
print("Second run: loading first half and merging...")
|
| 250 |
+
first_half_ds = datasets.load_dataset(output_dataset, split="train")
|
| 251 |
+
merged_ds = datasets.concatenate_datasets([first_half_ds, final_ds])
|
| 252 |
+
print(f"Final merged dataset size: {len(merged_ds)}")
|
| 253 |
+
print("Pushing full dataset with create_pr=True...")
|
| 254 |
+
merged_ds.push_to_hub(output_dataset, create_pr=True)
|
| 255 |
|
| 256 |
print("Cleaning up temporary files...")
|
| 257 |
for f in temp_files:
|