Nyamdavaa Amar
commited on
Commit
·
cf49f13
1
Parent(s):
3d4d40d
Edit presets
Browse files- adaptive_schedule.py +25 -50
- app.py +18 -16
- description1.md +4 -0
- description2.md +1 -0
- description3.md +0 -0
- interleaved_variant.py +9 -14
- schedule1f1bv.py +12 -12
adaptive_schedule.py
CHANGED
|
@@ -46,9 +46,9 @@ def transform_schedule(schedule, f, b, w, c):
|
|
| 46 |
time = 0
|
| 47 |
if (stage, type, mb) in local_prev:
|
| 48 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
| 49 |
-
if type in
|
| 50 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
| 51 |
-
if type in
|
| 52 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
| 53 |
# print(f'{stage} {type}:{mb}', time + cost[type])
|
| 54 |
time_map[(stage, type, mb)] = time + cost[type]
|
|
@@ -63,7 +63,7 @@ def transform_schedule(schedule, f, b, w, c):
|
|
| 63 |
for p, mb in stage:
|
| 64 |
result_stage.append(ScheduledNode(
|
| 65 |
p.upper(),
|
| 66 |
-
p in
|
| 67 |
sid,
|
| 68 |
mb,
|
| 69 |
get_time(sid, p, mb) - cost[p],
|
|
@@ -110,9 +110,9 @@ def evaluate_schedule(schedule, f, b, w, c):
|
|
| 110 |
time = 0
|
| 111 |
if (stage, type, mb) in local_prev:
|
| 112 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
| 113 |
-
if type in
|
| 114 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
| 115 |
-
if type in
|
| 116 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
| 117 |
# print(f'{stage} {type}:{mb}', time + cost[type])
|
| 118 |
time_map[(stage, type, mb)] = time + cost[type]
|
|
@@ -153,16 +153,6 @@ def get_peak_mem(schedules, return_all=False):
|
|
| 153 |
return all_peak
|
| 154 |
return max_peak
|
| 155 |
|
| 156 |
-
debug = False
|
| 157 |
-
def print_schedules(schedules):
|
| 158 |
-
if not debug:
|
| 159 |
-
return
|
| 160 |
-
for seq in schedules:
|
| 161 |
-
_str = ""
|
| 162 |
-
for v in seq:
|
| 163 |
-
_str += v
|
| 164 |
-
print(_str)
|
| 165 |
-
|
| 166 |
|
| 167 |
def calc_bubble(schedules):
|
| 168 |
stage_bubbles = []
|
|
@@ -199,8 +189,8 @@ def clear_invalid(repeated, stage, pos, offset=-1):
|
|
| 199 |
def clear_invalid_index(repeated, m):
|
| 200 |
p = len(repeated)
|
| 201 |
index = pattern_size
|
| 202 |
-
for identifier in
|
| 203 |
-
if identifier in
|
| 204 |
_iter = range(p)
|
| 205 |
else:
|
| 206 |
_iter = range(p - 1, -1, -1)
|
|
@@ -210,7 +200,7 @@ def clear_invalid_index(repeated, m):
|
|
| 210 |
clear_invalid(repeated, i, index - pattern_size, offset=-1)
|
| 211 |
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
|
| 212 |
index += 1
|
| 213 |
-
if identifier in
|
| 214 |
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
|
| 215 |
for k in range(pattern_size):
|
| 216 |
if repeated[i][index + k] == w_identifier:
|
|
@@ -386,6 +376,17 @@ def squeeze_without_change_order(schedules, m):
|
|
| 386 |
identifier_cnt[i][identifier] += 1
|
| 387 |
identifier_index[_cnt * p + i][identifier] = index
|
| 388 |
stage_index[i] = index + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
return squeezed
|
| 390 |
|
| 391 |
|
|
@@ -485,12 +486,11 @@ def process_cooldown(schedules, m):
|
|
| 485 |
return schedules
|
| 486 |
|
| 487 |
|
| 488 |
-
def schedule_by_pattern(p, m, patterns):
|
| 489 |
schedules = init_repeated_schedule(p, max(m, 2 * p), patterns)
|
| 490 |
schedules = clear_invalid_index(schedules, max(m, 2 * p))
|
| 491 |
-
print_schedules(schedules)
|
| 492 |
init_peak_mem = get_peak_mem(schedules)
|
| 493 |
-
if init_peak_mem >
|
| 494 |
return None, init_peak_mem, [6 * max(m, 2 * p)] * p
|
| 495 |
schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p))
|
| 496 |
|
|
@@ -503,20 +503,16 @@ def schedule_by_pattern(p, m, patterns):
|
|
| 503 |
schedules[sid][i] = ' '
|
| 504 |
else:
|
| 505 |
cnt[schedules[sid][i]] += 1
|
| 506 |
-
print_schedules(schedules)
|
| 507 |
peak_mem = get_peak_mem(schedules)
|
| 508 |
if peak_mem > init_peak_mem:
|
| 509 |
return None, init_peak_mem, [6 * m] * p
|
| 510 |
|
| 511 |
schedules = squeeze_without_change_order(schedules, m)
|
| 512 |
-
print_schedules(schedules)
|
| 513 |
|
| 514 |
schedules = process_cooldown(schedules, m)
|
| 515 |
-
print_schedules(schedules)
|
| 516 |
peak_mem = get_peak_mem(schedules)
|
| 517 |
if peak_mem > init_peak_mem:
|
| 518 |
return None, init_peak_mem, [6 * m] * p
|
| 519 |
-
|
| 520 |
stage_bubbles = calc_bubble(schedules)
|
| 521 |
return schedules, peak_mem, stage_bubbles
|
| 522 |
|
|
@@ -572,25 +568,8 @@ def schedule(p, m, cost, max_mem):
|
|
| 572 |
pattern = [0, ff_i, b_i, bb_i, -1, -1]
|
| 573 |
pattern = fill_w_in_pattern(pattern)
|
| 574 |
available_patterns.append(pattern)
|
| 575 |
-
available_offsets = []
|
| 576 |
-
for f_o in range(1, pattern_size + 1):
|
| 577 |
-
for ff_o in range(1, pattern_size + 1):
|
| 578 |
-
for b_o in range(1, pattern_size + 1):
|
| 579 |
-
if f_o != b_o:
|
| 580 |
-
continue
|
| 581 |
-
bb_o = ff_o + b_o - f_o
|
| 582 |
-
if bb_o < 1 or bb_o > pattern_size:
|
| 583 |
-
continue
|
| 584 |
-
if bb_o + ff_o + b_o + f_o > 2 * pattern_size:
|
| 585 |
-
continue
|
| 586 |
-
# if bb_o + ff_o + b_o + f_o != 6:
|
| 587 |
-
# continue
|
| 588 |
-
offset = [f_o, - ff_o, b_o, - bb_o]
|
| 589 |
-
if min(ff_o, bb_o) > 1:
|
| 590 |
-
continue
|
| 591 |
-
available_offsets.append(offset)
|
| 592 |
|
| 593 |
-
print(
|
| 594 |
available_offsets = [
|
| 595 |
[1, -1, 1, -1],
|
| 596 |
[2, -1, 2, -1],
|
|
@@ -601,7 +580,6 @@ def schedule(p, m, cost, max_mem):
|
|
| 601 |
|
| 602 |
best_schedule = None
|
| 603 |
best_bubble = None
|
| 604 |
-
peak_mem2min_bubble = {}
|
| 605 |
for pattern_0 in available_patterns:
|
| 606 |
for i_0 in range(len(available_offsets)):
|
| 607 |
for i_1 in range(i_0 + 1):
|
|
@@ -611,13 +589,10 @@ def schedule(p, m, cost, max_mem):
|
|
| 611 |
whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p)
|
| 612 |
if whole_pattern is None:
|
| 613 |
continue
|
| 614 |
-
|
| 615 |
-
# print(get_pattern_str(pattern))
|
| 616 |
-
# print(offset)
|
| 617 |
-
s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern)
|
| 618 |
-
if s is None:
|
| 619 |
-
continue
|
| 620 |
if peak_mem > 2 * p or peak_mem > max_mem:
|
|
|
|
|
|
|
| 621 |
continue
|
| 622 |
max_bubble = max(bubbles)
|
| 623 |
max_bubble = evaluate_schedule(s, *cost)
|
|
|
|
| 46 |
time = 0
|
| 47 |
if (stage, type, mb) in local_prev:
|
| 48 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
| 49 |
+
if type in "FB" and stage > 0:
|
| 50 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
| 51 |
+
if type in "fb" and stage + 1< len(schedule):
|
| 52 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
| 53 |
# print(f'{stage} {type}:{mb}', time + cost[type])
|
| 54 |
time_map[(stage, type, mb)] = time + cost[type]
|
|
|
|
| 63 |
for p, mb in stage:
|
| 64 |
result_stage.append(ScheduledNode(
|
| 65 |
p.upper(),
|
| 66 |
+
p in "fBW",
|
| 67 |
sid,
|
| 68 |
mb,
|
| 69 |
get_time(sid, p, mb) - cost[p],
|
|
|
|
| 110 |
time = 0
|
| 111 |
if (stage, type, mb) in local_prev:
|
| 112 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
| 113 |
+
if type in "FB" and stage > 0:
|
| 114 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
| 115 |
+
if type in "fb" and stage + 1< len(schedule):
|
| 116 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
| 117 |
# print(f'{stage} {type}:{mb}', time + cost[type])
|
| 118 |
time_map[(stage, type, mb)] = time + cost[type]
|
|
|
|
| 153 |
return all_peak
|
| 154 |
return max_peak
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
def calc_bubble(schedules):
|
| 158 |
stage_bubbles = []
|
|
|
|
| 189 |
def clear_invalid_index(repeated, m):
|
| 190 |
p = len(repeated)
|
| 191 |
index = pattern_size
|
| 192 |
+
for identifier in "FfBb":
|
| 193 |
+
if identifier in "FB":
|
| 194 |
_iter = range(p)
|
| 195 |
else:
|
| 196 |
_iter = range(p - 1, -1, -1)
|
|
|
|
| 200 |
clear_invalid(repeated, i, index - pattern_size, offset=-1)
|
| 201 |
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
|
| 202 |
index += 1
|
| 203 |
+
if identifier in "Bb":
|
| 204 |
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
|
| 205 |
for k in range(pattern_size):
|
| 206 |
if repeated[i][index + k] == w_identifier:
|
|
|
|
| 376 |
identifier_cnt[i][identifier] += 1
|
| 377 |
identifier_index[_cnt * p + i][identifier] = index
|
| 378 |
stage_index[i] = index + 1
|
| 379 |
+
while True:
|
| 380 |
+
if(len(squeezed[0]) == 1):
|
| 381 |
+
break
|
| 382 |
+
allempty = True
|
| 383 |
+
for x in squeezed:
|
| 384 |
+
if x[-1] != ' ':
|
| 385 |
+
allempty = False
|
| 386 |
+
if allempty == False:
|
| 387 |
+
break
|
| 388 |
+
for x in squeezed:
|
| 389 |
+
del x[-1]
|
| 390 |
return squeezed
|
| 391 |
|
| 392 |
|
|
|
|
| 486 |
return schedules
|
| 487 |
|
| 488 |
|
| 489 |
+
def schedule_by_pattern(p, m, patterns, max_mem):
|
| 490 |
schedules = init_repeated_schedule(p, max(m, 2 * p), patterns)
|
| 491 |
schedules = clear_invalid_index(schedules, max(m, 2 * p))
|
|
|
|
| 492 |
init_peak_mem = get_peak_mem(schedules)
|
| 493 |
+
if init_peak_mem > max_mem:
|
| 494 |
return None, init_peak_mem, [6 * max(m, 2 * p)] * p
|
| 495 |
schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p))
|
| 496 |
|
|
|
|
| 503 |
schedules[sid][i] = ' '
|
| 504 |
else:
|
| 505 |
cnt[schedules[sid][i]] += 1
|
|
|
|
| 506 |
peak_mem = get_peak_mem(schedules)
|
| 507 |
if peak_mem > init_peak_mem:
|
| 508 |
return None, init_peak_mem, [6 * m] * p
|
| 509 |
|
| 510 |
schedules = squeeze_without_change_order(schedules, m)
|
|
|
|
| 511 |
|
| 512 |
schedules = process_cooldown(schedules, m)
|
|
|
|
| 513 |
peak_mem = get_peak_mem(schedules)
|
| 514 |
if peak_mem > init_peak_mem:
|
| 515 |
return None, init_peak_mem, [6 * m] * p
|
|
|
|
| 516 |
stage_bubbles = calc_bubble(schedules)
|
| 517 |
return schedules, peak_mem, stage_bubbles
|
| 518 |
|
|
|
|
| 568 |
pattern = [0, ff_i, b_i, bb_i, -1, -1]
|
| 569 |
pattern = fill_w_in_pattern(pattern)
|
| 570 |
available_patterns.append(pattern)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
|
| 572 |
+
print(len(available_patterns))
|
| 573 |
available_offsets = [
|
| 574 |
[1, -1, 1, -1],
|
| 575 |
[2, -1, 2, -1],
|
|
|
|
| 580 |
|
| 581 |
best_schedule = None
|
| 582 |
best_bubble = None
|
|
|
|
| 583 |
for pattern_0 in available_patterns:
|
| 584 |
for i_0 in range(len(available_offsets)):
|
| 585 |
for i_1 in range(i_0 + 1):
|
|
|
|
| 589 |
whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p)
|
| 590 |
if whole_pattern is None:
|
| 591 |
continue
|
| 592 |
+
s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern, min(2 * p, max_mem))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
if peak_mem > 2 * p or peak_mem > max_mem:
|
| 594 |
+
break
|
| 595 |
+
if s is None:
|
| 596 |
continue
|
| 597 |
max_bubble = max(bubbles)
|
| 598 |
max_bubble = evaluate_schedule(s, *cost)
|
app.py
CHANGED
|
@@ -136,9 +136,10 @@ with gr.Blocks() as demo:
|
|
| 136 |
gr.Markdown(open("description1.md").read())
|
| 137 |
gr.Markdown("# Pipeline Scheduler Playground")
|
| 138 |
presets = {
|
| 139 |
-
'
|
| 140 |
-
'Ideal Case': (
|
| 141 |
-
'
|
|
|
|
| 142 |
}
|
| 143 |
preset_buttons = {}
|
| 144 |
|
|
@@ -153,30 +154,30 @@ with gr.Blocks() as demo:
|
|
| 153 |
with gr.Group():
|
| 154 |
gr.Markdown("Basic Parameters")
|
| 155 |
with gr.Row():
|
| 156 |
-
p=gr.Number(label="Number of stages (p)", value=
|
| 157 |
-
m=gr.Number(label="Number of microbatches (m)", value=
|
| 158 |
with gr.Column(scale=2):
|
| 159 |
with gr.Group():
|
| 160 |
gr.Markdown("Costs. All costs are used as integers. For chunked schedules, this is the time of two virtual stages on a stage combined.")
|
| 161 |
-
with gr.Row():
|
| 162 |
-
f=gr.Number(label="Time of F", value=
|
| 163 |
-
b=gr.Number(label="Time of B", value=
|
| 164 |
-
w=gr.Number(label="Time of W", value=
|
| 165 |
-
c=gr.Number(label="Time of one P2P communication", value=
|
| 166 |
with gr.Group():
|
| 167 |
gr.Markdown("Activation memory limit.")
|
| 168 |
def update_mem(p, s, mem):
|
| 169 |
print("update")
|
| 170 |
if s == "custom":
|
| 171 |
return mem
|
| 172 |
-
if s == "V-Min":
|
| 173 |
return (p + 4) // 3
|
| 174 |
-
if s == "V-Half":
|
| 175 |
return (p + 2) // 2
|
| 176 |
-
if s == "V-ZB":
|
| 177 |
return p
|
| 178 |
assert False
|
| 179 |
-
memsel=gr.Radio(choices=["V-Min", "V-Half", "V-ZB", "custom"], value="V-Half")
|
| 180 |
mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For chunked schedules, this is relative to two virtual stages on a stage combined.", value=(p.value + 2) // 2, interactive=True, precision=0)
|
| 181 |
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
| 182 |
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
|
@@ -212,7 +213,7 @@ with gr.Blocks() as demo:
|
|
| 212 |
with gr.Column(scale=4):
|
| 213 |
schedule1f1bv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
| 214 |
with gr.Group():
|
| 215 |
-
gr.Markdown("
|
| 216 |
with gr.Row():
|
| 217 |
with gr.Column(scale=1):
|
| 218 |
type2_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
|
@@ -221,7 +222,7 @@ with gr.Blocks() as demo:
|
|
| 221 |
with gr.Column(scale=4):
|
| 222 |
type2_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
| 223 |
with gr.Group():
|
| 224 |
-
gr.Markdown("Interleaved 1F1B Schedule")
|
| 225 |
with gr.Row():
|
| 226 |
with gr.Column(scale=1):
|
| 227 |
interleaved_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
|
@@ -234,6 +235,7 @@ with gr.Blocks() as demo:
|
|
| 234 |
schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
|
| 235 |
type2_acceleration, type2_mem, type2_bubble, type2_image,
|
| 236 |
interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
|
|
|
|
| 237 |
|
| 238 |
for (k, v) in presets.items():
|
| 239 |
def update_preset(pb, p, m, f, b, w, c, mem):
|
|
|
|
| 136 |
gr.Markdown(open("description1.md").read())
|
| 137 |
gr.Markdown("# Pipeline Scheduler Playground")
|
| 138 |
presets = {
|
| 139 |
+
'Default Case': (4, 10, 100, 110, 90, 5, 'V-Half (1/2)'),
|
| 140 |
+
'Ideal Case': (4, 10, 20, 20, 20, 0, 'V-Min (1/3)'),
|
| 141 |
+
'Real Case': (4, 10, 1049, 1122, 903, 79, 'V-Half (1/2)'),
|
| 142 |
+
'Zero Bubble Case': (4, 10, 1049, 1122, 903, 79, 'V-ZB (1)')
|
| 143 |
}
|
| 144 |
preset_buttons = {}
|
| 145 |
|
|
|
|
| 154 |
with gr.Group():
|
| 155 |
gr.Markdown("Basic Parameters")
|
| 156 |
with gr.Row():
|
| 157 |
+
p=gr.Number(label="Number of stages (p)", value=4, interactive=True, precision=0)
|
| 158 |
+
m=gr.Number(label="Number of microbatches (m)", value=10, interactive=True, precision=0)
|
| 159 |
with gr.Column(scale=2):
|
| 160 |
with gr.Group():
|
| 161 |
gr.Markdown("Costs. All costs are used as integers. For chunked schedules, this is the time of two virtual stages on a stage combined.")
|
| 162 |
+
with gr.Row():
|
| 163 |
+
f=gr.Number(label="Time of F", value=100, interactive=True, precision=0)
|
| 164 |
+
b=gr.Number(label="Time of B", value=110, interactive=True, precision=0)
|
| 165 |
+
w=gr.Number(label="Time of W", value=90, interactive=True, precision=0)
|
| 166 |
+
c=gr.Number(label="Time of one P2P communication", value=5, interactive=True, precision=0)
|
| 167 |
with gr.Group():
|
| 168 |
gr.Markdown("Activation memory limit.")
|
| 169 |
def update_mem(p, s, mem):
|
| 170 |
print("update")
|
| 171 |
if s == "custom":
|
| 172 |
return mem
|
| 173 |
+
if s == "V-Min (1/3)":
|
| 174 |
return (p + 4) // 3
|
| 175 |
+
if s == "V-Half (1/2)":
|
| 176 |
return (p + 2) // 2
|
| 177 |
+
if s == "V-ZB (1)":
|
| 178 |
return p
|
| 179 |
assert False
|
| 180 |
+
memsel=gr.Radio(choices=["V-Min (1/3)", "V-Half (1/2)", "V-ZB (1)", "custom"], value="V-Half (1/2)")
|
| 181 |
mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For chunked schedules, this is relative to two virtual stages on a stage combined.", value=(p.value + 2) // 2, interactive=True, precision=0)
|
| 182 |
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
| 183 |
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
|
|
|
| 213 |
with gr.Column(scale=4):
|
| 214 |
schedule1f1bv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
| 215 |
with gr.Group():
|
| 216 |
+
gr.Markdown("Zero bubble schedule with 2/3 1F1B memory")
|
| 217 |
with gr.Row():
|
| 218 |
with gr.Column(scale=1):
|
| 219 |
type2_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
|
|
|
| 222 |
with gr.Column(scale=4):
|
| 223 |
type2_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
| 224 |
with gr.Group():
|
| 225 |
+
gr.Markdown("Variation of Interleaved 1F1B Schedule")
|
| 226 |
with gr.Row():
|
| 227 |
with gr.Column(scale=1):
|
| 228 |
interleaved_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
|
|
|
| 235 |
schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
|
| 236 |
type2_acceleration, type2_mem, type2_bubble, type2_image,
|
| 237 |
interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
|
| 238 |
+
gr.Markdown(open("description3.md").read())
|
| 239 |
|
| 240 |
for (k, v) in presets.items():
|
| 241 |
def update_preset(pb, p, m, f, b, w, c, mem):
|
description1.md
CHANGED
|
@@ -1,5 +1,9 @@
|
|
| 1 |
# Pipeline Parallellism with Controllable Memory
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
Check out our paper at [Arxiv](https://arxiv.org/abs/2405.15362).
|
| 4 |
|
| 5 |
Bubble Rate here is calculated as (1 - longest stage time/(F+B+W)/m).
|
|
|
|
| 1 |
# Pipeline Parallellism with Controllable Memory
|
| 2 |
|
| 3 |
+
Pipeline Parallelism with Controllable Memory creates a framework on designing pipeline schedules and uses the framework to find memory optimal efficient schedules.
|
| 4 |
+
|
| 5 |
+
From our findings, we need approximately 1/3 memory under ideal conditions (F, B and W have same runtime), and 1/2 memory to create zero bubble schedule in realistic scenarios (with the necessary condition being W + 2B ≥ 2F and W + 2F ≥ 2B ).
|
| 6 |
+
|
| 7 |
Check out our paper at [Arxiv](https://arxiv.org/abs/2405.15362).
|
| 8 |
|
| 9 |
Bubble Rate here is calculated as (1 - longest stage time/(F+B+W)/m).
|
description2.md
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
## Alternative schedules
|
| 2 |
|
| 3 |
By utilizing the building block, we can search for different types of schedules depending on the need. We illustrate few of them here below:
|
|
|
|
| 4 |
* 1F1B-V schedule without doing any B-W split.
|
| 5 |
* Schedule with 2/3rd 1F1B memory by utilising B-W split. Note that two microbatches are included in a single building block to avoid collision.
|
| 6 |
* Variation of interleaved 1F1B with lower memory
|
|
|
|
| 1 |
## Alternative schedules
|
| 2 |
|
| 3 |
By utilizing the building block, we can search for different types of schedules depending on the need. We illustrate few of them here below:
|
| 4 |
+
|
| 5 |
* 1F1B-V schedule without doing any B-W split.
|
| 6 |
* Schedule with 2/3rd 1F1B memory by utilising B-W split. Note that two microbatches are included in a single building block to avoid collision.
|
| 7 |
* Variation of interleaved 1F1B with lower memory
|
description3.md
ADDED
|
File without changes
|
interleaved_variant.py
CHANGED
|
@@ -65,11 +65,6 @@ def get_interleaved_variation(_p, _n, cost):
|
|
| 65 |
'B': _b+_w,
|
| 66 |
'b': _b+_w
|
| 67 |
}
|
| 68 |
-
pred = {
|
| 69 |
-
'f': 'F',
|
| 70 |
-
'B': 'f',
|
| 71 |
-
'b': 'B'
|
| 72 |
-
}
|
| 73 |
|
| 74 |
time_map = {}
|
| 75 |
def get_time(stage, type, minibatch):
|
|
@@ -78,16 +73,16 @@ def get_interleaved_variation(_p, _n, cost):
|
|
| 78 |
time = 0
|
| 79 |
if (stage, type, minibatch) in local_prev:
|
| 80 |
time = get_time(*local_prev[(stage, type, minibatch)])
|
| 81 |
-
if stage > 0 and type in
|
| 82 |
time = max(time, get_time(stage - 1, type, minibatch) + _c)
|
| 83 |
-
if stage == 0 and type
|
| 84 |
-
time = max(time, get_time(_p - 1,
|
| 85 |
-
if stage != _p - 1 and type in
|
| 86 |
time = max(time, get_time(stage + 1, type, minibatch) + _c)
|
| 87 |
-
if stage == _p - 1 and type
|
| 88 |
-
time = max(time, get_time(0,
|
| 89 |
-
if stage == _p - 1 and type
|
| 90 |
-
time = max(time, get_time(stage,
|
| 91 |
|
| 92 |
time_map[(stage, type, minibatch)] = time + cost[type]
|
| 93 |
return time_map[(stage, type, minibatch)]
|
|
@@ -97,7 +92,7 @@ def get_interleaved_variation(_p, _n, cost):
|
|
| 97 |
for type, minibatch in stage:
|
| 98 |
result_stage.append(ScheduledNode(
|
| 99 |
type.upper(),
|
| 100 |
-
type in
|
| 101 |
sid,
|
| 102 |
minibatch,
|
| 103 |
get_time(sid, type, minibatch) - cost[type],
|
|
|
|
| 65 |
'B': _b+_w,
|
| 66 |
'b': _b+_w
|
| 67 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
time_map = {}
|
| 70 |
def get_time(stage, type, minibatch):
|
|
|
|
| 73 |
time = 0
|
| 74 |
if (stage, type, minibatch) in local_prev:
|
| 75 |
time = get_time(*local_prev[(stage, type, minibatch)])
|
| 76 |
+
if stage > 0 and type in "Ff":
|
| 77 |
time = max(time, get_time(stage - 1, type, minibatch) + _c)
|
| 78 |
+
if stage == 0 and type == 'f':
|
| 79 |
+
time = max(time, get_time(_p - 1, 'F', minibatch) + _c)
|
| 80 |
+
if stage != _p - 1 and type in "Bb":
|
| 81 |
time = max(time, get_time(stage + 1, type, minibatch) + _c)
|
| 82 |
+
if stage == _p - 1 and type == 'b':
|
| 83 |
+
time = max(time, get_time(0, 'B', minibatch) + _c)
|
| 84 |
+
if stage == _p - 1 and type == 'B':
|
| 85 |
+
time = max(time, get_time(stage, 'f', minibatch))
|
| 86 |
|
| 87 |
time_map[(stage, type, minibatch)] = time + cost[type]
|
| 88 |
return time_map[(stage, type, minibatch)]
|
|
|
|
| 92 |
for type, minibatch in stage:
|
| 93 |
result_stage.append(ScheduledNode(
|
| 94 |
type.upper(),
|
| 95 |
+
type in "fBW",
|
| 96 |
sid,
|
| 97 |
minibatch,
|
| 98 |
get_time(sid, type, minibatch) - cost[type],
|
schedule1f1bv.py
CHANGED
|
@@ -44,9 +44,9 @@ def transform_schedule(schedule, f, b, w, c):
|
|
| 44 |
time = 0
|
| 45 |
if (stage, type, mb) in local_prev:
|
| 46 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
| 47 |
-
if type in
|
| 48 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
| 49 |
-
if type in
|
| 50 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
| 51 |
time_map[(stage, type, mb)] = time + cost[type]
|
| 52 |
return time_map[(stage, type, mb)]
|
|
@@ -59,7 +59,7 @@ def transform_schedule(schedule, f, b, w, c):
|
|
| 59 |
for p, mb in stage:
|
| 60 |
result_stage.append(ScheduledNode(
|
| 61 |
p.upper(),
|
| 62 |
-
p in
|
| 63 |
sid,
|
| 64 |
mb,
|
| 65 |
get_time(sid, p, mb) - cost[p],
|
|
@@ -104,8 +104,8 @@ def clear_invalid(repeated, stage, pos, offset=-1):
|
|
| 104 |
def clear_invalid_index(repeated, m):
|
| 105 |
p = len(repeated)
|
| 106 |
index = pattern_size
|
| 107 |
-
for identifier in
|
| 108 |
-
if identifier in
|
| 109 |
_iter = range(p)
|
| 110 |
else:
|
| 111 |
_iter = range(p - 1, -1, -1)
|
|
@@ -115,7 +115,7 @@ def clear_invalid_index(repeated, m):
|
|
| 115 |
clear_invalid(repeated, i, index - pattern_size, offset=-1)
|
| 116 |
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
|
| 117 |
index += 1
|
| 118 |
-
if identifier in
|
| 119 |
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
|
| 120 |
for k in range(pattern_size):
|
| 121 |
if repeated[i][index + k] == w_identifier:
|
|
@@ -135,9 +135,9 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
|
|
| 135 |
for sid in range(len(schedules)):
|
| 136 |
cur = 0
|
| 137 |
for i in range(len(schedules[sid])):
|
| 138 |
-
if schedules[sid][i] in
|
| 139 |
cur += 1
|
| 140 |
-
if schedules[sid][i] in
|
| 141 |
cur -= 1
|
| 142 |
mem[sid][i] = cur
|
| 143 |
peak_mem = max(peak_mem, cur)
|
|
@@ -177,16 +177,16 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
|
|
| 177 |
pos += 1
|
| 178 |
while schedules[sid][pos] != ' ' and pos < i:
|
| 179 |
pos += 1
|
| 180 |
-
if schedules[sid][i] in
|
| 181 |
while pos < i and (schedules[sid][pos] != ' ' or schedules[sid][pos + 1] != ' '):
|
| 182 |
pos += 1
|
| 183 |
if pos == i:
|
| 184 |
loc[sid][cnt][schedules[sid][i]] = i
|
| 185 |
continue
|
| 186 |
-
if schedules[sid][i] in
|
| 187 |
schedules[sid][pos] = schedules[sid][i]
|
| 188 |
schedules[sid][i] = ' '
|
| 189 |
-
if schedules[sid][pos] in
|
| 190 |
for j in range(pos, i):
|
| 191 |
mem[sid][j] -= 1
|
| 192 |
loc[sid][cnt][schedules[sid][pos]] = pos
|
|
@@ -265,7 +265,7 @@ def schedule(p, m, cost):
|
|
| 265 |
s = schedule_by_pattern(p, m, whole_pattern)
|
| 266 |
for sid in range(len(s)):
|
| 267 |
for i in range(len(s[sid])):
|
| 268 |
-
if s[sid][i] in
|
| 269 |
s[sid][i] = ' '
|
| 270 |
res = transform_schedule(s, *cost)
|
| 271 |
return res
|
|
|
|
| 44 |
time = 0
|
| 45 |
if (stage, type, mb) in local_prev:
|
| 46 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
| 47 |
+
if type in "FB"and stage > 0:
|
| 48 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
| 49 |
+
if type in "fb" and stage + 1< len(schedule):
|
| 50 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
| 51 |
time_map[(stage, type, mb)] = time + cost[type]
|
| 52 |
return time_map[(stage, type, mb)]
|
|
|
|
| 59 |
for p, mb in stage:
|
| 60 |
result_stage.append(ScheduledNode(
|
| 61 |
p.upper(),
|
| 62 |
+
p in "fBW",
|
| 63 |
sid,
|
| 64 |
mb,
|
| 65 |
get_time(sid, p, mb) - cost[p],
|
|
|
|
| 104 |
def clear_invalid_index(repeated, m):
|
| 105 |
p = len(repeated)
|
| 106 |
index = pattern_size
|
| 107 |
+
for identifier in "FfBb":
|
| 108 |
+
if identifier in "FB":
|
| 109 |
_iter = range(p)
|
| 110 |
else:
|
| 111 |
_iter = range(p - 1, -1, -1)
|
|
|
|
| 115 |
clear_invalid(repeated, i, index - pattern_size, offset=-1)
|
| 116 |
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
|
| 117 |
index += 1
|
| 118 |
+
if identifier in "Bb":
|
| 119 |
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
|
| 120 |
for k in range(pattern_size):
|
| 121 |
if repeated[i][index + k] == w_identifier:
|
|
|
|
| 135 |
for sid in range(len(schedules)):
|
| 136 |
cur = 0
|
| 137 |
for i in range(len(schedules[sid])):
|
| 138 |
+
if schedules[sid][i] in "Ff":
|
| 139 |
cur += 1
|
| 140 |
+
if schedules[sid][i] in "Ww":
|
| 141 |
cur -= 1
|
| 142 |
mem[sid][i] = cur
|
| 143 |
peak_mem = max(peak_mem, cur)
|
|
|
|
| 177 |
pos += 1
|
| 178 |
while schedules[sid][pos] != ' ' and pos < i:
|
| 179 |
pos += 1
|
| 180 |
+
if schedules[sid][i] in "Bb":
|
| 181 |
while pos < i and (schedules[sid][pos] != ' ' or schedules[sid][pos + 1] != ' '):
|
| 182 |
pos += 1
|
| 183 |
if pos == i:
|
| 184 |
loc[sid][cnt][schedules[sid][i]] = i
|
| 185 |
continue
|
| 186 |
+
if schedules[sid][i] in "BbWw":
|
| 187 |
schedules[sid][pos] = schedules[sid][i]
|
| 188 |
schedules[sid][i] = ' '
|
| 189 |
+
if schedules[sid][pos] in "Ww":
|
| 190 |
for j in range(pos, i):
|
| 191 |
mem[sid][j] -= 1
|
| 192 |
loc[sid][cnt][schedules[sid][pos]] = pos
|
|
|
|
| 265 |
s = schedule_by_pattern(p, m, whole_pattern)
|
| 266 |
for sid in range(len(s)):
|
| 267 |
for i in range(len(s[sid])):
|
| 268 |
+
if s[sid][i] in "Ww":
|
| 269 |
s[sid][i] = ' '
|
| 270 |
res = transform_schedule(s, *cost)
|
| 271 |
return res
|