| import gradio as gr | |
| import hand_schedule | |
| import adaptive_schedule | |
| import interleaved_variant | |
| import type2 | |
| import schedule1f1bv | |
| from PIL import Image | |
| from svg_event import render_manual_graph | |
| import pathlib | |
| def percentage(x): | |
| return f"{x*100:.2f}%" | |
| def get_schedule_time(result): | |
| result = [ | |
| list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result | |
| ] | |
| time = max( | |
| [ | |
| max([x.completion_time for x in stage]) - min([x.start_time for x in stage]) for stage in result | |
| ] | |
| ) | |
| return time | |
| def get_memory_usage(result): | |
| max_mem = 0 | |
| has_w = False | |
| for r in result: | |
| for x in r: | |
| if x.type in ('W', 'w'): | |
| has_w = True | |
| for r in result: | |
| cur = 0 | |
| for x in r: | |
| if x.type in ('F', 'f'): | |
| cur += 1 | |
| if x.type in ('W', 'w'): | |
| cur -= 1 | |
| if has_w == False and x.type in ('B', 'b'): | |
| cur -= 1 | |
| max_mem = max(max_mem, cur) | |
| return max_mem | |
| img_queue = [] | |
| def get_schedule_image(result, max_time): | |
| result = [ | |
| list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result | |
| ] | |
| svg = render_manual_graph(result, max_time, len(result[0]) <= 72) | |
| img_queue.append(svg) | |
| if len(img_queue) > 32: | |
| poped = img_queue.pop(0) | |
| pathlib.Path(poped).unlink() | |
| return pathlib.Path(svg) | |
| def calculate(p, m, f, b, w, c, mem): | |
| baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c) | |
| baseline_result = [ | |
| list(filter(lambda x: x.type in {'F', 'B'}, r)) for r in baseline_result | |
| ] | |
| baseline_time = get_schedule_time(baseline_result) | |
| baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1) | |
| baseline_mem = get_memory_usage(baseline_result) | |
| baseline_acceleration=percentage(0) | |
| adapt_result = adaptive_schedule.schedule( | |
| p, | |
| m, | |
| [f/2, b/2, w/2, c], | |
| max_mem=mem * 2 | |
| ) | |
| adapt_time = get_schedule_time(adapt_result) | |
| adapt_mem = get_memory_usage(adapt_result) / 2 | |
| adapt_bubble=percentage(adapt_time/(f+b+w)/m - 1) | |
| adapt_acceleration=percentage(baseline_time/adapt_time - 1) if baseline_time is not None else None | |
| schedule1f1bv_result = schedule1f1bv.schedule( | |
| p, | |
| m, | |
| [f / 2, b / 2, w / 2, c] | |
| ) | |
| schedule1f1bv_time = get_schedule_time(schedule1f1bv_result) | |
| schedule1f1bv_mem = get_memory_usage(schedule1f1bv_result) / 2 | |
| schedule1f1bv_bubble=percentage(schedule1f1bv_time/(f+b+w)/m - 1) | |
| schedule1f1bv_acceleration=percentage(baseline_time/schedule1f1bv_time - 1) if baseline_time is not None else None | |
| type2_result = type2.schedule( | |
| p, | |
| m, | |
| [f, b, w, c] | |
| ) | |
| type2_time = get_schedule_time(type2_result) | |
| type2_mem = get_memory_usage(type2_result) | |
| type2_bubble=percentage(type2_time/(f+b+w)/m - 1) | |
| type2_acceleration=percentage(baseline_time/type2_time - 1) if baseline_time is not None else None | |
| interleaved_result = interleaved_variant.get_interleaved_variation( | |
| p, | |
| m, | |
| [f/2, b/2, w/2, c] | |
| ) | |
| interleaved_time = get_schedule_time(interleaved_result) | |
| interleaved_mem = get_memory_usage(interleaved_result) / 2 | |
| interleaved_bubble=percentage(interleaved_time/(f+b+w)/m - 1) | |
| interleaved_acceleration=percentage(baseline_time/interleaved_time - 1) if baseline_time is not None else None | |
| max_time = max(filter(lambda x: x is not None, [baseline_time, adapt_time, interleaved_time, type2_time, schedule1f1bv_time])) | |
| print(max_time) | |
| if baseline_result is not None: | |
| baseline_image = get_schedule_image(baseline_result, max_time) | |
| if adapt_result is not None: | |
| adapt_image = get_schedule_image(adapt_result, max_time) | |
| if interleaved_result is not None: | |
| interleaved_image = get_schedule_image(interleaved_result, max_time) | |
| if type2_result is not None: | |
| type2_image = get_schedule_image(type2_result, max_time) | |
| if schedule1f1bv_result is not None: | |
| schedule1f1bv_image = get_schedule_image(schedule1f1bv_result, max_time) | |
| return [baseline_acceleration, baseline_mem, baseline_bubble, baseline_image, | |
| adapt_acceleration, adapt_mem, adapt_bubble, adapt_image, | |
| schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image, | |
| type2_acceleration, type2_mem, type2_bubble, type2_image, | |
| interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(open("description1.md").read()) | |
| gr.Markdown("# Pipeline Scheduler Playground") | |
| presets = { | |
| 'Default Case': (4, 10, 100, 110, 90, 5, 'V-Half (1/2)'), | |
| 'Ideal Case': (4, 10, 20, 20, 20, 0, 'V-Min (1/3)'), | |
| 'Real Case': (4, 10, 1049, 1122, 903, 79, 'V-Half (1/2)'), | |
| 'Zero Bubble Case': (4, 10, 1049, 1122, 903, 79, 'V-ZB (1)') | |
| } | |
| preset_buttons = {} | |
| with gr.Group(): | |
| gr.Markdown("Preset Setups") | |
| with gr.Row(): | |
| for (k, v) in presets.items(): | |
| preset_buttons[k] = gr.Button(k, variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("Basic Parameters") | |
| with gr.Row(): | |
| p=gr.Number(label="Number of stages (p)", value=4, interactive=True, precision=0) | |
| m=gr.Number(label="Number of microbatches (m)", value=10, interactive=True, precision=0) | |
| with gr.Column(scale=2): | |
| with gr.Group(): | |
| gr.Markdown("Costs. All costs are used as integers. For chunked schedules, this is the time of two virtual stages on a stage combined.") | |
| with gr.Row(): | |
| f=gr.Number(label="Time of F", value=100, interactive=True, precision=0) | |
| b=gr.Number(label="Time of B", value=110, interactive=True, precision=0) | |
| w=gr.Number(label="Time of W", value=90, interactive=True, precision=0) | |
| c=gr.Number(label="Time of one P2P communication", value=5, interactive=True, precision=0) | |
| with gr.Group(): | |
| gr.Markdown("Activation memory limit.") | |
| def update_mem(p, s, mem): | |
| print("update") | |
| if s == "custom": | |
| return mem | |
| if s == "V-Min (1/3)": | |
| return (p + 4) // 3 | |
| if s == "V-Half (1/2)": | |
| return (p + 2) // 2 | |
| if s == "V-ZB (1)": | |
| return p | |
| assert False | |
| memsel=gr.Radio(choices=["V-Min (1/3)", "V-Half (1/2)", "V-ZB (1)", "custom"], value="V-Half (1/2)") | |
| mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For chunked schedules, this is relative to two virtual stages on a stage combined.", value=(p.value + 2) // 2, interactive=True, precision=0) | |
| memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem) | |
| p.change(update_mem, inputs=[p, memsel, mem], outputs=mem) | |
| button=gr.Button("Calculate", variant="primary") | |
| with gr.Group(): | |
| gr.Markdown("1F1B") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") | |
| baseline_mem=gr.Textbox("", label="Maximum memory usage") | |
| baseline_bubble=gr.Textbox("", label="Bubble Rate") | |
| with gr.Column(scale=4): | |
| baseline_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False) | |
| with gr.Group(): | |
| gr.Markdown("Adaptive Scheduler") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| adapt_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") | |
| adapt_mem=gr.Textbox("", label="Maximum memory usage") | |
| adapt_bubble=gr.Textbox("", label="Bubble Rate") | |
| with gr.Column(scale=4): | |
| adapt_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False) | |
| gr.Markdown(open("description2.md").read()) | |
| with gr.Group(): | |
| gr.Markdown("1F1B-V Schedule") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| schedule1f1bv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") | |
| schedule1f1bv_mem=gr.Textbox("", label="Maximum memory usage") | |
| schedule1f1bv_bubble=gr.Textbox("", label="Bubble Rate") | |
| with gr.Column(scale=4): | |
| schedule1f1bv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False) | |
| with gr.Group(): | |
| gr.Markdown("Zero bubble schedule with 2/3 1F1B memory") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| type2_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") | |
| type2_mem=gr.Textbox("", label="Maximum memory usage") | |
| type2_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).") | |
| with gr.Column(scale=4): | |
| type2_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False) | |
| with gr.Group(): | |
| gr.Markdown("Variation of Interleaved 1F1B Schedule") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| interleaved_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") | |
| interleaved_mem=gr.Textbox("", label="Maximum memory usage") | |
| interleaved_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).") | |
| with gr.Column(scale=4): | |
| interleaved_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False) | |
| button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_acceleration, baseline_mem, baseline_bubble, baseline_image, | |
| adapt_acceleration, adapt_mem, adapt_bubble, adapt_image, | |
| schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image, | |
| type2_acceleration, type2_mem, type2_bubble, type2_image, | |
| interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image]) | |
| gr.Markdown(open("description3.md").read()) | |
| for (k, v) in presets.items(): | |
| def update_preset(pb, p, m, f, b, w, c, mem): | |
| print(pb) | |
| print(presets[pb]) | |
| print(presets[pb][-1]) | |
| return *presets[pb],*calculate(*presets[pb][:-1], update_mem(p, presets[pb][-1], -1)) | |
| preset_buttons[k].click( | |
| update_preset, | |
| inputs=[preset_buttons[k], p, m, f, b, w, c, mem], | |
| outputs=[p, m, f, b, w, c, memsel, | |
| baseline_acceleration, baseline_mem, baseline_bubble, baseline_image, | |
| adapt_acceleration, adapt_mem, adapt_bubble, adapt_image, | |
| schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image, | |
| type2_acceleration, type2_mem, type2_bubble, type2_image, | |
| interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image]) | |
| demo.launch() | |