toshas's picture
fix issues caused by image size change
6698175
# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# More information about the method can be found at https://marigoldmonodepth.github.io
# --------------------------------------------------------------------------
# TODO: 16bit depth map download
# TODO: change to gradio-dualvision (update it with the Examples thumbs first)
import os
import PIL
import pandas
import requests
import spaces
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import torch as torch
from PIL import Image, ImageDraw
from scipy.ndimage import maximum_filter
from huggingface_hub import login
from marigold_dc import MarigoldDepthCompletionPipeline
DRY_RUN = os.environ.get("ACCELERATOR", "cpu") not in ("zero", "gpu")
DEFAULT_denoise_steps = 10
DEFAULT_lr_latent = 0.05
DEFAULT_lr_scale_shift = 0.005
TILE_CHAR = "██"
TAB10_COLORS = [
(31, 119, 180), # blue
(255, 127, 14), # orange
(44, 160, 44), # green
(214, 39, 40), # red
(148, 103, 189), # purple
(140, 86, 75), # brown
(227, 119, 194), # pink
(127, 127, 127), # gray
(188, 189, 34), # olive
(23, 190, 207) # cyan
]
def adjust_brightness(color, factor):
return tuple(
max(0, min(255, int(c * factor)))
for c in color
)
def get_wrapped_color(index):
base_index = index % len(TAB10_COLORS)
wrap_count = index // len(TAB10_COLORS)
base_color = TAB10_COLORS[base_index]
factor = 1.0 + 0.15 * ((wrap_count % 2) * 2 - 1) * (wrap_count // 2 + 1)
return adjust_brightness(base_color, factor)
def process_click_data(img: Image.Image, state_orig_img: gr.State, table, x: int, y: int, value: str = ""):
if isinstance(img, str):
img = Image.open(img)
if state_orig_img is None:
state_orig_img = img.copy()
if isinstance(table, pandas.DataFrame):
table = table.values.tolist()
color = get_wrapped_color(len(table))
color_hex = '#%02x%02x%02x' % color
img = img.convert("RGB")
draw = ImageDraw.Draw(img)
width, _ = img.size
r = int(width * 0.015)
draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline=color)
draw.ellipse((x - r, y - r, x + r, y + r), fill=None, outline=(255, 255, 255), width=max(1, r//4))
if not isinstance(table, list):
table = table.values.tolist()
table = table + [[TILE_CHAR, value, x, y, color_hex]]
return img, state_orig_img, table
def on_click(img: Image.Image, state_orig_img: gr.State, evt: gr.SelectData, table):
x, y = evt.index
img, state_orig_img, table = process_click_data(img, state_orig_img, table, x, y)
return img, state_orig_img, gr.Dataframe(table, visible=True)
def dilate_rgb_image(image, kernel_size):
r_channel, g_channel, b_channel = image[..., 0], image[..., 1], image[..., 2]
r_dilated = maximum_filter(r_channel, size=kernel_size)
g_dilated = maximum_filter(g_channel, size=kernel_size)
b_dilated = maximum_filter(b_channel, size=kernel_size)
dilated_image = np.stack([r_dilated, g_dilated, b_dilated], axis=-1)
return dilated_image
def generate_rmse_plot(steps, metrics, denoise_steps):
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=steps,
y=metrics,
mode="lines+markers",
line=dict(color="#af2928"),
name="RMSE",
)
)
if denoise_steps < 20:
x_dtick = 1
else:
x_dtick = 5
fig.update_layout(
autosize=True,
height=300,
margin=dict(l=20, r=20, t=20, b=20),
xaxis_title="Steps",
xaxis_range=[0, denoise_steps + 1],
xaxis=dict(
scaleanchor="y",
scaleratio=1.5,
dtick=x_dtick,
),
yaxis_title="RMSE",
yaxis=dict(
type="log",
),
hovermode="x unified",
template="plotly_white",
)
return fig
@spaces.GPU
def process(
image,
state_orig_img,
table,
path_sparse,
denoise_steps=DEFAULT_denoise_steps,
lr_latent=DEFAULT_lr_latent,
lr_scale_shift=DEFAULT_lr_scale_shift,
override_shift=None,
override_scale=None,
):
if override_shift is None:
pass
elif np.isnan(override_shift):
override_shift = None
else:
override_shift = float(override_shift)
if override_scale is None:
pass
elif np.isnan(override_scale):
override_scale = None
else:
override_scale = float(override_scale)
if isinstance(state_orig_img, str):
image = Image.open(state_orig_img)
elif isinstance(state_orig_img, PIL.Image.Image):
image = state_orig_img
elif isinstance(image, str):
image = Image.open(image)
elif isinstance(image, PIL.Image.Image):
pass
else:
raise TypeError(f"Unknown image type: {type(image)}")
if isinstance(table, pandas.DataFrame):
table = table.values.tolist()
if path_sparse is not None and os.path.exists(path_sparse):
# numpy file given (lidar)
sparse_depth = np.load(path_sparse)
sparse_depth_valid = sparse_depth[sparse_depth > 0]
sparse_depth_min = np.min(sparse_depth_valid)
sparse_depth_max = np.max(sparse_depth_valid)
kernel_size = 5
elif table is not None and len(table) >= 2:
# clicks annotations
sparse_depth = np.full((image.height, image.width), np.nan, dtype=np.float32)
for entry in table:
try:
sparse_depth[entry[3], entry[2]] = float(entry[1])
except Exception:
pass
sparse_depth_valid_mask = sparse_depth == sparse_depth
sparse_depth_valid = sparse_depth[sparse_depth_valid_mask]
sparse_depth_valid_num = np.sum(sparse_depth_valid_mask)
if sparse_depth_valid_num >= 2:
sparse_depth_min = np.min(sparse_depth_valid)
sparse_depth_max = np.max(sparse_depth_valid)
sparse_depth[~sparse_depth_valid_mask] = 0
kernel_size = 10
else:
sparse_depth = None
sparse_depth_min = 0
sparse_depth_max = 1
kernel_size = 5
else:
sparse_depth = None
sparse_depth_min = 0
sparse_depth_max = 1
kernel_size = 5
width, height = image.size
max_dim = max(width, height)
processing_resolution = 0
if max_dim > 768:
processing_resolution = 768
metrics = []
steps = []
for step, (pred, rmse) in enumerate(
pipe(
image=image,
sparse_depth=sparse_depth,
num_inference_steps=denoise_steps + 1,
processing_resolution=processing_resolution,
lr_latent=lr_latent,
lr_scale_shift=lr_scale_shift,
override_shift=override_shift,
override_scale=override_scale,
dry_run=DRY_RUN,
)
):
min_both = pred.min().item()
max_both = pred.max().item()
if sparse_depth is not None:
min_both = min(sparse_depth_min, min_both)
max_both = min(sparse_depth_max, max_both)
metrics.append(rmse)
steps.append(step)
vis_pred = pipe.image_processor.visualize_depth(pred, val_min=min_both, val_max=max_both)[0]
if sparse_depth is not None:
vis_sparse = pipe.image_processor.visualize_depth(sparse_depth, val_min=min_both, val_max=max_both)[0]
vis_sparse = np.array(vis_sparse)
vis_sparse[sparse_depth <= 0] = (0, 0, 0)
vis_sparse = dilate_rgb_image(vis_sparse, kernel_size=kernel_size)
else:
vis_sparse = np.full_like(vis_pred, 0)
vis_sparse = Image.fromarray(vis_sparse)
plot = generate_rmse_plot(steps, metrics, denoise_steps)
plot = gr.Plot(plot, visible=True)
slider = gr.ImageSlider([vis_sparse, vis_pred], visible=True)
yield slider, plot
os.system("pip freeze")
print("Environment:\n" + "\n".join(f"{k}: {os.environ[k]}" for k in sorted(os.environ.keys())))
if "HF_TOKEN_LOGIN" in os.environ:
login(token=os.environ["HF_TOKEN_LOGIN"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = MarigoldDepthCompletionPipeline.from_pretrained(
"prs-eth/marigold-depth-v1-1",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)
try:
import xformers
pipe.enable_xformers_memory_efficient_attention()
except:
print("Running without xformers")
pipe = pipe.to(device)
os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
with gr.Blocks(
theme=gr.themes.Default(
primary_hue=gr.themes.colors.red,
spacing_size=gr.themes.sizes.spacing_sm,
radius_size="none",
text_size="md",
).set(
button_secondary_background_fill="black",
button_secondary_text_color="white",
body_background_fill="linear-gradient(to right, #FFE0D0, #E0F0FF)"
),
analytics_enabled=False,
title="Marigold Depth Completion",
css="""
.slider .inner {
width: 4px;
background: #FFF;
}
.slider .icon-wrap {
fill: #FFF;
background-color: #FFF;
stroke: #FFF;
stroke-width: 3px;
}
.viewport {
aspect-ratio: 4/3;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
""",
head="""
<script>
function applyColorToTiles() {
const rows = document.querySelectorAll("table tbody tr");
if (rows.length === 0) return;
rows.forEach(row => {
const tileCell = row.children[0];
const colorCell = row.children[4];
const span = tileCell?.querySelector('span.svelte-1y3tas2.text');
if (span && colorCell?.innerText) {
span.style.color = colorCell.innerText.trim();
}
});
}
let observer = new MutationObserver((mutationsList) => {
applyColorToTiles();
})
observer.observe(document.body, { childList: true, subtree: true });
</script>
"""
) as demo:
gr.HTML(
"""
<h1>⇆ Marigold-DC: Zero-Shot Monocular Depth Completion with Guided Diffusion</h1>
<p align="center">
<a title="Website" href="https://MarigoldDepthCompletion.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%A4%8D%20Project%20-Website-blue" alt="Website Badge">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2412.13389" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/badge/%F0%9F%93%84%20Read%20-Paper-af2928" alt="arXiv Badge">
</a>
<a title="Github" href="https://github.com/prs-eth/marigold-dc" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/prs-eth/marigold-dc?label=GitHub&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a><br>
Upload any image, annotate with a few clicks, and compute dense metric depth!<br>
Alternatively, explore advanced LiDAR functionality and examples at the bottom.
</p>
"""
)
state_orig_img = gr.State()
with gr.Row():
with gr.Column():
thumb = gr.Image(
label="Thumb Image",
type="filepath",
visible=False,
)
input_image = gr.Image(
label="Input image (click to enter depth)",
type="filepath",
interactive=True,
)
table = gr.Dataframe(
headers=["Color", "Enter depth estimates (any unit)", "x", "y", "_color"],
datatype=["str", "number", "number", "number", "str"],
column_widths=["30px", "120px", "0px", "0px", "0px"],
static_columns=[0, 2, 3, 4],
show_fullscreen_button=False,
show_copy_button=False,
show_row_numbers=False,
show_search="none",
row_count=0,
interactive=True,
visible=False,
)
with gr.Accordion("Advanced options", open=False):
with gr.Row():
with gr.Column():
denoise_steps = gr.Slider(
label="Number of denoising steps",
minimum=4,
maximum=50,
step=1,
value=15,
)
lr_latent = gr.Number(
DEFAULT_lr_latent,
interactive=True,
label="Latent LR",
step=0.001,
)
with gr.Row():
lr_scale_shift = gr.Number(
DEFAULT_lr_scale_shift,
interactive=True,
label="Scale-and-shift LR",
step=0.001,
min_width=90,
)
override_shift = gr.Number(
value=float("NaN"),
label="Shift override",
min_width=90,
)
override_scale = gr.Number(
value=float("NaN"),
label="Scale override",
min_width=90,
)
with gr.Column():
input_sparse = gr.File(
label="Input sparse depth (numpy file)",
)
with gr.Row():
submit_btn = gr.Button(value="Compute Depth", variant="primary")
clear_btn = gr.Button(value="Clear")
with gr.Column():
output_slider = gr.ImageSlider(
label="Completed depth (red-near, blue-far)",
type="filepath",
show_download_button=True,
interactive=False,
elem_classes="slider",
slider_position=25,
)
plot = gr.Plot(
label="RMSE between sparse measurements and densified depth",
elem_id="viewport",
)
input_image.select(
on_click,
inputs=[
input_image,
state_orig_img,
table,
],
outputs=[
input_image,
state_orig_img,
table,
],
)
input_image.upload(
lambda : gr.update(label="Click and provide depth estimates in the table below"),
outputs=input_image,
)
def submit_depth_fn(
image,
state_orig_img,
table,
path_sparse,
denoise_steps,
lr_latent,
lr_scale_shift,
override_shift,
override_scale,
):
for outputs in process(
image,
state_orig_img,
table,
path_sparse,
denoise_steps,
lr_latent,
lr_scale_shift,
override_shift,
override_scale,
):
yield outputs
submit_btn.click(
fn=submit_depth_fn,
inputs=[
input_image,
state_orig_img,
table,
input_sparse,
denoise_steps,
lr_latent,
lr_scale_shift,
override_shift,
override_scale,
],
outputs=[
output_slider,
plot,
],
)
def examples_depth_lidar_fn(path_thumb):
real_url = lambda fname: f"https://huggingface.co/spaces/obukhovai/marigold-dc-metric/resolve/main/files/{fname}"
l_thumb = os.path.basename(path_thumb)
d_thumb = os.path.dirname(path_thumb)
l_image, l_sparse, clicks, nsteps = {
"thumb_matterhorn_clicks.jpg": ["matterhorn.jpg", None, [
[TILE_CHAR, "3", 106, 276, '#%02x%02x%02x' % get_wrapped_color(0)],
[TILE_CHAR, "2", 527, 600, '#%02x%02x%02x' % get_wrapped_color(1)],
], 15],
"thumb_kitti_1.jpg": ["kitti_1.png", "kitti_1.npy", [], 25],
"thumb_kitti_2.jpg": ["kitti_2.png", "kitti_2.npy", [], 25],
"thumb_teaser_10.jpg": ["teaser.png", "teaser_10.npy", [], 25],
"thumb_teaser_100.jpg": ["teaser.png", "teaser_100.npy", [], 25],
"thumb_teaser_1000.jpg": ["teaser.png", "teaser_1000.npy", [], 25],
}[l_thumb]
u_image = real_url(l_image)
l_down_image = os.path.join(d_thumb, l_image)
response = requests.get(u_image)
response.raise_for_status()
with open(l_down_image, "wb") as f:
f.write(response.content)
table_visible = len(clicks) > 0
l_down_sparse = None
if l_sparse is not None:
u_sparse = real_url(l_sparse)
l_down_sparse = os.path.join(d_thumb, l_sparse)
response = requests.get(u_sparse)
response.raise_for_status()
with open(l_down_sparse, "wb") as f:
f.write(response.content)
state_orig_img = None
table = []
if len(clicks) > 0:
for click in clicks:
_, value, x, y, _ = click
l_down_image, state_orig_img, table = process_click_data(l_down_image, state_orig_img, table, x, y, value)
for outputs in process(l_down_image, state_orig_img, clicks, l_down_sparse, denoise_steps=nsteps):
yield l_down_image, l_down_sparse, state_orig_img, gr.Dataframe(table, visible=table_visible), *outputs
examples = gr.Examples(
fn=examples_depth_lidar_fn,
examples=[
"files/thumb_matterhorn_clicks.jpg",
"files/thumb_kitti_1.jpg",
"files/thumb_kitti_2.jpg",
"files/thumb_teaser_10.jpg",
"files/thumb_teaser_100.jpg",
"files/thumb_teaser_1000.jpg",
],
inputs=[
thumb,
],
outputs=[
input_image,
input_sparse,
state_orig_img,
table,
output_slider,
plot,
],
cache_mode="lazy",
cache_examples=False,
run_on_click=True,
)
def clear_fn():
return [
gr.update(value=None, interactive=True, label="Input image"),
gr.File(None, interactive=True),
None,
None,
gr.Dataframe([[]], visible=False),
None,
gr.update(interactive=True),
]
clear_btn.click(
fn=clear_fn,
inputs=[],
outputs=[
input_image,
input_sparse,
output_slider,
plot,
table,
state_orig_img,
submit_btn,
],
)
demo.queue(
api_open=False,
).launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False,
)