Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # -------------------------------------------------------------------------- | |
| # If you find this code useful, we kindly ask you to cite our paper in your work. | |
| # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
| # More information about the method can be found at https://marigoldmonodepth.github.io | |
| # -------------------------------------------------------------------------- | |
| # TODO: 16bit depth map download | |
| # TODO: change to gradio-dualvision (update it with the Examples thumbs first) | |
| import os | |
| import PIL | |
| import pandas | |
| import requests | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import torch as torch | |
| from PIL import Image, ImageDraw | |
| from scipy.ndimage import maximum_filter | |
| from huggingface_hub import login | |
| from marigold_dc import MarigoldDepthCompletionPipeline | |
| DRY_RUN = os.environ.get("ACCELERATOR", "cpu") not in ("zero", "gpu") | |
| DEFAULT_denoise_steps = 10 | |
| DEFAULT_lr_latent = 0.05 | |
| DEFAULT_lr_scale_shift = 0.005 | |
| TILE_CHAR = "██" | |
| TAB10_COLORS = [ | |
| (31, 119, 180), # blue | |
| (255, 127, 14), # orange | |
| (44, 160, 44), # green | |
| (214, 39, 40), # red | |
| (148, 103, 189), # purple | |
| (140, 86, 75), # brown | |
| (227, 119, 194), # pink | |
| (127, 127, 127), # gray | |
| (188, 189, 34), # olive | |
| (23, 190, 207) # cyan | |
| ] | |
| def adjust_brightness(color, factor): | |
| return tuple( | |
| max(0, min(255, int(c * factor))) | |
| for c in color | |
| ) | |
| def get_wrapped_color(index): | |
| base_index = index % len(TAB10_COLORS) | |
| wrap_count = index // len(TAB10_COLORS) | |
| base_color = TAB10_COLORS[base_index] | |
| factor = 1.0 + 0.15 * ((wrap_count % 2) * 2 - 1) * (wrap_count // 2 + 1) | |
| return adjust_brightness(base_color, factor) | |
| def process_click_data(img: Image.Image, state_orig_img: gr.State, table, x: int, y: int, value: str = ""): | |
| if isinstance(img, str): | |
| img = Image.open(img) | |
| if state_orig_img is None: | |
| state_orig_img = img.copy() | |
| if isinstance(table, pandas.DataFrame): | |
| table = table.values.tolist() | |
| color = get_wrapped_color(len(table)) | |
| color_hex = '#%02x%02x%02x' % color | |
| img = img.convert("RGB") | |
| draw = ImageDraw.Draw(img) | |
| width, _ = img.size | |
| r = int(width * 0.015) | |
| draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline=color) | |
| draw.ellipse((x - r, y - r, x + r, y + r), fill=None, outline=(255, 255, 255), width=max(1, r//4)) | |
| if not isinstance(table, list): | |
| table = table.values.tolist() | |
| table = table + [[TILE_CHAR, value, x, y, color_hex]] | |
| return img, state_orig_img, table | |
| def on_click(img: Image.Image, state_orig_img: gr.State, evt: gr.SelectData, table): | |
| x, y = evt.index | |
| img, state_orig_img, table = process_click_data(img, state_orig_img, table, x, y) | |
| return img, state_orig_img, gr.Dataframe(table, visible=True) | |
| def dilate_rgb_image(image, kernel_size): | |
| r_channel, g_channel, b_channel = image[..., 0], image[..., 1], image[..., 2] | |
| r_dilated = maximum_filter(r_channel, size=kernel_size) | |
| g_dilated = maximum_filter(g_channel, size=kernel_size) | |
| b_dilated = maximum_filter(b_channel, size=kernel_size) | |
| dilated_image = np.stack([r_dilated, g_dilated, b_dilated], axis=-1) | |
| return dilated_image | |
| def generate_rmse_plot(steps, metrics, denoise_steps): | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=steps, | |
| y=metrics, | |
| mode="lines+markers", | |
| line=dict(color="#af2928"), | |
| name="RMSE", | |
| ) | |
| ) | |
| if denoise_steps < 20: | |
| x_dtick = 1 | |
| else: | |
| x_dtick = 5 | |
| fig.update_layout( | |
| autosize=True, | |
| height=300, | |
| margin=dict(l=20, r=20, t=20, b=20), | |
| xaxis_title="Steps", | |
| xaxis_range=[0, denoise_steps + 1], | |
| xaxis=dict( | |
| scaleanchor="y", | |
| scaleratio=1.5, | |
| dtick=x_dtick, | |
| ), | |
| yaxis_title="RMSE", | |
| yaxis=dict( | |
| type="log", | |
| ), | |
| hovermode="x unified", | |
| template="plotly_white", | |
| ) | |
| return fig | |
| def process( | |
| image, | |
| state_orig_img, | |
| table, | |
| path_sparse, | |
| denoise_steps=DEFAULT_denoise_steps, | |
| lr_latent=DEFAULT_lr_latent, | |
| lr_scale_shift=DEFAULT_lr_scale_shift, | |
| override_shift=None, | |
| override_scale=None, | |
| ): | |
| if override_shift is None: | |
| pass | |
| elif np.isnan(override_shift): | |
| override_shift = None | |
| else: | |
| override_shift = float(override_shift) | |
| if override_scale is None: | |
| pass | |
| elif np.isnan(override_scale): | |
| override_scale = None | |
| else: | |
| override_scale = float(override_scale) | |
| if isinstance(state_orig_img, str): | |
| image = Image.open(state_orig_img) | |
| elif isinstance(state_orig_img, PIL.Image.Image): | |
| image = state_orig_img | |
| elif isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, PIL.Image.Image): | |
| pass | |
| else: | |
| raise TypeError(f"Unknown image type: {type(image)}") | |
| if isinstance(table, pandas.DataFrame): | |
| table = table.values.tolist() | |
| if path_sparse is not None and os.path.exists(path_sparse): | |
| # numpy file given (lidar) | |
| sparse_depth = np.load(path_sparse) | |
| sparse_depth_valid = sparse_depth[sparse_depth > 0] | |
| sparse_depth_min = np.min(sparse_depth_valid) | |
| sparse_depth_max = np.max(sparse_depth_valid) | |
| kernel_size = 5 | |
| elif table is not None and len(table) >= 2: | |
| # clicks annotations | |
| sparse_depth = np.full((image.height, image.width), np.nan, dtype=np.float32) | |
| for entry in table: | |
| try: | |
| sparse_depth[entry[3], entry[2]] = float(entry[1]) | |
| except Exception: | |
| pass | |
| sparse_depth_valid_mask = sparse_depth == sparse_depth | |
| sparse_depth_valid = sparse_depth[sparse_depth_valid_mask] | |
| sparse_depth_valid_num = np.sum(sparse_depth_valid_mask) | |
| if sparse_depth_valid_num >= 2: | |
| sparse_depth_min = np.min(sparse_depth_valid) | |
| sparse_depth_max = np.max(sparse_depth_valid) | |
| sparse_depth[~sparse_depth_valid_mask] = 0 | |
| kernel_size = 10 | |
| else: | |
| sparse_depth = None | |
| sparse_depth_min = 0 | |
| sparse_depth_max = 1 | |
| kernel_size = 5 | |
| else: | |
| sparse_depth = None | |
| sparse_depth_min = 0 | |
| sparse_depth_max = 1 | |
| kernel_size = 5 | |
| width, height = image.size | |
| max_dim = max(width, height) | |
| processing_resolution = 0 | |
| if max_dim > 768: | |
| processing_resolution = 768 | |
| metrics = [] | |
| steps = [] | |
| for step, (pred, rmse) in enumerate( | |
| pipe( | |
| image=image, | |
| sparse_depth=sparse_depth, | |
| num_inference_steps=denoise_steps + 1, | |
| processing_resolution=processing_resolution, | |
| lr_latent=lr_latent, | |
| lr_scale_shift=lr_scale_shift, | |
| override_shift=override_shift, | |
| override_scale=override_scale, | |
| dry_run=DRY_RUN, | |
| ) | |
| ): | |
| min_both = pred.min().item() | |
| max_both = pred.max().item() | |
| if sparse_depth is not None: | |
| min_both = min(sparse_depth_min, min_both) | |
| max_both = min(sparse_depth_max, max_both) | |
| metrics.append(rmse) | |
| steps.append(step) | |
| vis_pred = pipe.image_processor.visualize_depth(pred, val_min=min_both, val_max=max_both)[0] | |
| if sparse_depth is not None: | |
| vis_sparse = pipe.image_processor.visualize_depth(sparse_depth, val_min=min_both, val_max=max_both)[0] | |
| vis_sparse = np.array(vis_sparse) | |
| vis_sparse[sparse_depth <= 0] = (0, 0, 0) | |
| vis_sparse = dilate_rgb_image(vis_sparse, kernel_size=kernel_size) | |
| else: | |
| vis_sparse = np.full_like(vis_pred, 0) | |
| vis_sparse = Image.fromarray(vis_sparse) | |
| plot = generate_rmse_plot(steps, metrics, denoise_steps) | |
| plot = gr.Plot(plot, visible=True) | |
| slider = gr.ImageSlider([vis_sparse, vis_pred], visible=True) | |
| yield slider, plot | |
| os.system("pip freeze") | |
| print("Environment:\n" + "\n".join(f"{k}: {os.environ[k]}" for k in sorted(os.environ.keys()))) | |
| if "HF_TOKEN_LOGIN" in os.environ: | |
| login(token=os.environ["HF_TOKEN_LOGIN"]) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| pipe = MarigoldDepthCompletionPipeline.from_pretrained( | |
| "prs-eth/marigold-depth-v1-1", | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| ) | |
| try: | |
| import xformers | |
| pipe.enable_xformers_memory_efficient_attention() | |
| except: | |
| print("Running without xformers") | |
| pipe = pipe.to(device) | |
| os.environ["GRADIO_ALLOW_FLAGGING"] = "never" | |
| with gr.Blocks( | |
| theme=gr.themes.Default( | |
| primary_hue=gr.themes.colors.red, | |
| spacing_size=gr.themes.sizes.spacing_sm, | |
| radius_size="none", | |
| text_size="md", | |
| ).set( | |
| button_secondary_background_fill="black", | |
| button_secondary_text_color="white", | |
| body_background_fill="linear-gradient(to right, #FFE0D0, #E0F0FF)" | |
| ), | |
| analytics_enabled=False, | |
| title="Marigold Depth Completion", | |
| css=""" | |
| .slider .inner { | |
| width: 4px; | |
| background: #FFF; | |
| } | |
| .slider .icon-wrap { | |
| fill: #FFF; | |
| background-color: #FFF; | |
| stroke: #FFF; | |
| stroke-width: 3px; | |
| } | |
| .viewport { | |
| aspect-ratio: 4/3; | |
| } | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| h2 { | |
| text-align: center; | |
| display: block; | |
| } | |
| h3 { | |
| text-align: center; | |
| display: block; | |
| } | |
| """, | |
| head=""" | |
| <script> | |
| function applyColorToTiles() { | |
| const rows = document.querySelectorAll("table tbody tr"); | |
| if (rows.length === 0) return; | |
| rows.forEach(row => { | |
| const tileCell = row.children[0]; | |
| const colorCell = row.children[4]; | |
| const span = tileCell?.querySelector('span.svelte-1y3tas2.text'); | |
| if (span && colorCell?.innerText) { | |
| span.style.color = colorCell.innerText.trim(); | |
| } | |
| }); | |
| } | |
| let observer = new MutationObserver((mutationsList) => { | |
| applyColorToTiles(); | |
| }) | |
| observer.observe(document.body, { childList: true, subtree: true }); | |
| </script> | |
| """ | |
| ) as demo: | |
| gr.HTML( | |
| """ | |
| <h1>⇆ Marigold-DC: Zero-Shot Monocular Depth Completion with Guided Diffusion</h1> | |
| <p align="center"> | |
| <a title="Website" href="https://MarigoldDepthCompletion.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/badge/%F0%9F%A4%8D%20Project%20-Website-blue" alt="Website Badge"> | |
| </a> | |
| <a title="arXiv" href="https://arxiv.org/abs/2412.13389" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-af2928" alt="arXiv Badge"> | |
| </a> | |
| <a title="Github" href="https://github.com/prs-eth/marigold-dc" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/github/stars/prs-eth/marigold-dc?label=GitHub&logo=github&color=C8C" alt="badge-github-stars"> | |
| </a> | |
| <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
| </a><br> | |
| Upload any image, annotate with a few clicks, and compute dense metric depth!<br> | |
| Alternatively, explore advanced LiDAR functionality and examples at the bottom. | |
| </p> | |
| """ | |
| ) | |
| state_orig_img = gr.State() | |
| with gr.Row(): | |
| with gr.Column(): | |
| thumb = gr.Image( | |
| label="Thumb Image", | |
| type="filepath", | |
| visible=False, | |
| ) | |
| input_image = gr.Image( | |
| label="Input image (click to enter depth)", | |
| type="filepath", | |
| interactive=True, | |
| ) | |
| table = gr.Dataframe( | |
| headers=["Color", "Enter depth estimates (any unit)", "x", "y", "_color"], | |
| datatype=["str", "number", "number", "number", "str"], | |
| column_widths=["30px", "120px", "0px", "0px", "0px"], | |
| static_columns=[0, 2, 3, 4], | |
| show_fullscreen_button=False, | |
| show_copy_button=False, | |
| show_row_numbers=False, | |
| show_search="none", | |
| row_count=0, | |
| interactive=True, | |
| visible=False, | |
| ) | |
| with gr.Accordion("Advanced options", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| denoise_steps = gr.Slider( | |
| label="Number of denoising steps", | |
| minimum=4, | |
| maximum=50, | |
| step=1, | |
| value=15, | |
| ) | |
| lr_latent = gr.Number( | |
| DEFAULT_lr_latent, | |
| interactive=True, | |
| label="Latent LR", | |
| step=0.001, | |
| ) | |
| with gr.Row(): | |
| lr_scale_shift = gr.Number( | |
| DEFAULT_lr_scale_shift, | |
| interactive=True, | |
| label="Scale-and-shift LR", | |
| step=0.001, | |
| min_width=90, | |
| ) | |
| override_shift = gr.Number( | |
| value=float("NaN"), | |
| label="Shift override", | |
| min_width=90, | |
| ) | |
| override_scale = gr.Number( | |
| value=float("NaN"), | |
| label="Scale override", | |
| min_width=90, | |
| ) | |
| with gr.Column(): | |
| input_sparse = gr.File( | |
| label="Input sparse depth (numpy file)", | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button(value="Compute Depth", variant="primary") | |
| clear_btn = gr.Button(value="Clear") | |
| with gr.Column(): | |
| output_slider = gr.ImageSlider( | |
| label="Completed depth (red-near, blue-far)", | |
| type="filepath", | |
| show_download_button=True, | |
| interactive=False, | |
| elem_classes="slider", | |
| slider_position=25, | |
| ) | |
| plot = gr.Plot( | |
| label="RMSE between sparse measurements and densified depth", | |
| elem_id="viewport", | |
| ) | |
| input_image.select( | |
| on_click, | |
| inputs=[ | |
| input_image, | |
| state_orig_img, | |
| table, | |
| ], | |
| outputs=[ | |
| input_image, | |
| state_orig_img, | |
| table, | |
| ], | |
| ) | |
| input_image.upload( | |
| lambda : gr.update(label="Click and provide depth estimates in the table below"), | |
| outputs=input_image, | |
| ) | |
| def submit_depth_fn( | |
| image, | |
| state_orig_img, | |
| table, | |
| path_sparse, | |
| denoise_steps, | |
| lr_latent, | |
| lr_scale_shift, | |
| override_shift, | |
| override_scale, | |
| ): | |
| for outputs in process( | |
| image, | |
| state_orig_img, | |
| table, | |
| path_sparse, | |
| denoise_steps, | |
| lr_latent, | |
| lr_scale_shift, | |
| override_shift, | |
| override_scale, | |
| ): | |
| yield outputs | |
| submit_btn.click( | |
| fn=submit_depth_fn, | |
| inputs=[ | |
| input_image, | |
| state_orig_img, | |
| table, | |
| input_sparse, | |
| denoise_steps, | |
| lr_latent, | |
| lr_scale_shift, | |
| override_shift, | |
| override_scale, | |
| ], | |
| outputs=[ | |
| output_slider, | |
| plot, | |
| ], | |
| ) | |
| def examples_depth_lidar_fn(path_thumb): | |
| real_url = lambda fname: f"https://huggingface.co/spaces/obukhovai/marigold-dc-metric/resolve/main/files/{fname}" | |
| l_thumb = os.path.basename(path_thumb) | |
| d_thumb = os.path.dirname(path_thumb) | |
| l_image, l_sparse, clicks, nsteps = { | |
| "thumb_matterhorn_clicks.jpg": ["matterhorn.jpg", None, [ | |
| [TILE_CHAR, "3", 106, 276, '#%02x%02x%02x' % get_wrapped_color(0)], | |
| [TILE_CHAR, "2", 527, 600, '#%02x%02x%02x' % get_wrapped_color(1)], | |
| ], 15], | |
| "thumb_kitti_1.jpg": ["kitti_1.png", "kitti_1.npy", [], 25], | |
| "thumb_kitti_2.jpg": ["kitti_2.png", "kitti_2.npy", [], 25], | |
| "thumb_teaser_10.jpg": ["teaser.png", "teaser_10.npy", [], 25], | |
| "thumb_teaser_100.jpg": ["teaser.png", "teaser_100.npy", [], 25], | |
| "thumb_teaser_1000.jpg": ["teaser.png", "teaser_1000.npy", [], 25], | |
| }[l_thumb] | |
| u_image = real_url(l_image) | |
| l_down_image = os.path.join(d_thumb, l_image) | |
| response = requests.get(u_image) | |
| response.raise_for_status() | |
| with open(l_down_image, "wb") as f: | |
| f.write(response.content) | |
| table_visible = len(clicks) > 0 | |
| l_down_sparse = None | |
| if l_sparse is not None: | |
| u_sparse = real_url(l_sparse) | |
| l_down_sparse = os.path.join(d_thumb, l_sparse) | |
| response = requests.get(u_sparse) | |
| response.raise_for_status() | |
| with open(l_down_sparse, "wb") as f: | |
| f.write(response.content) | |
| state_orig_img = None | |
| table = [] | |
| if len(clicks) > 0: | |
| for click in clicks: | |
| _, value, x, y, _ = click | |
| l_down_image, state_orig_img, table = process_click_data(l_down_image, state_orig_img, table, x, y, value) | |
| for outputs in process(l_down_image, state_orig_img, clicks, l_down_sparse, denoise_steps=nsteps): | |
| yield l_down_image, l_down_sparse, state_orig_img, gr.Dataframe(table, visible=table_visible), *outputs | |
| examples = gr.Examples( | |
| fn=examples_depth_lidar_fn, | |
| examples=[ | |
| "files/thumb_matterhorn_clicks.jpg", | |
| "files/thumb_kitti_1.jpg", | |
| "files/thumb_kitti_2.jpg", | |
| "files/thumb_teaser_10.jpg", | |
| "files/thumb_teaser_100.jpg", | |
| "files/thumb_teaser_1000.jpg", | |
| ], | |
| inputs=[ | |
| thumb, | |
| ], | |
| outputs=[ | |
| input_image, | |
| input_sparse, | |
| state_orig_img, | |
| table, | |
| output_slider, | |
| plot, | |
| ], | |
| cache_mode="lazy", | |
| cache_examples=False, | |
| run_on_click=True, | |
| ) | |
| def clear_fn(): | |
| return [ | |
| gr.update(value=None, interactive=True, label="Input image"), | |
| gr.File(None, interactive=True), | |
| None, | |
| None, | |
| gr.Dataframe([[]], visible=False), | |
| None, | |
| gr.update(interactive=True), | |
| ] | |
| clear_btn.click( | |
| fn=clear_fn, | |
| inputs=[], | |
| outputs=[ | |
| input_image, | |
| input_sparse, | |
| output_slider, | |
| plot, | |
| table, | |
| state_orig_img, | |
| submit_btn, | |
| ], | |
| ) | |
| demo.queue( | |
| api_open=False, | |
| ).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ssr_mode=False, | |
| ) | |