Spaces:
Running
on
Zero
Running
on
Zero
xinjie.wang
commited on
Commit
·
c28dddb
1
Parent(s):
93b8c4c
init commit
Browse files- app.py +133 -0
- configs/config.yaml +93 -0
- dataset/__init__.py +13 -0
- dataset/base_dataset.py +404 -0
- dataset/data_module.py +82 -0
- dataset/mydataset.py +282 -0
- dataset/utils.py +194 -0
- inference.py +450 -0
- lightning_logs/version_0/hparams.yaml +1 -0
- lightning_logs/version_1/hparams.yaml +1 -0
- lightning_logs/version_2/hparams.yaml +1 -0
- lightning_logs/version_3/hparams.yaml +1 -0
- lightning_logs/version_4/hparams.yaml +1 -0
- lightning_logs/version_5/hparams.yaml +1 -0
- lightning_logs/version_6/hparams.yaml +111 -0
- lightning_logs/version_6/metrics.csv +4 -0
- metrics/__init__.py +0 -0
- metrics/aor.py +44 -0
- metrics/cd.py +284 -0
- metrics/giou.py +142 -0
- metrics/iou.py +220 -0
- metrics/iou_cdist.py +227 -0
- models/__init__.py +19 -0
- models/denoiser.py +415 -0
- models/utils.py +199 -0
- my_utils/__init__.py +0 -0
- my_utils/callbacks.py +36 -0
- my_utils/lr_schedulers.py +104 -0
- my_utils/misc.py +35 -0
- my_utils/plot.py +122 -0
- my_utils/refs.py +122 -0
- my_utils/render.py +482 -0
- my_utils/savermixins.py +55 -0
- objects/__init__.py +0 -0
- objects/dict_utils.py +299 -0
- objects/motions.py +99 -0
- requirements.txt +21 -0
- retrieval/__init__.py +0 -0
- retrieval/obj_retrieval.py +509 -0
- retrieval/retrieval_hash_acd.json +329 -0
- retrieval/retrieval_hash_no_handles.json +722 -0
- scripts/graph_pred/api.py +210 -0
- scripts/graph_pred/eval.py +62 -0
- scripts/graph_pred/prompt_workflow_new.py +363 -0
- scripts/json2urdf.py +160 -0
- scripts/mesh_retrieval/retrieve.py +97 -0
- scripts/mesh_retrieval/retrieve_gpt.py +29 -0
- scripts/mesh_retrieval/run_retrieve.py +68 -0
app.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import zipfile
|
| 5 |
+
from types import SimpleNamespace
|
| 6 |
+
from inference import run_demo, load_config
|
| 7 |
+
import random
|
| 8 |
+
import string
|
| 9 |
+
from gradio.themes import Soft
|
| 10 |
+
from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
|
| 11 |
+
|
| 12 |
+
custom_theme = Soft(
|
| 13 |
+
primary_hue=stone,
|
| 14 |
+
secondary_hue=gray,
|
| 15 |
+
radius_size="md",
|
| 16 |
+
text_size="sm",
|
| 17 |
+
spacing_size="sm",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def inference_ui(img1, img2, omega, n_denoise_steps):
|
| 22 |
+
tmpdir = 'results'
|
| 23 |
+
random_str = ''.join(random.choices(string.ascii_letters, k=16))
|
| 24 |
+
tmpdir = tmpdir + "_" + random_str
|
| 25 |
+
|
| 26 |
+
# 删除所有包含 "results" 的目录
|
| 27 |
+
for dir in os.listdir('.'):
|
| 28 |
+
if dir.startswith('results') and os.path.isdir(dir):
|
| 29 |
+
shutil.rmtree(dir)
|
| 30 |
+
os.makedirs(os.path.join(tmpdir, "0"), exist_ok=True)
|
| 31 |
+
|
| 32 |
+
args = SimpleNamespace(
|
| 33 |
+
img_path_1=img1,
|
| 34 |
+
img_path_2=img2,
|
| 35 |
+
ckpt_path='ckpts/dipo.ckpt',
|
| 36 |
+
config_path='configs/config.yaml',
|
| 37 |
+
use_example_graph=False,
|
| 38 |
+
save_dir=tmpdir,
|
| 39 |
+
gt_data_root='./data/PartnetMobility',
|
| 40 |
+
n_samples=3,
|
| 41 |
+
omega=omega,
|
| 42 |
+
n_denoise_steps=n_denoise_steps,
|
| 43 |
+
)
|
| 44 |
+
args.config = load_config(args.config_path)
|
| 45 |
+
run_demo(args)
|
| 46 |
+
|
| 47 |
+
gif_path = os.path.join(tmpdir, "0", "animation.gif")
|
| 48 |
+
ply_path = os.path.join(tmpdir, "0", "object.ply")
|
| 49 |
+
glb_path = os.path.join(tmpdir, "0", "object.glb")
|
| 50 |
+
|
| 51 |
+
# 压缩结果为ZIP包
|
| 52 |
+
zip_path = os.path.join(tmpdir, "output.zip")
|
| 53 |
+
folder_to_zip = os.path.join(tmpdir, "0")
|
| 54 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 55 |
+
for root, dirs, files in os.walk(folder_to_zip):
|
| 56 |
+
for file in files:
|
| 57 |
+
abs_path = os.path.join(root, file)
|
| 58 |
+
rel_path = os.path.relpath(abs_path, folder_to_zip)
|
| 59 |
+
zipf.write(abs_path, arcname=rel_path)
|
| 60 |
+
|
| 61 |
+
return (
|
| 62 |
+
gif_path if os.path.exists(gif_path) else None,
|
| 63 |
+
zip_path if os.path.exists(zip_path) else None
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def prepare_data():
|
| 67 |
+
if not os.path.exists("data") or not os.path.exists("saved_model"):
|
| 68 |
+
print("Downloading data.tar from Hugging Face Datasets...")
|
| 69 |
+
os.system("wget https://huggingface.co/datasets/wuruiqi0722/DIPO_data/resolve/main/data/data.tar -O data.tar")
|
| 70 |
+
os.system("tar -xf data.tar")
|
| 71 |
+
|
| 72 |
+
with gr.Blocks(theme=custom_theme) as demo:
|
| 73 |
+
gr.Markdown("## DIPO: Dual-State Images Controlled Articulated Object Generation Powered by Diverse Data")
|
| 74 |
+
gr.Markdown(
|
| 75 |
+
"""
|
| 76 |
+
<p style="display: flex; gap: 10px; flex-wrap: nowrap;">
|
| 77 |
+
<a href="https://rq-wu.github.io/projects/DIPO">
|
| 78 |
+
<img alt="📖 Project Page" src="https://img.shields.io/badge/📖-Project_Page-blue">
|
| 79 |
+
</a>
|
| 80 |
+
<a href="https://arxiv.org/abs/2505.20460">
|
| 81 |
+
<img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
|
| 82 |
+
</a>
|
| 83 |
+
<a href="https://github.com/RQ-Wu/DIPO">
|
| 84 |
+
<img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
|
| 85 |
+
</a>
|
| 86 |
+
</p>
|
| 87 |
+
"""
|
| 88 |
+
)
|
| 89 |
+
gr.Markdown("Currently, only the articulated object in following categories are supported: Table, Dishwasher, StorageFurniture, Refrigerator, WashingMachine, Microwave, Oven.")
|
| 90 |
+
|
| 91 |
+
with gr.Row():
|
| 92 |
+
with gr.Column(scale=1):
|
| 93 |
+
img1_input = gr.Image(label="Image: Closed State", type="filepath", height=250)
|
| 94 |
+
img2_input = gr.Image(label="Image: Opened State", type="filepath", height=250)
|
| 95 |
+
omega = gr.Slider(0.0, 1.0, step=0.1, value=0.5, label="Omega (CFG Guidance)")
|
| 96 |
+
n_denoise = gr.Slider(10, 200, step=10, value=100, label="Denoising Steps")
|
| 97 |
+
run_button = gr.Button("🚀 Run Generation (~2mins)")
|
| 98 |
+
|
| 99 |
+
with gr.Column(scale=1):
|
| 100 |
+
output_gif = gr.Image(label="GIF Animation", type="filepath", height=678, width=10000)
|
| 101 |
+
zip_download_btn = gr.DownloadButton(label="📦 Download URDF folder", interactive=False)
|
| 102 |
+
|
| 103 |
+
gr.Examples(
|
| 104 |
+
examples=[
|
| 105 |
+
["examples/1.png", "examples/1_open_1.png"],
|
| 106 |
+
["examples/1.png", "examples/1_open_2.png"],
|
| 107 |
+
["examples/close1.png", "examples/open1.png"],
|
| 108 |
+
# ["examples/close2.png", "examples/open2.png"],
|
| 109 |
+
["examples/close3.png", "examples/open3.png"],
|
| 110 |
+
# ["examples/close4.png", "examples/open4.png"],
|
| 111 |
+
["examples/close5.png", "examples/open5.png"],
|
| 112 |
+
["examples/close6.png", "examples/open6.png"],
|
| 113 |
+
["examples/close7.png", "examples/open7.png"],
|
| 114 |
+
["examples/close8.png", "examples/open8.png"],
|
| 115 |
+
["examples/close9.jpg", "examples/open9.jpg"],
|
| 116 |
+
["examples/close10.png", "examples/open10.png"],
|
| 117 |
+
],
|
| 118 |
+
inputs=[img1_input, img2_input],
|
| 119 |
+
label="📂 Example Inputs"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
run_button.click(
|
| 123 |
+
fn=inference_ui,
|
| 124 |
+
inputs=[img1_input, img2_input, omega, n_denoise],
|
| 125 |
+
outputs=[output_gif, zip_download_btn]
|
| 126 |
+
).success(
|
| 127 |
+
lambda: gr.DownloadButton(interactive=True),
|
| 128 |
+
outputs=[zip_download_btn]
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
prepare_data()
|
| 133 |
+
demo.launch()
|
configs/config.yaml
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: dipo
|
| 2 |
+
version: denoiser
|
| 3 |
+
|
| 4 |
+
data:
|
| 5 |
+
name: dm_dipo
|
| 6 |
+
json_root: data_path
|
| 7 |
+
root: data_path # root directory of the dataset
|
| 8 |
+
batch_size: 20 # batch size for training
|
| 9 |
+
num_workers: 8 # number of workers for data loading
|
| 10 |
+
K: 32 # maximum number of nodes (parts) in the graph (object)
|
| 11 |
+
split_file: split_file_path
|
| 12 |
+
n_views_per_model: 20
|
| 13 |
+
frame_mode: last_frame
|
| 14 |
+
test_which: pm
|
| 15 |
+
mode_num: 5
|
| 16 |
+
|
| 17 |
+
system:
|
| 18 |
+
name: sys_origin
|
| 19 |
+
exp_dir: ./exps/${name}/${version}
|
| 20 |
+
data_root: ${data.root}
|
| 21 |
+
n_time_samples: 16
|
| 22 |
+
loss_fg_weight: 0.01
|
| 23 |
+
img_drop_prob: 0.1 # image dropout probability, for classifier free training
|
| 24 |
+
guidance_scaler: 0.5 # scaling factor for guidance on the image during inference
|
| 25 |
+
graph_drop_prob: 0.5 # graph dropout probability, for classifier free training
|
| 26 |
+
|
| 27 |
+
model:
|
| 28 |
+
name: denoiser
|
| 29 |
+
in_ch: 6
|
| 30 |
+
attn_dim: 128
|
| 31 |
+
n_head: 4
|
| 32 |
+
n_layers: 6
|
| 33 |
+
dropout: 0.1
|
| 34 |
+
K: ${data.K}
|
| 35 |
+
mode_num: 5
|
| 36 |
+
img_emb_dims: [768, 128]
|
| 37 |
+
cat_drop_prob: 0.5 # object category dropout probability, for classifier free training
|
| 38 |
+
|
| 39 |
+
scheduler: # scheduler for the diffusion model
|
| 40 |
+
name: ddpm
|
| 41 |
+
config:
|
| 42 |
+
num_train_timesteps: 1000
|
| 43 |
+
beta_schedule: linear
|
| 44 |
+
prediction_type: epsilon
|
| 45 |
+
|
| 46 |
+
lr_scheduler_adapter: # lr scheduler for the new modules on top of the base model
|
| 47 |
+
name: LinearWarmupCosineAnnealingLR
|
| 48 |
+
warmup_epochs: 3
|
| 49 |
+
max_epochs: ${trainer.max_epochs}
|
| 50 |
+
warmup_start_lr: 1e-6
|
| 51 |
+
eta_min: 1e-5
|
| 52 |
+
|
| 53 |
+
optimizer_adapter: # optimizer for the new modules on top of the base model
|
| 54 |
+
name: AdamW
|
| 55 |
+
args:
|
| 56 |
+
lr: 5e-4
|
| 57 |
+
betas: [0.9, 0.99]
|
| 58 |
+
eps: 1.e-15
|
| 59 |
+
|
| 60 |
+
lr_scheduler_cage: # lr scheduler for modules in the base model
|
| 61 |
+
name: LinearWarmupCosineAnnealingLR
|
| 62 |
+
warmup_epochs: 3
|
| 63 |
+
max_epochs: ${trainer.max_epochs}
|
| 64 |
+
warmup_start_lr: 1e-6
|
| 65 |
+
eta_min: 1e-5
|
| 66 |
+
|
| 67 |
+
optimizer_cage: # optimizer for modules in the base model
|
| 68 |
+
name: AdamW
|
| 69 |
+
args:
|
| 70 |
+
lr: 5e-5
|
| 71 |
+
betas: [0.9, 0.99]
|
| 72 |
+
eps: 1.e-15
|
| 73 |
+
|
| 74 |
+
checkpoint:
|
| 75 |
+
dirpath: ${system.exp_dir}/ckpts
|
| 76 |
+
save_top_k: -1
|
| 77 |
+
every_n_epochs: 50
|
| 78 |
+
|
| 79 |
+
logger: # wandb logger
|
| 80 |
+
save_dir: ${system.exp_dir}/logs # directory to save logs
|
| 81 |
+
name: ${name}_${version}
|
| 82 |
+
project: SINGAPO
|
| 83 |
+
|
| 84 |
+
trainer:
|
| 85 |
+
max_epochs: 200
|
| 86 |
+
log_every_n_steps: 100
|
| 87 |
+
limit_train_batches: 1.0
|
| 88 |
+
limit_val_batches: 1.0
|
| 89 |
+
check_val_every_n_epoch: 10
|
| 90 |
+
precision: 16-mixed
|
| 91 |
+
profiler: simple
|
| 92 |
+
num_sanity_val_steps: -1
|
| 93 |
+
|
dataset/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datamodules = {}
|
| 2 |
+
|
| 3 |
+
def register(name):
|
| 4 |
+
def decorator(cls):
|
| 5 |
+
datamodules[name] = cls
|
| 6 |
+
return cls
|
| 7 |
+
return decorator
|
| 8 |
+
|
| 9 |
+
def make(name, config):
|
| 10 |
+
dm = datamodules[name](config)
|
| 11 |
+
return dm
|
| 12 |
+
|
| 13 |
+
from . import data_module
|
dataset/base_dataset.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
# import collections.abc
|
| 5 |
+
# sys.modules['collections'].Mapping = collections.abc.Mapping
|
| 6 |
+
|
| 7 |
+
import networkx as nx
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
from my_utils.refs import cat_ref, sem_ref, joint_ref, data_mode_ref
|
| 10 |
+
from collections import deque
|
| 11 |
+
|
| 12 |
+
def build_graph(tree, K=32):
|
| 13 |
+
'''
|
| 14 |
+
Function to build graph from the node list.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
nodes: list of nodes
|
| 18 |
+
K: the maximum number of nodes in the graph
|
| 19 |
+
Returns:
|
| 20 |
+
adj: adjacency matrix, records the 1-ring relationship (parent+children) between nodes
|
| 21 |
+
edge_list: list of edges, for visualization
|
| 22 |
+
'''
|
| 23 |
+
adj = np.zeros((K, K), dtype=np.float32)
|
| 24 |
+
parents = []
|
| 25 |
+
tree_list = []
|
| 26 |
+
for node in tree:
|
| 27 |
+
tree_list.append(
|
| 28 |
+
{
|
| 29 |
+
'id': node['id'],
|
| 30 |
+
'parent_id': node['parent'],
|
| 31 |
+
}
|
| 32 |
+
)
|
| 33 |
+
# 1-ring relationship
|
| 34 |
+
if node['parent'] != -1:
|
| 35 |
+
adj[node['id'], node['parent']] = 1
|
| 36 |
+
parents.append(node['parent'])
|
| 37 |
+
else:
|
| 38 |
+
adj[node['id'], node['id']] = 1
|
| 39 |
+
parents.append(-1)
|
| 40 |
+
for child_id in node['children']:
|
| 41 |
+
adj[node['id'], child_id] = 1
|
| 42 |
+
return {
|
| 43 |
+
'adj': adj,
|
| 44 |
+
'parents': np.array(parents, dtype=np.int8),
|
| 45 |
+
'tree_list': tree_list
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
from collections import defaultdict
|
| 49 |
+
from functools import cmp_to_key
|
| 50 |
+
|
| 51 |
+
def bfs_tree_simple(tree_list):
|
| 52 |
+
order = [0] * len(tree_list)
|
| 53 |
+
queue = []
|
| 54 |
+
current_node_idx = 0
|
| 55 |
+
for node_idx, node in enumerate(tree_list):
|
| 56 |
+
if node['parent_id'] == -1:
|
| 57 |
+
queue.append(node['id'])
|
| 58 |
+
order[node_idx] = current_node_idx
|
| 59 |
+
current_node_idx += 1
|
| 60 |
+
break
|
| 61 |
+
while len(queue) > 0:
|
| 62 |
+
current_node = queue.pop(0)
|
| 63 |
+
for node_idx, node in enumerate(tree_list):
|
| 64 |
+
if node['parent_id'] == current_node:
|
| 65 |
+
queue.append(node['id'])
|
| 66 |
+
order[node_idx] = current_node_idx
|
| 67 |
+
current_node_idx += 1
|
| 68 |
+
|
| 69 |
+
return order
|
| 70 |
+
|
| 71 |
+
def bfs_tree(tree_list, aabb_list, epsilon=1e-3):
|
| 72 |
+
# 初始化遍历顺序列表
|
| 73 |
+
order = [0] * len(tree_list)
|
| 74 |
+
current_order = 0
|
| 75 |
+
|
| 76 |
+
# 构建父节点到子节点的索引映射
|
| 77 |
+
parent_map = defaultdict(list)
|
| 78 |
+
for idx, node in enumerate(tree_list):
|
| 79 |
+
parent_map[node['parent_id']].append(idx)
|
| 80 |
+
|
| 81 |
+
# 查找根节点
|
| 82 |
+
root_indices = [idx for idx, node in enumerate(tree_list) if node['parent_id'] == -1]
|
| 83 |
+
if not root_indices:
|
| 84 |
+
return order
|
| 85 |
+
|
| 86 |
+
# 初始化队列(存储节点索引)
|
| 87 |
+
queue = [root_indices[0]]
|
| 88 |
+
order[root_indices[0]] = current_order
|
| 89 |
+
current_order += 1
|
| 90 |
+
|
| 91 |
+
# 比较函数:按中心坐标排序
|
| 92 |
+
def compare_centers(a, b):
|
| 93 |
+
# 获取两个节点的中心坐标
|
| 94 |
+
center_a = [(aabb_list[a][i] + aabb_list[a][i+3])/2 for i in range(3)]
|
| 95 |
+
center_b = [(aabb_list[b][i] + aabb_list[b][i+3])/2 for i in range(3)]
|
| 96 |
+
|
| 97 |
+
# 逐级比较坐标(考虑epsilon阈值)
|
| 98 |
+
for coord in range(3):
|
| 99 |
+
delta = abs(center_a[coord] - center_b[coord])
|
| 100 |
+
if delta > epsilon:
|
| 101 |
+
return -1 if center_a[coord] < center_b[coord] else 1
|
| 102 |
+
return 0 # 所有坐标差均小于阈值时保持原顺序
|
| 103 |
+
|
| 104 |
+
# BFS遍历
|
| 105 |
+
while queue:
|
| 106 |
+
current_idx = queue.pop(0)
|
| 107 |
+
current_id = tree_list[current_idx]['id']
|
| 108 |
+
|
| 109 |
+
# 获取子节点索引并排序
|
| 110 |
+
children = parent_map.get(current_id, [])
|
| 111 |
+
sorted_children = sorted(children, key=cmp_to_key(compare_centers))
|
| 112 |
+
|
| 113 |
+
# 处理子节点
|
| 114 |
+
for child_idx in sorted_children:
|
| 115 |
+
order[child_idx] = current_order
|
| 116 |
+
current_order += 1
|
| 117 |
+
queue.append(child_idx)
|
| 118 |
+
|
| 119 |
+
return order
|
| 120 |
+
|
| 121 |
+
class BaseDataset(Dataset):
|
| 122 |
+
def __init__(self, hparams):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.hparams = hparams
|
| 125 |
+
|
| 126 |
+
def _filter_models(self, models_ids):
|
| 127 |
+
'''
|
| 128 |
+
Filter out models that has more than K nodes.
|
| 129 |
+
'''
|
| 130 |
+
json_data_root = self.hparams.json_root
|
| 131 |
+
filtered = []
|
| 132 |
+
for i, model_id in enumerate(models_ids):
|
| 133 |
+
if i % 100 == 0:
|
| 134 |
+
print(f'Checking model {i}/{len(models_ids)}')
|
| 135 |
+
path = os.path.join(json_data_root, model_id, self.json_name)
|
| 136 |
+
with open(path, 'r') as f:
|
| 137 |
+
json_file = json.load(f)
|
| 138 |
+
if len(json_file['diffuse_tree']) <= self.hparams.K:
|
| 139 |
+
filtered.append(model_id)
|
| 140 |
+
return filtered
|
| 141 |
+
|
| 142 |
+
def get_acd_mapping(self):
|
| 143 |
+
self.category_mapping = {
|
| 144 |
+
'armoire': 'StorageFurniture',
|
| 145 |
+
'bookcase': 'StorageFurniture',
|
| 146 |
+
'chest_of_drawers': 'StorageFurniture',
|
| 147 |
+
'desk': 'Table',
|
| 148 |
+
'dishwasher': 'Dishwasher',
|
| 149 |
+
'hanging_cabinet': 'StorageFurniture',
|
| 150 |
+
'kitchen_cabinet': 'StorageFurniture',
|
| 151 |
+
'microwave': 'Microwave',
|
| 152 |
+
'nightstand': 'StorageFurniture',
|
| 153 |
+
'oven': 'Oven',
|
| 154 |
+
'refrigerator': 'Refrigerator',
|
| 155 |
+
'sink_cabinet': 'StorageFurniture',
|
| 156 |
+
'tv_stand': 'StorageFurniture',
|
| 157 |
+
'washer': 'WashingMachine',
|
| 158 |
+
'table': 'Table',
|
| 159 |
+
'cabinet': 'StorageFurniture',
|
| 160 |
+
'hanging_cabinet': 'StorageFurniture',
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
def _random_permute(self, graph, nodes):
|
| 164 |
+
'''
|
| 165 |
+
Function to randomly permute the nodes and update the graph and node attribute info.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
graph: a dictionary containing the adjacency matrix, edge list, and root node
|
| 169 |
+
nodes: a list of nodes
|
| 170 |
+
Returns:
|
| 171 |
+
graph_permuted: a dictionary containing the updated adjacency matrix, edge list, and root node
|
| 172 |
+
nodes_permuted: a list of permuted nodes
|
| 173 |
+
'''
|
| 174 |
+
N = len(nodes)
|
| 175 |
+
order = np.random.permutation(N)
|
| 176 |
+
graph_permuted = self._reorder_nodes(graph, order)
|
| 177 |
+
exchange = [0] * len(order)
|
| 178 |
+
for i in range(len(order)):
|
| 179 |
+
exchange[order[i]] = i
|
| 180 |
+
nodes_permuted = nodes[exchange, :]
|
| 181 |
+
return graph_permuted, nodes_permuted
|
| 182 |
+
|
| 183 |
+
def _permute_by_order(self, graph, nodes, order):
|
| 184 |
+
'''
|
| 185 |
+
Function to permute the nodes and update the graph and node attribute info by order.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
graph: a dictionary containing the adjacency matrix, edge list, and root node
|
| 189 |
+
nodes: a list of nodes
|
| 190 |
+
order: a list of indices for reordering
|
| 191 |
+
Returns:
|
| 192 |
+
graph_permuted: a dictionary containing the updated adjacency matrix, edge list, and root node
|
| 193 |
+
nodes_permuted: a list of permuted nodes
|
| 194 |
+
'''
|
| 195 |
+
graph_permuted = self._reorder_nodes(graph, order)
|
| 196 |
+
if nodes is None:
|
| 197 |
+
return graph_permuted, None
|
| 198 |
+
else:
|
| 199 |
+
exchange = [0] * len(order)
|
| 200 |
+
for i in range(len(order)):
|
| 201 |
+
exchange[order[i]] = i
|
| 202 |
+
nodes_permuted = nodes[exchange, :]
|
| 203 |
+
return graph_permuted, nodes_permuted
|
| 204 |
+
|
| 205 |
+
def _prepare_node_data(self, node):
|
| 206 |
+
# semantic label
|
| 207 |
+
label = np.array([sem_ref['fwd'][node['name']]], dtype=np.float32) / 5. - 0.8 # (1,), range from -0.8 to 0.8
|
| 208 |
+
# joint type
|
| 209 |
+
joint_type = np.array([joint_ref['fwd'][node['joint']['type']] / 5.], dtype=np.float32) - 0.5 # (1,), range from -0.8 to 0.8
|
| 210 |
+
# aabb
|
| 211 |
+
aabb_center = np.array(node['aabb']['center'], dtype=np.float32) # (3,), range from -1 to 1
|
| 212 |
+
aabb_size = np.array(node['aabb']['size'], dtype=np.float32) # (3,), range from -1 to 1
|
| 213 |
+
aabb_max = aabb_center + aabb_size / 2
|
| 214 |
+
aabb_min = aabb_center - aabb_size / 2
|
| 215 |
+
# joint axis and range
|
| 216 |
+
if node['joint']['type'] == 'fixed':
|
| 217 |
+
axis_dir = np.zeros((3,), dtype=np.float32)
|
| 218 |
+
axis_ori = aabb_center
|
| 219 |
+
joint_range = np.zeros((2,), dtype=np.float32)
|
| 220 |
+
else:
|
| 221 |
+
if node['joint']['type'] == 'revolute' or node['joint']['type'] == 'continuous':
|
| 222 |
+
joint_range = np.array([node['joint']['range'][1]], dtype=np.float32) / 360.
|
| 223 |
+
joint_range = np.concatenate([joint_range, np.zeros((1,), dtype=np.float32)], axis=0) # (2,)
|
| 224 |
+
elif node['joint']['type'] == 'prismatic' or node['joint']['type'] == 'screw':
|
| 225 |
+
joint_range = np.array([node['joint']['range'][1]], dtype=np.float32)
|
| 226 |
+
joint_range = np.concatenate([np.zeros((1,), dtype=np.float32), joint_range], axis=0) # (2,)
|
| 227 |
+
axis_dir = np.array(node['joint']['axis']['direction'], dtype=np.float32) * 0.7 # (3,), range from -0.7 to 0.7
|
| 228 |
+
# make sure the axis is pointing to the positive direction
|
| 229 |
+
if np.sum(axis_dir > 0) < np.sum(-axis_dir > 0):
|
| 230 |
+
axis_dir = -axis_dir
|
| 231 |
+
joint_range = -joint_range
|
| 232 |
+
axis_ori = np.array(node['joint']['axis']['origin'], dtype=np.float32) # (3,), range from -1 to 1
|
| 233 |
+
if (node['joint']['type'] == 'prismatic' or node['joint']['type'] == 'screw') and node['name'] != 'door':
|
| 234 |
+
axis_ori = aabb_center
|
| 235 |
+
# prepare node data by given mod name
|
| 236 |
+
# aabb = np.concatenate([aabb_max, aabb_min], axis=0)
|
| 237 |
+
# axis = np.concatenate([axis_dir, axis_ori], axis=0)
|
| 238 |
+
# node_data_all = [aabb, joint_type.repeat(6), axis, joint_range.repeat(3), label.repeat(6)]
|
| 239 |
+
# node_data_list = [node_data_all[data_mode_ref[mod_name]] for mod_name in self.hparams.data_mode]
|
| 240 |
+
# node_data = np.concatenate(node_data_list, axis=0)
|
| 241 |
+
node_label = np.ones(6, dtype=np.float32)
|
| 242 |
+
|
| 243 |
+
node_data = np.concatenate([aabb_max, aabb_min, joint_type.repeat(6), axis_dir, axis_ori, joint_range.repeat(3), label.repeat(6), node_label], axis=0)
|
| 244 |
+
if self.hparams.mode_num == 5:
|
| 245 |
+
node_data = np.concatenate([aabb_max, aabb_min, joint_type.repeat(6), axis_dir, axis_ori, joint_range.repeat(3), label.repeat(6)], axis=0)
|
| 246 |
+
return node_data
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def _reorder_nodes(self, graph, order):
|
| 250 |
+
'''
|
| 251 |
+
Function to reorder nodes in the graph and
|
| 252 |
+
update the adjacency matrix, edge list, and root node.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
graph: a dictionary containing the adjacency matrix, edge list, and root node
|
| 256 |
+
order: a list of indices for reordering
|
| 257 |
+
Returns:
|
| 258 |
+
new_graph: a dictionary containing the updated adjacency matrix, edge list, and root node
|
| 259 |
+
'''
|
| 260 |
+
N = len(order)
|
| 261 |
+
mapping = {i: order[i] for i in range(N)}
|
| 262 |
+
mapping.update({i: i for i in range(N, self.hparams.K)})
|
| 263 |
+
G = nx.from_numpy_array(graph['adj'], create_using=nx.Graph)
|
| 264 |
+
G_ = nx.relabel_nodes(G, mapping)
|
| 265 |
+
new_adj = nx.adjacency_matrix(G_, G.nodes).todense()
|
| 266 |
+
|
| 267 |
+
exchange = [0] * len(order)
|
| 268 |
+
for i in range(len(order)):
|
| 269 |
+
exchange[order[i]] = i
|
| 270 |
+
return {
|
| 271 |
+
'adj': new_adj.astype(np.float32),
|
| 272 |
+
'parents': graph['parents'][exchange]
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def _prepare_input_GT(self, file, model_id):
|
| 277 |
+
'''
|
| 278 |
+
Function to parse input item from a json file for the CAGE training.
|
| 279 |
+
'''
|
| 280 |
+
tree = file['diffuse_tree']
|
| 281 |
+
K = self.hparams.K # max number of nodes
|
| 282 |
+
cond = {} # conditional information and axillary data
|
| 283 |
+
cond['parents'] = np.zeros(K, dtype=np.int8)
|
| 284 |
+
|
| 285 |
+
# prepare node data
|
| 286 |
+
nodes = []
|
| 287 |
+
for node in tree:
|
| 288 |
+
node_data = self._prepare_node_data(node) # (36,)
|
| 289 |
+
nodes.append(node_data)
|
| 290 |
+
nodes = np.array(nodes, dtype=np.float32)
|
| 291 |
+
n_nodes = len(nodes)
|
| 292 |
+
|
| 293 |
+
# prepare graph
|
| 294 |
+
graph = build_graph(tree, self.hparams.K)
|
| 295 |
+
if self.mode == 'train': # perturb the node order for training
|
| 296 |
+
graph, nodes = self._random_permute(graph, nodes)
|
| 297 |
+
|
| 298 |
+
# pad the nodes to K with empty nodes
|
| 299 |
+
if n_nodes < K:
|
| 300 |
+
empty_node = np.zeros((nodes[0].shape[0],))
|
| 301 |
+
data = np.concatenate([nodes, [empty_node] * (K - n_nodes)], axis=0, dtype=np.float32) # (K, 36)
|
| 302 |
+
else:
|
| 303 |
+
data = nodes
|
| 304 |
+
mode_num = data.shape[1] // 6
|
| 305 |
+
data = data.reshape(K*mode_num, 6) # (K * n_attr, 6)
|
| 306 |
+
|
| 307 |
+
# attr mask (for Local Attention)
|
| 308 |
+
attr_mask = np.eye(K, K, dtype=bool)
|
| 309 |
+
attr_mask = attr_mask.repeat(mode_num, axis=0).repeat(mode_num, axis=1)
|
| 310 |
+
cond['attr_mask'] = attr_mask
|
| 311 |
+
|
| 312 |
+
# key padding mask (for Global Attention)
|
| 313 |
+
pad_mask = np.zeros((K*mode_num, K*mode_num), dtype=bool)
|
| 314 |
+
pad_mask[:, :n_nodes*mode_num] = 1
|
| 315 |
+
cond['key_pad_mask'] = pad_mask
|
| 316 |
+
|
| 317 |
+
# adj mask (for Graph Relation Attention)
|
| 318 |
+
adj_mask = graph['adj'][:].astype(bool)
|
| 319 |
+
adj_mask = adj_mask.repeat(mode_num, axis=0).repeat(mode_num, axis=1)
|
| 320 |
+
adj_mask[n_nodes*mode_num:, :] = 1
|
| 321 |
+
cond['adj_mask'] = adj_mask
|
| 322 |
+
|
| 323 |
+
# object category
|
| 324 |
+
if self.map_cat: # for ACD dataset
|
| 325 |
+
category = file['meta']['obj_cat']
|
| 326 |
+
category = self.category_mapping[category]
|
| 327 |
+
cond['cat'] = cat_ref[category]
|
| 328 |
+
else:
|
| 329 |
+
cond['cat'] = cat_ref.get(file['meta']['obj_cat'], None)
|
| 330 |
+
if cond['cat'] is None:
|
| 331 |
+
cond['cat'] = self.category_mapping.get(file['meta']['obj_cat'], None)
|
| 332 |
+
if cond['cat'] is None:
|
| 333 |
+
cond['cat'] = 2
|
| 334 |
+
else:
|
| 335 |
+
cond['cat'] = cat_ref.get(cond['cat'], None)
|
| 336 |
+
# cond['cat'] = cat_ref[file['meta']['obj_cat']]
|
| 337 |
+
if cond['cat'] is None:
|
| 338 |
+
cond['cat'] = 2
|
| 339 |
+
# axillary info
|
| 340 |
+
cond['name'] = model_id
|
| 341 |
+
cond['adj'] = graph['adj']
|
| 342 |
+
cond['parents'][:n_nodes] = graph['parents']
|
| 343 |
+
cond['n_nodes'] = n_nodes
|
| 344 |
+
cond['obj_cat'] = file['meta']['obj_cat']
|
| 345 |
+
|
| 346 |
+
return data, cond
|
| 347 |
+
|
| 348 |
+
def _prepare_input(self, model_id, pred_file, gt_file=None):
|
| 349 |
+
'''
|
| 350 |
+
Function to parse input item from pred_file, and parse GT from gt_file if available.
|
| 351 |
+
'''
|
| 352 |
+
K = self.hparams.K # max number of nodes
|
| 353 |
+
cond = {} # conditional information and axillary data
|
| 354 |
+
# prepare node data
|
| 355 |
+
n_nodes = len(pred_file['diffuse_tree'])
|
| 356 |
+
# prepare graph
|
| 357 |
+
pred_graph = build_graph(pred_file['diffuse_tree'], K)
|
| 358 |
+
# dummy GT data
|
| 359 |
+
data = np.zeros((K*5, 6), dtype=np.float32)
|
| 360 |
+
|
| 361 |
+
# attr mask (for Local Attention)
|
| 362 |
+
attr_mask = np.eye(K, K, dtype=bool)
|
| 363 |
+
attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
|
| 364 |
+
cond['attr_mask'] = attr_mask
|
| 365 |
+
|
| 366 |
+
# key padding mask (for Global Attention)
|
| 367 |
+
pad_mask = np.zeros((K*5, K*5), dtype=bool)
|
| 368 |
+
pad_mask[:, :n_nodes*5] = 1
|
| 369 |
+
cond['key_pad_mask'] = pad_mask
|
| 370 |
+
|
| 371 |
+
# adj mask (for Graph Relation Attention)
|
| 372 |
+
adj_mask = pred_graph['adj'][:].astype(bool)
|
| 373 |
+
adj_mask = adj_mask.repeat(5, axis=0).repeat(5, axis=1)
|
| 374 |
+
adj_mask[n_nodes*5:, :] = 1
|
| 375 |
+
cond['adj_mask'] = adj_mask
|
| 376 |
+
|
| 377 |
+
# placeholder category, won't be used if category is given (below)
|
| 378 |
+
cond['cat'] = cat_ref['StorageFurniture']
|
| 379 |
+
cond['obj_cat'] = 'StorageFurniture'
|
| 380 |
+
# if object category is given as input
|
| 381 |
+
if not self.hparams.get('test_label_free', False):
|
| 382 |
+
assert 'meta' in pred_file, 'meta not found in the json file.'
|
| 383 |
+
assert 'obj_cat' in pred_file['meta'], 'obj_cat not found in the metadata of the json file.'
|
| 384 |
+
category = pred_file['meta']['obj_cat']
|
| 385 |
+
if self.map_cat: # for ACD dataset
|
| 386 |
+
category = self.category_mapping[category]
|
| 387 |
+
cond['cat'] = cat_ref[category]
|
| 388 |
+
cond['obj_cat'] = category
|
| 389 |
+
|
| 390 |
+
# axillary info
|
| 391 |
+
cond['name'] = model_id
|
| 392 |
+
cond['adj'] = pred_graph['adj']
|
| 393 |
+
cond['parents'] = np.zeros(K, dtype=np.int8)
|
| 394 |
+
cond['parents'][:n_nodes] = pred_graph['parents']
|
| 395 |
+
cond['n_nodes'] = n_nodes
|
| 396 |
+
|
| 397 |
+
return data, cond
|
| 398 |
+
|
| 399 |
+
def __getitem__(self, index):
|
| 400 |
+
raise NotImplementedError
|
| 401 |
+
|
| 402 |
+
def __len__(self):
|
| 403 |
+
raise NotImplementedError
|
| 404 |
+
|
dataset/data_module.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 3 |
+
import json
|
| 4 |
+
import dataset
|
| 5 |
+
import lightning.pytorch as pl
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from dataset.mydataset import MyDataset
|
| 8 |
+
|
| 9 |
+
@dataset.register("dm_dipo")
|
| 10 |
+
class DIPODataModule(pl.LightningDataModule):
|
| 11 |
+
|
| 12 |
+
def __init__(self, hparams):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.hparams.update(hparams)
|
| 15 |
+
|
| 16 |
+
def _prepare_split(self):
|
| 17 |
+
with open(self.hparams.split_file , "r") as f:
|
| 18 |
+
splits = json.load(f)
|
| 19 |
+
|
| 20 |
+
train_ids = splits["train"]
|
| 21 |
+
val_ids = [i for i in train_ids if "data" not in i]
|
| 22 |
+
return train_ids, val_ids
|
| 23 |
+
|
| 24 |
+
def _prepare_test_ids(self):
|
| 25 |
+
if "acd" in self.hparams.get('test_which'):
|
| 26 |
+
with open("/home/users/ruiqi.wu/singapo/data/data_acd.json", "r") as f:
|
| 27 |
+
file = json.load(f)
|
| 28 |
+
elif 'pm' in self.hparams.get('test_which'):
|
| 29 |
+
with open(self.hparams.split_file, "r") as f:
|
| 30 |
+
file = json.load(f)
|
| 31 |
+
else:
|
| 32 |
+
raise NotImplementedError(f"Dataset {self.hparams.get('test_which')} not implemented for SingapoDataModule")
|
| 33 |
+
ids = file['test']
|
| 34 |
+
return ids
|
| 35 |
+
|
| 36 |
+
def setup(self, stage=None):
|
| 37 |
+
|
| 38 |
+
if stage == "fit" or stage is None:
|
| 39 |
+
train_ids, val_ids = self._prepare_split()
|
| 40 |
+
val_ids = val_ids
|
| 41 |
+
self.train_dataset = MyDataset(self.hparams, model_ids=train_ids[:10], mode="train")
|
| 42 |
+
self.val_dataset = MyDataset(self.hparams, model_ids=val_ids[:50], mode="val")
|
| 43 |
+
elif stage == "validate":
|
| 44 |
+
val_ids = self._prepare_test_ids()
|
| 45 |
+
val_ids = val_ids
|
| 46 |
+
self.val_dataset = MyDataset(self.hparams, model_ids=val_ids, mode="val")
|
| 47 |
+
elif stage == "test":
|
| 48 |
+
test_ids = self._prepare_test_ids()
|
| 49 |
+
self.test_dataset = MyDataset(self.hparams, model_ids=test_ids, mode="test")
|
| 50 |
+
else:
|
| 51 |
+
raise NotImplementedError(f"Stage {stage} not implemented for SingapoDataModule")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def train_dataloader(self):
|
| 55 |
+
return DataLoader(
|
| 56 |
+
self.train_dataset,
|
| 57 |
+
batch_size=self.hparams.batch_size,
|
| 58 |
+
num_workers=self.hparams.num_workers,
|
| 59 |
+
pin_memory=True,
|
| 60 |
+
shuffle=True,
|
| 61 |
+
persistent_workers=True
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def val_dataloader(self):
|
| 65 |
+
return DataLoader(
|
| 66 |
+
self.val_dataset,
|
| 67 |
+
batch_size=128,
|
| 68 |
+
num_workers=self.hparams.num_workers,
|
| 69 |
+
pin_memory=True,
|
| 70 |
+
shuffle=False,
|
| 71 |
+
persistent_workers=True
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def test_dataloader(self):
|
| 75 |
+
return DataLoader(
|
| 76 |
+
self.test_dataset,
|
| 77 |
+
batch_size=1,
|
| 78 |
+
num_workers=self.hparams.num_workers,
|
| 79 |
+
pin_memory=True,
|
| 80 |
+
shuffle=False,
|
| 81 |
+
persistent_workers=True
|
| 82 |
+
)
|
dataset/mydataset.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torchvision.transforms as T
|
| 7 |
+
from dataset.base_dataset import BaseDataset
|
| 8 |
+
import random
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import imageio
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
def make_white_background(src_img):
|
| 14 |
+
'''Make the white background for the input RGBA image.'''
|
| 15 |
+
src_img.load()
|
| 16 |
+
background = Image.new("RGB", src_img.size, (255, 255, 255))
|
| 17 |
+
background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
|
| 18 |
+
return background
|
| 19 |
+
|
| 20 |
+
class MyDataset(BaseDataset):
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
Dataset for training and testing on the PartNet-Mobility and ACD datasets (with our preprocessing).
|
| 24 |
+
The GT graph is given.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, hparams, model_ids, mode="train", json_name="object.json"):
|
| 28 |
+
self.hparams = hparams
|
| 29 |
+
self.json_name = json_name
|
| 30 |
+
self.model_ids = self._filter_models(model_ids)
|
| 31 |
+
self.mode = mode
|
| 32 |
+
self.map_cat = False
|
| 33 |
+
self.get_acd_mapping()
|
| 34 |
+
|
| 35 |
+
self.no_GT = (
|
| 36 |
+
True if self.hparams.get("test_no_GT", False) and self.hparams.get("test_pred_G", False)
|
| 37 |
+
else False
|
| 38 |
+
)
|
| 39 |
+
self.pred_G = (
|
| 40 |
+
False
|
| 41 |
+
if mode in ["train", "val"]
|
| 42 |
+
else self.hparams.get("test_pred_G", False)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if mode == 'test':
|
| 46 |
+
if "acd" in hparams.test_which:
|
| 47 |
+
self.map_cat = True
|
| 48 |
+
|
| 49 |
+
self.files = self._cache_data()
|
| 50 |
+
print(f"[INFO] {mode} dataset: {len(self)} data samples loaded.")
|
| 51 |
+
|
| 52 |
+
def _cache_data_train(self):
|
| 53 |
+
json_data_root = self.hparams.json_root
|
| 54 |
+
data_root = self.hparams.root
|
| 55 |
+
# number of views per model and in total
|
| 56 |
+
n_views_per_model = self.hparams.n_views_per_model
|
| 57 |
+
n_views = n_views_per_model * len(self.model_ids)
|
| 58 |
+
# json files for each model
|
| 59 |
+
json_files = []
|
| 60 |
+
# mapping to the index of the corresponding model in json_files
|
| 61 |
+
model_mappings = []
|
| 62 |
+
# space for dinov2 patch features
|
| 63 |
+
feats = np.empty((n_views, 512, 768), dtype=np.float16)
|
| 64 |
+
# space for object masks on image patches
|
| 65 |
+
obj_masks = np.empty((n_views, 256), dtype=bool)
|
| 66 |
+
# input images (not required in training)
|
| 67 |
+
imgs = None
|
| 68 |
+
# load data for non-aug views
|
| 69 |
+
i = 0 # index for views
|
| 70 |
+
for j, model_id in enumerate(self.model_ids):
|
| 71 |
+
print(model_id)
|
| 72 |
+
# if j % 10 == 0 and torch.distributed.get_rank() == 0:
|
| 73 |
+
# print(f"\rLoading training data: {j}/{len(self.model_ids)}")
|
| 74 |
+
# 3D data
|
| 75 |
+
with open(os.path.join(json_data_root, model_id, self.json_name), "r") as f:
|
| 76 |
+
json_file = json.load(f)
|
| 77 |
+
json_files.append(json_file)
|
| 78 |
+
filenames = os.listdir(os.path.join(data_root, model_id, 'features'))
|
| 79 |
+
filenames = [f for f in filenames if 'high_res' not in f]
|
| 80 |
+
filenames = filenames[:self.hparams.n_views_per_model]
|
| 81 |
+
for filename in filenames:
|
| 82 |
+
view_feat = np.load(os.path.join(data_root, model_id, 'features', filename))
|
| 83 |
+
first_frame_feat = view_feat[0]
|
| 84 |
+
if self.hparams.frame_mode == 'last_frame':
|
| 85 |
+
second_frame_feat = view_feat[-2]
|
| 86 |
+
elif self.hparams.frame_mode == 'random_state_frame':
|
| 87 |
+
second_frame_feat = view_feat[-1]
|
| 88 |
+
else:
|
| 89 |
+
raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame")
|
| 90 |
+
feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16)
|
| 91 |
+
feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16)
|
| 92 |
+
i = i + 1
|
| 93 |
+
model_mappings += [j] * n_views_per_model
|
| 94 |
+
# object masks for all views
|
| 95 |
+
# all_obj_masks = np.load(
|
| 96 |
+
# os.path.join(json_data_root, model_id, "features/patch_obj_masks.npy")
|
| 97 |
+
# ) # (20, Np)
|
| 98 |
+
# obj_masks[i : i + n_views_per_model] = all_obj_masks[:n_views_per_model]
|
| 99 |
+
return {
|
| 100 |
+
"len": n_views,
|
| 101 |
+
"gt_files": json_files,
|
| 102 |
+
"features": feats,
|
| 103 |
+
"obj_masks": None,
|
| 104 |
+
"model_mappings": model_mappings,
|
| 105 |
+
"imgs": imgs,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
def _cache_data_non_train(self):
|
| 109 |
+
# number of views per model and in total
|
| 110 |
+
n_views_per_model = 2
|
| 111 |
+
n_views = n_views_per_model * len(self.model_ids)
|
| 112 |
+
# json files for each model
|
| 113 |
+
gt_files = []
|
| 114 |
+
pred_files = [] # for predicted graphs
|
| 115 |
+
# mapping to the index of the corresponding model in json_files
|
| 116 |
+
model_mappings = []
|
| 117 |
+
# space for dinov2 patch features
|
| 118 |
+
feats = np.empty((n_views, 512, 768), dtype=np.float16)
|
| 119 |
+
# space for input images
|
| 120 |
+
first_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8)
|
| 121 |
+
second_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8)
|
| 122 |
+
# transformation for input images
|
| 123 |
+
transform = T.Compose(
|
| 124 |
+
[
|
| 125 |
+
T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
|
| 126 |
+
T.CenterCrop(224),
|
| 127 |
+
T.Resize(128, interpolation=T.InterpolationMode.BICUBIC),
|
| 128 |
+
]
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
i = 0 # index for views
|
| 132 |
+
desc = f'Loading {self.mode} data'
|
| 133 |
+
for j, model_id in tqdm(enumerate(self.model_ids), total=len(self.model_ids), desc=desc):
|
| 134 |
+
with open(os.path.join(self.hparams.json_root, model_id, self.json_name), "r") as f:
|
| 135 |
+
json_file = json.load(f)
|
| 136 |
+
gt_files.append(json_file)
|
| 137 |
+
# filename_dir = os.path.join(self.hparams.root, model_id, 'features')
|
| 138 |
+
for filename in ['18.npy', '19.npy']:
|
| 139 |
+
view_feat = np.load(os.path.join(self.hparams.root, model_id, 'features', filename))
|
| 140 |
+
first_frame_feat = view_feat[0]
|
| 141 |
+
if self.hparams.frame_mode == 'last_frame':
|
| 142 |
+
second_frame_feat = view_feat[-2]
|
| 143 |
+
elif self.hparams.frame_mode == 'random_state_frame':
|
| 144 |
+
second_frame_feat = view_feat[-1]
|
| 145 |
+
else:
|
| 146 |
+
raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame")
|
| 147 |
+
feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16)
|
| 148 |
+
feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16)
|
| 149 |
+
|
| 150 |
+
video_path = os.path.join(self.hparams.root, model_id, 'imgs', 'animation_' + filename.replace('.npy', '.mp4'))
|
| 151 |
+
reader = imageio.get_reader(video_path)
|
| 152 |
+
frames = []
|
| 153 |
+
for frame in reader:
|
| 154 |
+
frames.append(frame)
|
| 155 |
+
reader.close()
|
| 156 |
+
|
| 157 |
+
first_img = Image.fromarray(frames[0])
|
| 158 |
+
if first_img.mode == 'RGBA':
|
| 159 |
+
first_img = make_white_background(first_img)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
first_img = np.asarray(transform(first_img), dtype=np.int8)
|
| 163 |
+
first_imgs[i] = first_img
|
| 164 |
+
|
| 165 |
+
if self.hparams.frame_mode == 'last_frame':
|
| 166 |
+
second_img = Image.fromarray(frames[-1])
|
| 167 |
+
elif self.hparams.frame_mode == 'random_state_frame':
|
| 168 |
+
second_img_path = video_path.replace('animation', 'random').replace('.mp4', '.png')
|
| 169 |
+
second_img = Image.open(second_img_path)
|
| 170 |
+
if second_img.mode == 'RGBA':
|
| 171 |
+
second_img = make_white_background(second_img)
|
| 172 |
+
second_img = np.asarray(transform(second_img), dtype=np.int8)
|
| 173 |
+
second_imgs[i] = second_img
|
| 174 |
+
|
| 175 |
+
i = i + 1
|
| 176 |
+
# mapping to json file
|
| 177 |
+
model_mappings += [j] * n_views_per_model
|
| 178 |
+
|
| 179 |
+
return {
|
| 180 |
+
"len": n_views,
|
| 181 |
+
"gt_files": gt_files,
|
| 182 |
+
"pred_files": pred_files,
|
| 183 |
+
"features": feats,
|
| 184 |
+
"model_mappings": model_mappings,
|
| 185 |
+
"imgs": [first_imgs, second_imgs],
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
def _cache_data(self):
|
| 189 |
+
"""
|
| 190 |
+
Function to cache data from disk.
|
| 191 |
+
"""
|
| 192 |
+
if self.mode == "train":
|
| 193 |
+
return self._cache_data_train()
|
| 194 |
+
else:
|
| 195 |
+
return self._cache_data_non_train()
|
| 196 |
+
|
| 197 |
+
def _get_item_train_val(self, index):
|
| 198 |
+
model_i = self.files["model_mappings"][index]
|
| 199 |
+
gt_file = self.files["gt_files"][model_i]
|
| 200 |
+
data, cond = self._prepare_input_GT(
|
| 201 |
+
file=gt_file, model_id=self.model_ids[model_i]
|
| 202 |
+
)
|
| 203 |
+
if self.mode == "val":
|
| 204 |
+
# input image for visualization
|
| 205 |
+
img_first = self.files["imgs"][0][index]
|
| 206 |
+
img_last = self.files["imgs"][1][index]
|
| 207 |
+
cond["img"] = np.concatenate([img_first, img_last], axis=1)
|
| 208 |
+
# else:
|
| 209 |
+
# # object masks on patches
|
| 210 |
+
# # obj_mask = self.files["obj_masks"][index][None, ...].repeat(self.hparams.K * 5, axis=0)
|
| 211 |
+
# cond["img_obj_mask"] = [None]
|
| 212 |
+
return data, cond
|
| 213 |
+
|
| 214 |
+
def _get_item_test(self, index):
|
| 215 |
+
model_i = self.files["model_mappings"][index]
|
| 216 |
+
|
| 217 |
+
gt_file = None if self.no_GT else self.files["gt_files"][model_i]
|
| 218 |
+
|
| 219 |
+
if self.hparams.get('G_dir', None) is None:
|
| 220 |
+
data, cond = self._prepare_input_GT(file=gt_file, model_id=self.model_ids[model_i])
|
| 221 |
+
else:
|
| 222 |
+
if index % 2 == 0:
|
| 223 |
+
filename = '18.json'
|
| 224 |
+
else:
|
| 225 |
+
filename = '19.json'
|
| 226 |
+
pred_file_path = os.path.join(self.hparams.G_dir, self.model_ids[model_i], filename)
|
| 227 |
+
with open(pred_file_path, "r") as f:
|
| 228 |
+
pred_file = json.load(f)
|
| 229 |
+
data, cond = self._prepare_input(model_id=self.model_ids[model_i], pred_file=pred_file, gt_file=gt_file)
|
| 230 |
+
# input image for visualization
|
| 231 |
+
img_first = self.files["imgs"][0][index]
|
| 232 |
+
img_last = self.files["imgs"][1][index]
|
| 233 |
+
cond["img"] = np.concatenate([img_first, img_last], axis=1)
|
| 234 |
+
return data, cond
|
| 235 |
+
|
| 236 |
+
def __getitem__(self, index):
|
| 237 |
+
# input image features
|
| 238 |
+
feat = self.files["features"][index]
|
| 239 |
+
|
| 240 |
+
# prepare input, GT data and other axillary info
|
| 241 |
+
if self.mode == "test":
|
| 242 |
+
data, cond = self._get_item_test(index)
|
| 243 |
+
else:
|
| 244 |
+
data, cond = self._get_item_train_val(index)
|
| 245 |
+
|
| 246 |
+
return data, cond, feat
|
| 247 |
+
|
| 248 |
+
def __len__(self):
|
| 249 |
+
return self.files["len"]
|
| 250 |
+
|
| 251 |
+
if __name__ == '__main__':
|
| 252 |
+
from types import SimpleNamespace
|
| 253 |
+
|
| 254 |
+
class EnhancedNamespace(SimpleNamespace):
|
| 255 |
+
def get(self, key, default=None):
|
| 256 |
+
return getattr(self, key, default)
|
| 257 |
+
|
| 258 |
+
hparams = {
|
| 259 |
+
"name": "dm_singapo",
|
| 260 |
+
"json_root": "/home/users/ruiqi.wu/singapo/", # root directory of the dataset
|
| 261 |
+
"batch_size": 20, # batch size for training
|
| 262 |
+
"num_workers": 8, # number of workers for data loading
|
| 263 |
+
"K": 32, # maximum number of nodes (parts) in the graph (object)
|
| 264 |
+
"split_file": "/home/users/ruiqi.wu/singapo/data/data_split.json",
|
| 265 |
+
"n_views_per_model": 5,
|
| 266 |
+
"root": "/home/users/ruiqi.wu/manipulate_3d_generate/data/blender_version",
|
| 267 |
+
"frame_mode": "last_frame"
|
| 268 |
+
}
|
| 269 |
+
hparams = EnhancedNamespace(**hparams)
|
| 270 |
+
with open(hparams.split_file , "r") as f:
|
| 271 |
+
splits = json.load(f)
|
| 272 |
+
|
| 273 |
+
train_ids = splits["train"]
|
| 274 |
+
val_ids = [i for i in train_ids if "augmented" not in i]
|
| 275 |
+
|
| 276 |
+
val_ids = [val_id for val_id in val_ids if os.path.exists(os.path.join(hparams.root, val_id, "features"))]
|
| 277 |
+
|
| 278 |
+
dataset = MyDataset(hparams, model_ids=val_ids[:20], mode="valid")
|
| 279 |
+
for i in range(20):
|
| 280 |
+
data, cond, feat = dataset.__getitem__(i)
|
| 281 |
+
import ipdb
|
| 282 |
+
ipdb.set_trace()
|
dataset/utils.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from my_utils.refs import joint_ref, sem_ref
|
| 6 |
+
|
| 7 |
+
def rescale_axis(jtype, axis_d, axis_o, box_center):
|
| 8 |
+
'''
|
| 9 |
+
Function to rescale the axis for rendering
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
- jtype (int): joint type
|
| 13 |
+
- axis_d (np.array): axis direction
|
| 14 |
+
- axis_o (np.array): axis origin
|
| 15 |
+
- box_center (np.array): bounding box center
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
- center (np.array): rescaled axis origin
|
| 19 |
+
- axis_d (np.array): rescaled axis direction
|
| 20 |
+
'''
|
| 21 |
+
if jtype == 0 or jtype == 1:
|
| 22 |
+
return [0., 0., 0.], [0., 0., 0.]
|
| 23 |
+
if jtype == 3 or jtype == 4:
|
| 24 |
+
center = box_center
|
| 25 |
+
else:
|
| 26 |
+
center = axis_o + np.dot(axis_d, box_center-axis_o) * axis_d
|
| 27 |
+
return center.tolist(), axis_d.tolist()
|
| 28 |
+
|
| 29 |
+
def make_white_background(src_img):
|
| 30 |
+
'''Make the white background for the input RGBA image.'''
|
| 31 |
+
src_img.load()
|
| 32 |
+
background = Image.new("RGB", src_img.size, (255, 255, 255))
|
| 33 |
+
background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
|
| 34 |
+
return background
|
| 35 |
+
|
| 36 |
+
def build_graph(tree, K=32):
|
| 37 |
+
'''
|
| 38 |
+
Function to build graph from the node list.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
nodes: list of nodes
|
| 42 |
+
K: the maximum number of nodes in the graph
|
| 43 |
+
Returns:
|
| 44 |
+
adj: adjacency matrix, records the 1-ring relationship (parent+children) between nodes
|
| 45 |
+
edge_list: list of edges, for visualization
|
| 46 |
+
'''
|
| 47 |
+
adj = np.zeros((K, K), dtype=np.float32)
|
| 48 |
+
parents = []
|
| 49 |
+
for node in tree:
|
| 50 |
+
# 1-ring relationship
|
| 51 |
+
if node['parent'] != -1:
|
| 52 |
+
adj[node['id'], node['parent']] = 1
|
| 53 |
+
parents.append(node['parent'])
|
| 54 |
+
else:
|
| 55 |
+
adj[node['id'], node['id']] = 1
|
| 56 |
+
parents.append(-1)
|
| 57 |
+
for child_id in node['children']:
|
| 58 |
+
adj[node['id'], child_id] = 1
|
| 59 |
+
|
| 60 |
+
return {
|
| 61 |
+
'adj': adj,
|
| 62 |
+
'parents': np.array(parents, dtype=np.int8)
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def load_input_from(pred_file, K=32):
|
| 66 |
+
'''
|
| 67 |
+
Function to parse input item from a file containing the predicted graph
|
| 68 |
+
'''
|
| 69 |
+
|
| 70 |
+
cond = {} # conditional information and axillary data
|
| 71 |
+
# prepare node data
|
| 72 |
+
n_nodes = len(pred_file['diffuse_tree'])
|
| 73 |
+
# prepare graph
|
| 74 |
+
pred_graph = build_graph(pred_file['diffuse_tree'], K)
|
| 75 |
+
|
| 76 |
+
# attr mask (for Local Attention)
|
| 77 |
+
attr_mask = np.eye(K, K, dtype=bool)
|
| 78 |
+
attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
|
| 79 |
+
cond['attr_mask'] = attr_mask
|
| 80 |
+
|
| 81 |
+
# key padding mask (for Global Attention)
|
| 82 |
+
pad_mask = np.zeros((K*5, K*5), dtype=bool)
|
| 83 |
+
pad_mask[:, :n_nodes*5] = 1
|
| 84 |
+
cond['key_pad_mask'] = pad_mask
|
| 85 |
+
|
| 86 |
+
# adj mask (for Graph Relation Attention)
|
| 87 |
+
adj_mask = pred_graph['adj'][:].astype(bool)
|
| 88 |
+
adj_mask = adj_mask.repeat(5, axis=0).repeat(5, axis=1)
|
| 89 |
+
adj_mask[n_nodes*5:, :] = 1
|
| 90 |
+
cond['adj_mask'] = adj_mask
|
| 91 |
+
|
| 92 |
+
# placeholder
|
| 93 |
+
data = np.zeros((K*5, 6), dtype=bool)
|
| 94 |
+
cond['cat'] = 2
|
| 95 |
+
|
| 96 |
+
# axillary info
|
| 97 |
+
cond['adj'] = pred_graph['adj']
|
| 98 |
+
cond['parents'] = np.zeros(K, dtype=np.int8)
|
| 99 |
+
cond['parents'][:n_nodes] = pred_graph['parents']
|
| 100 |
+
cond['n_nodes'] = n_nodes
|
| 101 |
+
|
| 102 |
+
return data, cond
|
| 103 |
+
|
| 104 |
+
def convert_data_range(x):
|
| 105 |
+
'''postprocessing: convert the raw model output to the original range, following CAGE'''
|
| 106 |
+
x = x.reshape(-1, 30) # (K, 36)
|
| 107 |
+
aabb_max = x[:, 0:3]
|
| 108 |
+
aabb_min = x[:, 3:6]
|
| 109 |
+
center = (aabb_max + aabb_min) / 2.0
|
| 110 |
+
size = (aabb_max - aabb_min).clip(min=5e-3)
|
| 111 |
+
|
| 112 |
+
j_type = np.mean(x[:, 6:12], axis=1)
|
| 113 |
+
j_type = ((j_type + 0.5) * 5).clip(min=1.0, max=5.0).round()
|
| 114 |
+
|
| 115 |
+
axis_d = x[:, 12:15]
|
| 116 |
+
axis_d = axis_d / (
|
| 117 |
+
np.linalg.norm(axis_d, axis=1, keepdims=True) + np.finfo(float).eps
|
| 118 |
+
)
|
| 119 |
+
axis_o = x[:, 15:18]
|
| 120 |
+
|
| 121 |
+
j_range = (x[:, 18:20] + x[:, 20:22] + x[:, 22:24]) / 3
|
| 122 |
+
j_range = j_range.clip(min=-1.0, max=1.0)
|
| 123 |
+
j_range[:, 0] = j_range[:, 0] * 360
|
| 124 |
+
j_range[:, 1] = j_range[:, 1]
|
| 125 |
+
|
| 126 |
+
label = np.mean(x[:, 24:30], axis=1)
|
| 127 |
+
label = ((label + 0.8) * 5).clip(min=0.0, max=7.0).round()
|
| 128 |
+
return {
|
| 129 |
+
"center": center,
|
| 130 |
+
"size": size,
|
| 131 |
+
"type": j_type,
|
| 132 |
+
"axis_d": axis_d,
|
| 133 |
+
"axis_o": axis_o,
|
| 134 |
+
"range": j_range,
|
| 135 |
+
"label": label,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
def parse_tree(data, n_nodes, par, adj):
|
| 139 |
+
tree = []
|
| 140 |
+
# convert to json format
|
| 141 |
+
for i in range(n_nodes):
|
| 142 |
+
node = {"id": i}
|
| 143 |
+
node["name"] = sem_ref["bwd"][int(data["label"][i].item())]
|
| 144 |
+
node["parent"] = int(par[i])
|
| 145 |
+
node["children"] = [
|
| 146 |
+
int(child) for child in np.where(adj[i] == 1)[0] if child != par[i]
|
| 147 |
+
]
|
| 148 |
+
node["aabb"] = {}
|
| 149 |
+
node["aabb"]["center"] = data["center"][i].tolist()
|
| 150 |
+
node["aabb"]["size"] = data["size"][i].tolist()
|
| 151 |
+
node["joint"] = {}
|
| 152 |
+
if node['name'] == 'base':
|
| 153 |
+
node["joint"]["type"] = 'fixed'
|
| 154 |
+
else:
|
| 155 |
+
node["joint"]["type"] = joint_ref["bwd"][int(data["type"][i].item())]
|
| 156 |
+
if node["joint"]["type"] == "fixed":
|
| 157 |
+
node["joint"]["range"] = [0.0, 0.0]
|
| 158 |
+
elif node["joint"]["type"] == "revolute":
|
| 159 |
+
node["joint"]["range"] = [0.0, float(data["range"][i][0])]
|
| 160 |
+
elif node["joint"]["type"] == "continuous":
|
| 161 |
+
node["joint"]["range"] = [0.0, 360.0]
|
| 162 |
+
elif (
|
| 163 |
+
node["joint"]["type"] == "prismatic" or node["joint"]["type"] == "screw"
|
| 164 |
+
):
|
| 165 |
+
node["joint"]["range"] = [0.0, float(data["range"][i][1])]
|
| 166 |
+
node["joint"]["axis"] = {}
|
| 167 |
+
# relocate the axis to visualize well
|
| 168 |
+
axis_o, axis_d = rescale_axis(
|
| 169 |
+
int(data["type"][i].item()),
|
| 170 |
+
data["axis_d"][i],
|
| 171 |
+
data["axis_o"][i],
|
| 172 |
+
data["center"][i],
|
| 173 |
+
)
|
| 174 |
+
node["joint"]["axis"]["direction"] = axis_d
|
| 175 |
+
node["joint"]["axis"]["origin"] = axis_o
|
| 176 |
+
# append node to the tree
|
| 177 |
+
tree.append(node)
|
| 178 |
+
return tree
|
| 179 |
+
|
| 180 |
+
def convert_json(x, c, prefix=''):
|
| 181 |
+
out = {"meta": {}, "diffuse_tree": []}
|
| 182 |
+
n_nodes = c[f"{prefix}n_nodes"][0].item()
|
| 183 |
+
par = c[f"{prefix}parents"][0].cpu().numpy().tolist()
|
| 184 |
+
adj = c[f"{prefix}adj"][0].cpu().numpy()
|
| 185 |
+
np.fill_diagonal(adj, 0) # remove self-loop for the root node
|
| 186 |
+
if f"{prefix}obj_cat" in c:
|
| 187 |
+
out["meta"]["obj_cat"] = c[f"{prefix}obj_cat"][0]
|
| 188 |
+
|
| 189 |
+
# convert the data to original range
|
| 190 |
+
data = convert_data_range(x)
|
| 191 |
+
# parse the tree
|
| 192 |
+
tree = parse_tree(data, n_nodes, par, adj)
|
| 193 |
+
out["diffuse_tree"] = tree
|
| 194 |
+
return out
|
inference.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image, ImageOps
|
| 8 |
+
import imageio
|
| 9 |
+
# from my_utils.plot import viz_graph
|
| 10 |
+
from my_utils.misc import load_config
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
from diffusers import DDPMScheduler
|
| 13 |
+
from models.denoiser import Denoiser
|
| 14 |
+
from scripts.json2urdf import create_urdf_from_json, pybullet_render
|
| 15 |
+
from dataset.utils import make_white_background, load_input_from, convert_data_range, parse_tree
|
| 16 |
+
import models
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
import base64
|
| 20 |
+
from scripts.graph_pred.api import predict_graph_twomode, gpt_infer_image_category
|
| 21 |
+
import subprocess
|
| 22 |
+
import spaces
|
| 23 |
+
import time
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
cat_ref = {
|
| 27 |
+
"Table": 0,
|
| 28 |
+
"Dishwasher": 1,
|
| 29 |
+
"StorageFurniture": 2,
|
| 30 |
+
"Refrigerator": 3,
|
| 31 |
+
"WashingMachine": 4,
|
| 32 |
+
"Microwave": 5,
|
| 33 |
+
"Oven": 6,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def run_retrieve(src_dir, json_name, data_root):
|
| 37 |
+
fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root]
|
| 38 |
+
try:
|
| 39 |
+
subprocess.run(fn_call, check=True, stderr=subprocess.STDOUT)
|
| 40 |
+
except subprocess.CalledProcessError as e:
|
| 41 |
+
print(f'Error from run_retrieve: {src_dir}')
|
| 42 |
+
print(f'Error: {e}')
|
| 43 |
+
|
| 44 |
+
def make_white_background(src_img):
|
| 45 |
+
'''Make the white background for the input RGBA image.'''
|
| 46 |
+
src_img.load()
|
| 47 |
+
background = Image.new("RGB", src_img.size, (255, 255, 255))
|
| 48 |
+
background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
|
| 49 |
+
return background
|
| 50 |
+
|
| 51 |
+
def pad_to_square(img, fill=0):
|
| 52 |
+
"""Pad image to square with given fill value (default: 0 = black)."""
|
| 53 |
+
width, height = img.size
|
| 54 |
+
if width == height:
|
| 55 |
+
return img
|
| 56 |
+
max_side = max(width, height)
|
| 57 |
+
delta_w = max_side - width
|
| 58 |
+
delta_h = max_side - height
|
| 59 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2)
|
| 60 |
+
return ImageOps.expand(img, padding, fill=fill)
|
| 61 |
+
|
| 62 |
+
def load_img(img_path):
|
| 63 |
+
transform = T.Compose([
|
| 64 |
+
T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
|
| 65 |
+
T.ToTensor(),
|
| 66 |
+
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 67 |
+
])
|
| 68 |
+
with Image.open(img_path) as img:
|
| 69 |
+
if img.mode == 'RGBA':
|
| 70 |
+
img = make_white_background(img)
|
| 71 |
+
img = img.convert('RGB') # Ensure it's 3-channel for normalization
|
| 72 |
+
img = pad_to_square(img, fill=0)
|
| 73 |
+
img = transform(img)
|
| 74 |
+
img_batch = img.unsqueeze(0).cuda()
|
| 75 |
+
|
| 76 |
+
return img_batch
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_frame_with_imageio(frame):
|
| 80 |
+
"""
|
| 81 |
+
将单帧图像处理为符合 DINO 模型输入的格式。
|
| 82 |
+
"""
|
| 83 |
+
transform = T.Compose([
|
| 84 |
+
T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
|
| 85 |
+
T.ToTensor(),
|
| 86 |
+
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 87 |
+
])
|
| 88 |
+
img = Image.fromarray(frame) # 转为 PIL 图像
|
| 89 |
+
if img.mode == 'RGBA':
|
| 90 |
+
img = make_white_background(img)
|
| 91 |
+
img = transform(img) # 应用预处理
|
| 92 |
+
return img.unsqueeze(0).cuda() # 增加 batch 维度
|
| 93 |
+
|
| 94 |
+
def read_video_as_batch_with_imageio(video_path):
|
| 95 |
+
"""
|
| 96 |
+
使用 imageio 读取视频并将所有帧处理为 batch 格式 (B, C, H, W)。
|
| 97 |
+
"""
|
| 98 |
+
reader = imageio.get_reader(video_path)
|
| 99 |
+
batch_frames = []
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
for frame in reader:
|
| 103 |
+
# 加载帧并处理为 (1, C, H, W)
|
| 104 |
+
processed_frame = load_frame_with_imageio(frame)
|
| 105 |
+
batch_frames.append(processed_frame)
|
| 106 |
+
|
| 107 |
+
reader.close()
|
| 108 |
+
if batch_frames:
|
| 109 |
+
return torch.cat(batch_frames, dim=0).cuda() # 在 batch 维度堆叠,并转移到 GPU
|
| 110 |
+
else:
|
| 111 |
+
print("视频没有有效帧")
|
| 112 |
+
return None
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"处理视频时出错: {e}")
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
def extract_dino_feature(img_path_1, img_path_2):
|
| 118 |
+
print('Extracting DINO feature...')
|
| 119 |
+
feat_1 = load_img(img_path_1)
|
| 120 |
+
feat_2 = load_img(img_path_2)
|
| 121 |
+
frames = torch.cat([feat_1, feat_2], dim=0)
|
| 122 |
+
dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg', pretrained=True).cuda()
|
| 123 |
+
print('step4')
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
feat = dinov2_vitb14_reg.forward_features(frames)["x_norm_patchtokens"]
|
| 126 |
+
# release the GPU memory of the model
|
| 127 |
+
feat_input = torch.cat([feat[0], feat[-1]], dim=0).unsqueeze(0)
|
| 128 |
+
print('Extracting DINO feature over')
|
| 129 |
+
torch.cuda.empty_cache()
|
| 130 |
+
return feat_input
|
| 131 |
+
|
| 132 |
+
def set_scheduler(n_steps=100):
|
| 133 |
+
scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='linear', prediction_type='epsilon')
|
| 134 |
+
scheduler.set_timesteps(n_steps)
|
| 135 |
+
return scheduler
|
| 136 |
+
|
| 137 |
+
def prepare_model_input(data, cond, feat, n_samples):
|
| 138 |
+
# attention masks
|
| 139 |
+
attr_mask = torch.from_numpy(cond['attr_mask']).unsqueeze(0).repeat(n_samples, 1, 1)
|
| 140 |
+
key_pad_mask = torch.from_numpy(cond['key_pad_mask'])
|
| 141 |
+
graph_mask = torch.from_numpy(cond['adj_mask'])
|
| 142 |
+
# input image feature
|
| 143 |
+
f = feat.repeat(n_samples, 1, 1)
|
| 144 |
+
# input noise
|
| 145 |
+
B, C = data.shape
|
| 146 |
+
noise = torch.randn([n_samples, B, C], dtype=torch.float32)
|
| 147 |
+
# dummy image feature (used for guided diffusion)
|
| 148 |
+
dummy_feat = torch.from_numpy(np.load('systems/dino_dummy.npy').astype(np.float32))
|
| 149 |
+
dummy_feat = dummy_feat.unsqueeze(0).repeat(n_samples, 1, 1)
|
| 150 |
+
# dummy object category
|
| 151 |
+
cat = torch.zeros(1, dtype=torch.long).repeat(n_samples)
|
| 152 |
+
return {
|
| 153 |
+
"noise": noise.cuda(),
|
| 154 |
+
"attr_mask": attr_mask.cuda(),
|
| 155 |
+
"key_pad_mask": key_pad_mask.cuda(),
|
| 156 |
+
"graph_mask": graph_mask.cuda(),
|
| 157 |
+
"dummy_f": dummy_feat.cuda(),
|
| 158 |
+
'cat': cat.cuda(),
|
| 159 |
+
'f': f.cuda(),
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def prepare_model_input_nocond(feat, n_samples):
|
| 163 |
+
# attention masks
|
| 164 |
+
cond_example = np.zeros((32*5, 32*5), dtype=bool)
|
| 165 |
+
attr_mask = np.eye(32, 32, dtype=bool)
|
| 166 |
+
attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
|
| 167 |
+
attr_mask = torch.from_numpy(attr_mask).unsqueeze(0).repeat(n_samples, 1, 1)
|
| 168 |
+
key_pad_mask = torch.from_numpy(cond_example).unsqueeze(0).repeat(n_samples, 1, 1)
|
| 169 |
+
graph_mask = torch.from_numpy(cond_example).unsqueeze(0).repeat(n_samples, 1, 1)
|
| 170 |
+
# input image feature
|
| 171 |
+
f = feat.repeat(n_samples, 1, 1)
|
| 172 |
+
# input noise
|
| 173 |
+
data = np.zeros((32*5, 6), dtype=bool)
|
| 174 |
+
noise = torch.randn(data.shape, dtype=torch.float32).repeat(n_samples, 1, 1)
|
| 175 |
+
# dummy image feature (used for guided diffusion)
|
| 176 |
+
dummy_feat = torch.from_numpy(np.load('systems/dino_dummy.npy').astype(np.float32))
|
| 177 |
+
dummy_feat = dummy_feat.unsqueeze(0).repeat(n_samples, 1, 1)
|
| 178 |
+
# dummy object category
|
| 179 |
+
cat = torch.zeros(1, dtype=torch.long).repeat(n_samples)
|
| 180 |
+
return {
|
| 181 |
+
"noise": noise.cuda(),
|
| 182 |
+
"attr_mask": attr_mask.cuda(),
|
| 183 |
+
"key_pad_mask": key_pad_mask.cuda(),
|
| 184 |
+
"graph_mask": graph_mask.cuda(),
|
| 185 |
+
"dummy_f": dummy_feat.cuda(),
|
| 186 |
+
'cat': cat.cuda(),
|
| 187 |
+
'f': f.cuda(),
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
def save_graph(pred_graph, save_dir):
|
| 191 |
+
print(f'Saving the predicted graph to {save_dir}/pred_graph.json')
|
| 192 |
+
# save the response
|
| 193 |
+
with open(os.path.join(save_dir, "pred_graph.json"), "w") as f:
|
| 194 |
+
json.dump(pred_graph, f, indent=4)
|
| 195 |
+
# Visualize the graph
|
| 196 |
+
# img_graph = Image.fromarray(viz_graph(pred_graph))
|
| 197 |
+
# img_graph.save(os.path.join(save_dir, "pred_graph.png"))
|
| 198 |
+
|
| 199 |
+
def forward(model, scheduler, inputs, omega=0.5):
|
| 200 |
+
print('Running inference...')
|
| 201 |
+
noisy_x = inputs['noise']
|
| 202 |
+
for t in scheduler.timesteps:
|
| 203 |
+
timesteps = torch.tensor([t], device=inputs['noise'].device)
|
| 204 |
+
outputs_cond = model(
|
| 205 |
+
x=noisy_x,
|
| 206 |
+
cat=inputs['cat'],
|
| 207 |
+
timesteps=timesteps,
|
| 208 |
+
feat=inputs['f'],
|
| 209 |
+
key_pad_mask=inputs['key_pad_mask'],
|
| 210 |
+
graph_mask=inputs['graph_mask'],
|
| 211 |
+
attr_mask=inputs['attr_mask'],
|
| 212 |
+
label_free=True,
|
| 213 |
+
) # take condtional image as input
|
| 214 |
+
if omega != 0:
|
| 215 |
+
outputs_free = model(
|
| 216 |
+
x=noisy_x,
|
| 217 |
+
cat=inputs['cat'],
|
| 218 |
+
timesteps=timesteps,
|
| 219 |
+
feat=inputs['dummy_f'],
|
| 220 |
+
key_pad_mask=inputs['key_pad_mask'],
|
| 221 |
+
graph_mask=inputs['graph_mask'],
|
| 222 |
+
attr_mask=inputs['attr_mask'],
|
| 223 |
+
label_free=True,
|
| 224 |
+
) # take the dummy DINO features for the condition-free mode
|
| 225 |
+
noise_pred = (1 + omega) * outputs_cond['noise_pred'] - omega * outputs_free['noise_pred']
|
| 226 |
+
else:
|
| 227 |
+
noise_pred = outputs_cond['noise_pred']
|
| 228 |
+
noisy_x = scheduler.step(noise_pred, t, noisy_x).prev_sample
|
| 229 |
+
return noisy_x
|
| 230 |
+
|
| 231 |
+
def _convert_json(x, c):
|
| 232 |
+
out = {"meta": {}, "diffuse_tree": []}
|
| 233 |
+
n_nodes = c["n_nodes"]
|
| 234 |
+
par = c["parents"].tolist()
|
| 235 |
+
adj = c["adj"]
|
| 236 |
+
np.fill_diagonal(adj, 0) # remove self-loop for the root node
|
| 237 |
+
if "obj_cat" in c:
|
| 238 |
+
out["meta"]["obj_cat"] = c["obj_cat"]
|
| 239 |
+
|
| 240 |
+
# convert the data to original range
|
| 241 |
+
data = convert_data_range(x)
|
| 242 |
+
# parse the tree
|
| 243 |
+
out["diffuse_tree"] = parse_tree(data, n_nodes, par, adj)
|
| 244 |
+
return out
|
| 245 |
+
|
| 246 |
+
def post_process(output, cond, save_root, gt_data_root, visualize=False):
|
| 247 |
+
print('Post-processing...')
|
| 248 |
+
N = output.shape[0]
|
| 249 |
+
for i in range(N):
|
| 250 |
+
cond_n = {}
|
| 251 |
+
cond_n['n_nodes'] = cond['n_nodes'][i]
|
| 252 |
+
cond_n['parents'] = cond['parents'][i]
|
| 253 |
+
cond_n['adj'] = cond['adj'][i]
|
| 254 |
+
cond_n['obj_cat'] = cond['cat']
|
| 255 |
+
# convert the raw model output to the json format
|
| 256 |
+
out_json = _convert_json(output, cond_n)
|
| 257 |
+
save_dir = os.path.join(save_root, str(i))
|
| 258 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 259 |
+
with open(os.path.join(save_dir, "object.json"), "w") as f:
|
| 260 |
+
json.dump(out_json, f, indent=4)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# retrieve part meshes (call python script)
|
| 264 |
+
# print(f"Retrieving part meshes for the object {i}...")
|
| 265 |
+
# os.system(f"python scripts/mesh_retrieval/retrieve.py --src_dir {save_dir} --json_name object.json --gt_data_root {gt_data_root}")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def load_model(ckpt_path, config):
|
| 271 |
+
print('Loading model from checkpoint...')
|
| 272 |
+
model = models.make(config.name, config)
|
| 273 |
+
state_dict = torch.load(ckpt_path)
|
| 274 |
+
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
|
| 275 |
+
model.load_state_dict(state_dict)
|
| 276 |
+
model.eval()
|
| 277 |
+
return model.cuda()
|
| 278 |
+
|
| 279 |
+
def convert_pred_graph(pred_graph):
|
| 280 |
+
cond = {}
|
| 281 |
+
B, K = pred_graph.shape[:2]
|
| 282 |
+
adj = np.zeros((B, K, K), dtype=np.float32)
|
| 283 |
+
padding = np.zeros((B, 5 * K, 5* K), dtype=bool)
|
| 284 |
+
parents = np.zeros((B, K), dtype=np.int32)
|
| 285 |
+
n_nodes = np.zeros((B,), dtype=np.int32)
|
| 286 |
+
for b in range(B):
|
| 287 |
+
node_len = 0
|
| 288 |
+
for k in range(K):
|
| 289 |
+
if pred_graph[b, k] == k and k > 0:
|
| 290 |
+
node_len = k
|
| 291 |
+
break
|
| 292 |
+
node = pred_graph[b, k]
|
| 293 |
+
adj[b, k, node] = 1
|
| 294 |
+
adj[b, node, k] = 1
|
| 295 |
+
parents[b, k] = node
|
| 296 |
+
adj[b, node_len:] = 1
|
| 297 |
+
padding[b, :, :5 * node_len] = 1
|
| 298 |
+
parents[b, 0] = -1
|
| 299 |
+
n_nodes[b] = node_len
|
| 300 |
+
adj_mask = adj.astype(bool).repeat(5, axis=1).repeat(5, axis=2)
|
| 301 |
+
attr_mask = np.eye(32, 32, dtype=bool)
|
| 302 |
+
attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
|
| 303 |
+
|
| 304 |
+
cond['adj_mask'] = adj_mask
|
| 305 |
+
cond['attr_mask'] = attr_mask
|
| 306 |
+
cond['key_pad_mask'] = padding
|
| 307 |
+
|
| 308 |
+
cond['adj'] = adj
|
| 309 |
+
cond['parents'] = parents
|
| 310 |
+
cond['n_nodes'] = n_nodes
|
| 311 |
+
cond['cat'] = 'StorageFurniture'
|
| 312 |
+
|
| 313 |
+
data = np.zeros((32*5, 6), dtype=bool)
|
| 314 |
+
|
| 315 |
+
return data, cond
|
| 316 |
+
|
| 317 |
+
def bfs_tree_simple(tree_list):
|
| 318 |
+
order = [0] * len(tree_list)
|
| 319 |
+
queue = []
|
| 320 |
+
current_node_idx = 0
|
| 321 |
+
for node_idx, node in enumerate(tree_list):
|
| 322 |
+
if node['parent'] == -1:
|
| 323 |
+
queue.append(node['id'])
|
| 324 |
+
order[current_node_idx] = node_idx
|
| 325 |
+
current_node_idx += 1
|
| 326 |
+
break
|
| 327 |
+
while len(queue) > 0:
|
| 328 |
+
current_node = queue.pop(0)
|
| 329 |
+
for node_idx, node in enumerate(tree_list):
|
| 330 |
+
if node['parent'] == current_node:
|
| 331 |
+
queue.append(node['id'])
|
| 332 |
+
order[current_node_idx] = node_idx
|
| 333 |
+
current_node_idx += 1
|
| 334 |
+
|
| 335 |
+
return order
|
| 336 |
+
|
| 337 |
+
def get_graph_from_gpt(img_path_1, img_path_2):
|
| 338 |
+
first_img = Image.open(img_path_1)
|
| 339 |
+
first_img_data = first_img.resize((1024, 1024))
|
| 340 |
+
buffer = BytesIO()
|
| 341 |
+
first_img_data.save(buffer, format="PNG")
|
| 342 |
+
buffer.seek(0)
|
| 343 |
+
# encode the image as base64
|
| 344 |
+
first_encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
second_img = Image.open(img_path_2)
|
| 348 |
+
second_img_data = second_img.resize((1024, 1024))
|
| 349 |
+
buffer = BytesIO()
|
| 350 |
+
second_img_data.save(buffer, format="PNG")
|
| 351 |
+
buffer.seek(0)
|
| 352 |
+
# encode the image as base64
|
| 353 |
+
second_encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
|
| 354 |
+
|
| 355 |
+
pred_gpt = predict_graph_twomode('', first_img_data=first_encoded_image, second_img_data=second_encoded_image)
|
| 356 |
+
print(pred_gpt)
|
| 357 |
+
pred_graph = pred_gpt['diffuse_tree']
|
| 358 |
+
# order = bfs_tree_simple(pred_graph)
|
| 359 |
+
# pred_graph = [pred_graph[i] for i in order]
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
# generate array [0, 1, 2, ..., 31] for init
|
| 363 |
+
graph_array = np.array([i for i in range(32)])
|
| 364 |
+
for node_idx, node in enumerate(pred_graph):
|
| 365 |
+
if node['parent'] == -1:
|
| 366 |
+
graph_array[node_idx] = node_idx
|
| 367 |
+
else:
|
| 368 |
+
graph_array[node_idx] = node['parent']
|
| 369 |
+
|
| 370 |
+
# new axis for batch
|
| 371 |
+
graph_array = np.expand_dims(graph_array, axis=0)
|
| 372 |
+
|
| 373 |
+
cat_str = gpt_infer_image_category(first_encoded_image, second_encoded_image)
|
| 374 |
+
|
| 375 |
+
return torch.from_numpy(graph_array).cuda().repeat(3, 1), cat_str
|
| 376 |
+
|
| 377 |
+
@spaces.GPU
|
| 378 |
+
def run_demo(args):
|
| 379 |
+
# extract DINOV2 feature from the input image
|
| 380 |
+
t1 = time.time()
|
| 381 |
+
feat = extract_dino_feature(args.img_path_1, args.img_path_2)
|
| 382 |
+
t2 = time.time()
|
| 383 |
+
print(f'Extracted DINO feature in {t2 - t1:.2f} seconds')
|
| 384 |
+
scheduler = set_scheduler(args.n_denoise_steps)
|
| 385 |
+
# load the checkpoint of the model
|
| 386 |
+
model = load_model(args.ckpt_path, args.config.system.model)
|
| 387 |
+
|
| 388 |
+
# inference
|
| 389 |
+
with torch.no_grad():
|
| 390 |
+
t3 = time.time()
|
| 391 |
+
pred_graph, cat_str = get_graph_from_gpt(args.img_path_1, args.img_path_2)
|
| 392 |
+
t4 = time.time()
|
| 393 |
+
print(f'Got the predicted graph in {t4 - t3:.2f} seconds')
|
| 394 |
+
print(pred_graph)
|
| 395 |
+
data, cond = convert_pred_graph(pred_graph)
|
| 396 |
+
inputs = prepare_model_input(data, cond, feat, n_samples=args.n_samples)
|
| 397 |
+
|
| 398 |
+
# Update the object category
|
| 399 |
+
cond['cat'] = cat_str
|
| 400 |
+
inputs['cat'][:] = cat_ref[cat_str]
|
| 401 |
+
print(f'Object category predicted by GPT: {cat_str}, {cat_ref[cat_str]}')
|
| 402 |
+
|
| 403 |
+
output = forward(model, scheduler, inputs, omega=args.omega).cpu().numpy()
|
| 404 |
+
t5 = time.time()
|
| 405 |
+
print(f'Forwarded the model in {t5 - t4:.2f} seconds')
|
| 406 |
+
|
| 407 |
+
# post-process
|
| 408 |
+
post_process(output, cond, args.save_dir, args.gt_data_root, visualize=True)
|
| 409 |
+
|
| 410 |
+
# retrieve
|
| 411 |
+
for sample in os.listdir(args.save_dir):
|
| 412 |
+
sample_dir = os.path.join(args.save_dir, sample)
|
| 413 |
+
t6 = time.time()
|
| 414 |
+
run_retrieve(sample_dir, 'object.json', args.gt_data_root)
|
| 415 |
+
t7 = time.time()
|
| 416 |
+
print(f'Retrieved part meshes for in {t7 - t6:.2f} seconds')
|
| 417 |
+
|
| 418 |
+
save_json_path = os.path.join(args.save_dir, "0", "object.json")
|
| 419 |
+
with open(save_json_path, 'r') as file:
|
| 420 |
+
json_data = json.load(file)
|
| 421 |
+
create_urdf_from_json(json_data, save_json_path.replace('.json', '.urdf'))
|
| 422 |
+
pybullet_render(save_json_path.replace('.json', '.urdf'), os.path.join(args.save_dir, "0"), 8)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
if __name__ == '__main__':
|
| 426 |
+
'''
|
| 427 |
+
Script for running the inference on an example image input.
|
| 428 |
+
'''
|
| 429 |
+
parser = argparse.ArgumentParser()
|
| 430 |
+
parser.add_argument("--img_path_1", type=str, default='examples/1.png', help="path to the input image")
|
| 431 |
+
parser.add_argument("--img_path_2", type=str, default='examples/1_open_1.png', help="path to the input image")
|
| 432 |
+
parser.add_argument("--ckpt_path", type=str, default='exps/singapo/final/ckpts/last.ckpt', help="path to the checkpoint of the model")
|
| 433 |
+
parser.add_argument("--config_path", type=str, default='exps/singapo/final/config/parsed.yaml', help="path to the config file")
|
| 434 |
+
parser.add_argument("--use_example_graph", action="store_true", default=False, help="if you don't have the openai key yet, turn on to use the example graph for inference")
|
| 435 |
+
parser.add_argument("--save_dir", type=str, default='results', help="path to save the output")
|
| 436 |
+
parser.add_argument("--gt_data_root", type=str, default='./', help="the root directory of the original data, used for part mesh retrieval")
|
| 437 |
+
parser.add_argument("--n_samples", type=int, default=3, help="number of samples to generate given the input")
|
| 438 |
+
parser.add_argument("--omega", type=float, default=0.5, help="the weight of the condition-free mode in the inference")
|
| 439 |
+
parser.add_argument("--n_denoise_steps", type=int, default=100, help="number of denoising steps")
|
| 440 |
+
args = parser.parse_args()
|
| 441 |
+
|
| 442 |
+
assert os.path.exists(args.img_path_1), "The input image does not exist"
|
| 443 |
+
# assert os.path.exists(args.ckpt_path), "The checkpoint does not exist"
|
| 444 |
+
assert os.path.exists(args.config_path), "The config file does not exist"
|
| 445 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 446 |
+
|
| 447 |
+
config = load_config(args.config_path)
|
| 448 |
+
args.config = config
|
| 449 |
+
|
| 450 |
+
run_demo(args)
|
lightning_logs/version_0/hparams.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
lightning_logs/version_1/hparams.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
lightning_logs/version_2/hparams.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
lightning_logs/version_3/hparams.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
lightning_logs/version_4/hparams.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
lightning_logs/version_5/hparams.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{}
|
lightning_logs/version_6/hparams.yaml
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sys_origin
|
| 2 |
+
exp_dir: ./exps/dipo/denoiser
|
| 3 |
+
data_root: /horizon-bucket/robot_lab/users/ruiqi.wu/robot/dataset/blender
|
| 4 |
+
n_time_samples: 16
|
| 5 |
+
loss_fg_weight: 0.01
|
| 6 |
+
img_drop_prob: 0.1
|
| 7 |
+
guidance_scaler: 0.5
|
| 8 |
+
graph_drop_prob: 0.5
|
| 9 |
+
model:
|
| 10 |
+
name: denoiser
|
| 11 |
+
in_ch: 6
|
| 12 |
+
attn_dim: 128
|
| 13 |
+
n_head: 4
|
| 14 |
+
n_layers: 6
|
| 15 |
+
dropout: 0.1
|
| 16 |
+
K: 32
|
| 17 |
+
mode_num: 5
|
| 18 |
+
img_emb_dims:
|
| 19 |
+
- 768
|
| 20 |
+
- 128
|
| 21 |
+
cat_drop_prob: 0.5
|
| 22 |
+
scheduler:
|
| 23 |
+
name: ddpm
|
| 24 |
+
config:
|
| 25 |
+
num_train_timesteps: 1000
|
| 26 |
+
beta_schedule: linear
|
| 27 |
+
prediction_type: epsilon
|
| 28 |
+
lr_scheduler_adapter:
|
| 29 |
+
name: LinearWarmupCosineAnnealingLR
|
| 30 |
+
warmup_epochs: 3
|
| 31 |
+
max_epochs: 200
|
| 32 |
+
warmup_start_lr: 1.0e-06
|
| 33 |
+
eta_min: 1.0e-05
|
| 34 |
+
optimizer_adapter:
|
| 35 |
+
name: AdamW
|
| 36 |
+
args:
|
| 37 |
+
lr: 0.0005
|
| 38 |
+
betas:
|
| 39 |
+
- 0.9
|
| 40 |
+
- 0.99
|
| 41 |
+
eps: 1.0e-15
|
| 42 |
+
lr_scheduler_cage:
|
| 43 |
+
name: LinearWarmupCosineAnnealingLR
|
| 44 |
+
warmup_epochs: 3
|
| 45 |
+
max_epochs: 200
|
| 46 |
+
warmup_start_lr: 1.0e-06
|
| 47 |
+
eta_min: 1.0e-05
|
| 48 |
+
optimizer_cage:
|
| 49 |
+
name: AdamW
|
| 50 |
+
args:
|
| 51 |
+
lr: 5.0e-05
|
| 52 |
+
betas:
|
| 53 |
+
- 0.9
|
| 54 |
+
- 0.99
|
| 55 |
+
eps: 1.0e-15
|
| 56 |
+
hparams:
|
| 57 |
+
name: sys_origin
|
| 58 |
+
exp_dir: ./exps/dipo/denoiser
|
| 59 |
+
data_root: /horizon-bucket/robot_lab/users/ruiqi.wu/robot/dataset/blender
|
| 60 |
+
n_time_samples: 16
|
| 61 |
+
loss_fg_weight: 0.01
|
| 62 |
+
img_drop_prob: 0.1
|
| 63 |
+
guidance_scaler: 0.5
|
| 64 |
+
graph_drop_prob: 0.5
|
| 65 |
+
model:
|
| 66 |
+
name: denoiser
|
| 67 |
+
in_ch: 6
|
| 68 |
+
attn_dim: 128
|
| 69 |
+
n_head: 4
|
| 70 |
+
n_layers: 6
|
| 71 |
+
dropout: 0.1
|
| 72 |
+
K: 32
|
| 73 |
+
mode_num: 5
|
| 74 |
+
img_emb_dims:
|
| 75 |
+
- 768
|
| 76 |
+
- 128
|
| 77 |
+
cat_drop_prob: 0.5
|
| 78 |
+
scheduler:
|
| 79 |
+
name: ddpm
|
| 80 |
+
config:
|
| 81 |
+
num_train_timesteps: 1000
|
| 82 |
+
beta_schedule: linear
|
| 83 |
+
prediction_type: epsilon
|
| 84 |
+
lr_scheduler_adapter:
|
| 85 |
+
name: LinearWarmupCosineAnnealingLR
|
| 86 |
+
warmup_epochs: 3
|
| 87 |
+
max_epochs: 200
|
| 88 |
+
warmup_start_lr: 1.0e-06
|
| 89 |
+
eta_min: 1.0e-05
|
| 90 |
+
optimizer_adapter:
|
| 91 |
+
name: AdamW
|
| 92 |
+
args:
|
| 93 |
+
lr: 0.0005
|
| 94 |
+
betas:
|
| 95 |
+
- 0.9
|
| 96 |
+
- 0.99
|
| 97 |
+
eps: 1.0e-15
|
| 98 |
+
lr_scheduler_cage:
|
| 99 |
+
name: LinearWarmupCosineAnnealingLR
|
| 100 |
+
warmup_epochs: 3
|
| 101 |
+
max_epochs: 200
|
| 102 |
+
warmup_start_lr: 1.0e-06
|
| 103 |
+
eta_min: 1.0e-05
|
| 104 |
+
optimizer_cage:
|
| 105 |
+
name: AdamW
|
| 106 |
+
args:
|
| 107 |
+
lr: 5.0e-05
|
| 108 |
+
betas:
|
| 109 |
+
- 0.9
|
| 110 |
+
- 0.99
|
| 111 |
+
eps: 1.0e-15
|
lightning_logs/version_6/metrics.csv
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
lr-AdamW,lr-AdamW-1,step
|
| 2 |
+
1e-06,1e-06,0
|
| 3 |
+
0.0002505,2.5500000000000003e-05,10
|
| 4 |
+
0.0005,5.000000000000001e-05,20
|
metrics/__init__.py
ADDED
|
File without changes
|
metrics/aor.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 3 |
+
import numpy as np
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from metrics.iou import sampling_iou
|
| 6 |
+
from objects.motions import transform_all_parts
|
| 7 |
+
|
| 8 |
+
from objects.dict_utils import get_bbox_vertices
|
| 9 |
+
|
| 10 |
+
'''
|
| 11 |
+
This file computes the Average Overlap Ratio (AOR) metric\n
|
| 12 |
+
'''
|
| 13 |
+
|
| 14 |
+
def AOR(tgt, num_states=10, transform_use_plucker=False):
|
| 15 |
+
tree = tgt["diffuse_tree"]
|
| 16 |
+
states = np.linspace(0, 1, num_states)
|
| 17 |
+
original_bbox_vertices = np.array([get_bbox_vertices(tgt, i) for i in range(len(tgt["diffuse_tree"]))], dtype=np.float32)
|
| 18 |
+
|
| 19 |
+
ious = []
|
| 20 |
+
for state_idx, state in enumerate(states):
|
| 21 |
+
ious_per_state = []
|
| 22 |
+
bbox_vertices = deepcopy(original_bbox_vertices)
|
| 23 |
+
part_trans = transform_all_parts(bbox_vertices, tgt, state, transform_use_plucker)
|
| 24 |
+
for node in tree:
|
| 25 |
+
children = node['children']
|
| 26 |
+
num_children = len(children)
|
| 27 |
+
if num_children < 2:
|
| 28 |
+
continue
|
| 29 |
+
for i in range(num_children-1):
|
| 30 |
+
for j in range(i+1, num_children):
|
| 31 |
+
child_id = children[i]
|
| 32 |
+
sibling_id = children[j]
|
| 33 |
+
bbox_v_0 = deepcopy(bbox_vertices[child_id])
|
| 34 |
+
bbox_v_1 = deepcopy(bbox_vertices[sibling_id])
|
| 35 |
+
iou = sampling_iou(bbox_v_0, bbox_v_1, part_trans[child_id], part_trans[sibling_id], num_samples=10000)
|
| 36 |
+
if np.isnan(iou):
|
| 37 |
+
continue
|
| 38 |
+
ious_per_state.append(iou)
|
| 39 |
+
if len(ious_per_state) > 0:
|
| 40 |
+
ious.append(np.mean(ious_per_state))
|
| 41 |
+
if len(ious) == 0:
|
| 42 |
+
return -1
|
| 43 |
+
return float(np.mean(ious))
|
| 44 |
+
|
metrics/cd.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script computes the Chamfer Distance (CD) between two objects\n
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import trimesh
|
| 11 |
+
from copy import deepcopy
|
| 12 |
+
from pytorch3d.structures import Meshes
|
| 13 |
+
from pytorch3d.ops import sample_points_from_meshes
|
| 14 |
+
from pytorch3d.loss import chamfer_distance
|
| 15 |
+
from objects.motions import transform_all_parts
|
| 16 |
+
from objects.dict_utils import (
|
| 17 |
+
zero_center_object,
|
| 18 |
+
rescale_object,
|
| 19 |
+
compute_overall_bbox_size,
|
| 20 |
+
get_base_part_idx,
|
| 21 |
+
find_part_mapping
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def _load_and_combine_plys(dir, ply_files, scale=None, z_rotate=None, translate=None):
|
| 25 |
+
"""
|
| 26 |
+
Load and combine the ply files into one PyTorch3D mesh
|
| 27 |
+
|
| 28 |
+
- dir: the directory of the object in which the ply files are from\n
|
| 29 |
+
- ply_files: the list of ply files\n
|
| 30 |
+
- scale: the scale factor to apply to the vertices\n
|
| 31 |
+
- z_rotate: whether to rotate the object around the z-axis by 90 degrees\n
|
| 32 |
+
- translate: the translation to apply to the vertices\n
|
| 33 |
+
|
| 34 |
+
Return:\n
|
| 35 |
+
- mesh: one PyTorch3D mesh of the combined ply files
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# Combine the ply files into one
|
| 39 |
+
meshes = []
|
| 40 |
+
for ply_file in ply_files:
|
| 41 |
+
meshes.append(trimesh.load(os.path.join(dir, ply_file), force="mesh"))
|
| 42 |
+
full_part_mesh = trimesh.util.concatenate(meshes)
|
| 43 |
+
|
| 44 |
+
# Apply the transformations
|
| 45 |
+
full_part_mesh.vertices -= full_part_mesh.bounding_box.centroid
|
| 46 |
+
transformation = trimesh.transformations.compose_matrix(
|
| 47 |
+
scale=scale,
|
| 48 |
+
angles=[0, 0, np.radians(90) if z_rotate else 0],
|
| 49 |
+
translate=translate,
|
| 50 |
+
)
|
| 51 |
+
full_part_mesh.apply_transform(transformation)
|
| 52 |
+
|
| 53 |
+
# Create the PyTorch3D mesh
|
| 54 |
+
mesh = Meshes(
|
| 55 |
+
verts=torch.as_tensor(full_part_mesh.vertices, dtype=torch.float32, device='cuda').unsqueeze(
|
| 56 |
+
0
|
| 57 |
+
),
|
| 58 |
+
faces=torch.as_tensor(full_part_mesh.faces, dtype=torch.int32, device='cuda').unsqueeze(0),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
return mesh
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _compute_chamfer_distance(
|
| 65 |
+
obj1_part_points, obj2_part_points, part_mapping=None, exclude_id=-1
|
| 66 |
+
):
|
| 67 |
+
"""
|
| 68 |
+
Compute the chamfer distance between the two set of points representing the two objects
|
| 69 |
+
|
| 70 |
+
- obj1_part_points: the set of points representing the first object\n
|
| 71 |
+
- obj2_part_points: the set of points representing the second object\n
|
| 72 |
+
- part_mapping (optional): the part mapping from the first object to the second object, if provided, the chamfer distance will be computed between the corresponding parts\n
|
| 73 |
+
- exclude_id (optional): the part id to exclude from the chamfer distance computation, the default if provided is the base part id\n
|
| 74 |
+
|
| 75 |
+
Return:\n
|
| 76 |
+
- distance: the chamfer distance between the two objects
|
| 77 |
+
"""
|
| 78 |
+
if part_mapping is not None:
|
| 79 |
+
n_parts = part_mapping.shape[0]
|
| 80 |
+
distance = 0
|
| 81 |
+
for i in range(n_parts):
|
| 82 |
+
if i == exclude_id:
|
| 83 |
+
continue
|
| 84 |
+
obj1_part_points_i = obj1_part_points[i]
|
| 85 |
+
obj2_part_points_i = obj2_part_points[int(part_mapping[i, 0])]
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
obj1_part_points_i = obj1_part_points_i.cuda()
|
| 88 |
+
obj2_part_points_i = obj2_part_points_i.cuda()
|
| 89 |
+
# symmetric chamfer distance
|
| 90 |
+
forward_distance, _ = chamfer_distance(
|
| 91 |
+
obj1_part_points_i[None, :],
|
| 92 |
+
obj2_part_points_i[None, :],
|
| 93 |
+
batch_reduction=None,
|
| 94 |
+
)
|
| 95 |
+
backward_distance, _ = chamfer_distance(
|
| 96 |
+
obj2_part_points_i[None, :],
|
| 97 |
+
obj1_part_points_i[None, :],
|
| 98 |
+
batch_reduction=None,
|
| 99 |
+
)
|
| 100 |
+
distance += (forward_distance.item() + backward_distance.item()) * 0.5
|
| 101 |
+
distance /= n_parts
|
| 102 |
+
else:
|
| 103 |
+
# Merge the points of all parts into one tensor
|
| 104 |
+
obj1_part_points = obj1_part_points.reshape(-1, 3)
|
| 105 |
+
obj2_part_points = obj2_part_points.reshape(-1, 3)
|
| 106 |
+
|
| 107 |
+
# Compute the chamfer distance between the two objects
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
obj1_part_points = obj1_part_points.cuda()
|
| 110 |
+
obj2_part_points = obj2_part_points.cuda()
|
| 111 |
+
forward_distance, _ = chamfer_distance(
|
| 112 |
+
obj1_part_points[None, :],
|
| 113 |
+
obj2_part_points[None, :],
|
| 114 |
+
batch_reduction=None,
|
| 115 |
+
)
|
| 116 |
+
backward_distance, _ = chamfer_distance(
|
| 117 |
+
obj2_part_points[None, :],
|
| 118 |
+
obj1_part_points[None, :],
|
| 119 |
+
batch_reduction=None,
|
| 120 |
+
)
|
| 121 |
+
distance = (forward_distance.item() + backward_distance.item()) * 0.5
|
| 122 |
+
|
| 123 |
+
return distance
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _get_scores(
|
| 127 |
+
src_dict,
|
| 128 |
+
tgt_dict,
|
| 129 |
+
original_src_part_points,
|
| 130 |
+
original_tgt_part_points,
|
| 131 |
+
part_mapping,
|
| 132 |
+
num_states,
|
| 133 |
+
include_base,
|
| 134 |
+
src_base_idx,
|
| 135 |
+
):
|
| 136 |
+
|
| 137 |
+
chamfer_distances = np.zeros(num_states, dtype=np.float32)
|
| 138 |
+
joint_states = np.linspace(0, 1, num_states)
|
| 139 |
+
for state_idx, state in enumerate(joint_states):
|
| 140 |
+
|
| 141 |
+
# Reset the part point clouds
|
| 142 |
+
src_part_points = deepcopy(original_src_part_points)
|
| 143 |
+
tgt_part_points = deepcopy(original_tgt_part_points)
|
| 144 |
+
|
| 145 |
+
# Transform the part point clouds to the current state using the joints
|
| 146 |
+
transform_all_parts(src_part_points.numpy(), src_dict, state, dry_run=False)
|
| 147 |
+
transform_all_parts(tgt_part_points.numpy(), tgt_dict, state, dry_run=False)
|
| 148 |
+
|
| 149 |
+
# Compute the chamfer distance between the two objects
|
| 150 |
+
chamfer_distances[state_idx] = _compute_chamfer_distance(
|
| 151 |
+
src_part_points,
|
| 152 |
+
tgt_part_points,
|
| 153 |
+
part_mapping=part_mapping,
|
| 154 |
+
exclude_id=-1 if include_base else src_base_idx,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Compute the ID
|
| 158 |
+
aid_cd = np.mean(chamfer_distances)
|
| 159 |
+
rid_cd = chamfer_distances[0]
|
| 160 |
+
|
| 161 |
+
return {
|
| 162 |
+
"AS-CD": float(aid_cd),
|
| 163 |
+
"RS-CD": float(rid_cd),
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def CD(
|
| 168 |
+
gen_obj_dict,
|
| 169 |
+
gen_obj_path,
|
| 170 |
+
gt_obj_dict,
|
| 171 |
+
gt_obj_path,
|
| 172 |
+
num_states=5,
|
| 173 |
+
num_samples=2048,
|
| 174 |
+
include_base=False,
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
Compute the Chamfer Distance\n
|
| 178 |
+
This metric is the average of per-part chamfer distance between the two objects over a number of articulation states\n
|
| 179 |
+
|
| 180 |
+
- gen_obj_dict: the generated object dictionary\n
|
| 181 |
+
- gen_obj_path: the directory to the predicted object\n
|
| 182 |
+
- gt_obj_dict: the ground truth object dictionary\n
|
| 183 |
+
- gt_obj_path: the directory to the ground truth object\n
|
| 184 |
+
- num_states (optional): the number of articulation states to compute the metric\n
|
| 185 |
+
- num_samples (optional): the number of samples to use\n
|
| 186 |
+
- include_base (optional): whether to include the base part in the chamfer distance computation\n
|
| 187 |
+
|
| 188 |
+
Return:\n
|
| 189 |
+
- aid_score: the score over the sampled articulated states\n
|
| 190 |
+
- rid_score: the score at the resting state\n
|
| 191 |
+
- The score is in the range of [0, inf), lower is better
|
| 192 |
+
"""
|
| 193 |
+
# Make copies of the dictionaries to avoid modifying the original dictionaries
|
| 194 |
+
gen_dict = deepcopy(gen_obj_dict)
|
| 195 |
+
gt_dict = deepcopy(gt_obj_dict)
|
| 196 |
+
|
| 197 |
+
# Zero center the objects
|
| 198 |
+
zero_center_object(gen_dict)
|
| 199 |
+
zero_center_object(gt_dict)
|
| 200 |
+
|
| 201 |
+
# Compute the scale factor by comparing the overall bbox size and scale the candidate object as a whole
|
| 202 |
+
gen_bbox_size = compute_overall_bbox_size(gen_dict)
|
| 203 |
+
gt_bbox_size = compute_overall_bbox_size(gt_dict)
|
| 204 |
+
scale_factor = gen_bbox_size / gt_bbox_size
|
| 205 |
+
rescale_object(gen_obj_dict, scale_factor)
|
| 206 |
+
|
| 207 |
+
# Record the indices of the base parts of the two objects
|
| 208 |
+
gen_base_idx = get_base_part_idx(gen_dict)
|
| 209 |
+
gt_base_idx = get_base_part_idx(gt_dict)
|
| 210 |
+
|
| 211 |
+
# Find mapping between the parts of the two objects based on closest bbox centers
|
| 212 |
+
mapping_gen2gt = find_part_mapping(gen_dict, gt_dict, use_hungarian=True)
|
| 213 |
+
mapping_gt2gen = find_part_mapping(gt_dict, gen_dict, use_hungarian=True)
|
| 214 |
+
|
| 215 |
+
# Get the number of parts of the two objects
|
| 216 |
+
gen_tree = gen_dict["diffuse_tree"]
|
| 217 |
+
gt_tree = gt_dict["diffuse_tree"]
|
| 218 |
+
gen_num_parts = len(gen_tree)
|
| 219 |
+
gt_num_parts = len(gt_tree)
|
| 220 |
+
|
| 221 |
+
# Get the paths of the ply files of the two objects
|
| 222 |
+
gen_part_ply_paths = [
|
| 223 |
+
{"dir": gen_obj_path, "files": gen_tree[i]["plys"]}
|
| 224 |
+
for i in range(gen_num_parts)
|
| 225 |
+
]
|
| 226 |
+
gt_part_ply_paths = [
|
| 227 |
+
{"dir": gt_obj_path, "files": gt_tree[i]["plys"]}
|
| 228 |
+
for i in range(gt_num_parts)
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
# Load the ply files of the two objects and sample points from them
|
| 232 |
+
gen_part_points = torch.zeros(
|
| 233 |
+
(gen_num_parts, num_samples, 3), dtype=torch.float32
|
| 234 |
+
)
|
| 235 |
+
for i in range(gen_num_parts):
|
| 236 |
+
part_mesh = _load_and_combine_plys(
|
| 237 |
+
gen_part_ply_paths[i]["dir"],
|
| 238 |
+
gen_part_ply_paths[i]["files"],
|
| 239 |
+
scale=scale_factor,
|
| 240 |
+
translate=gen_tree[i]["aabb"]["center"],
|
| 241 |
+
)
|
| 242 |
+
gen_part_points[i] = sample_points_from_meshes(
|
| 243 |
+
part_mesh, num_samples=num_samples
|
| 244 |
+
).squeeze(0).cpu()
|
| 245 |
+
|
| 246 |
+
gt_part_points = torch.zeros(
|
| 247 |
+
(gt_num_parts, num_samples, 3), dtype=torch.float32
|
| 248 |
+
)
|
| 249 |
+
for i in range(gt_num_parts):
|
| 250 |
+
part_mesh = _load_and_combine_plys(
|
| 251 |
+
gt_part_ply_paths[i]["dir"],
|
| 252 |
+
gt_part_ply_paths[i]["files"],
|
| 253 |
+
translate=gt_tree[i]["aabb"]["center"],
|
| 254 |
+
)
|
| 255 |
+
gt_part_points[i] = sample_points_from_meshes(
|
| 256 |
+
part_mesh, num_samples=num_samples
|
| 257 |
+
).squeeze(0).cpu()
|
| 258 |
+
|
| 259 |
+
cd_gen2gt = _get_scores(
|
| 260 |
+
gen_dict,
|
| 261 |
+
gt_dict,
|
| 262 |
+
gen_part_points,
|
| 263 |
+
gt_part_points,
|
| 264 |
+
mapping_gen2gt,
|
| 265 |
+
num_states,
|
| 266 |
+
include_base,
|
| 267 |
+
gen_base_idx,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
cd_gt2gen = _get_scores(
|
| 271 |
+
gt_dict,
|
| 272 |
+
gen_dict,
|
| 273 |
+
gt_part_points,
|
| 274 |
+
gen_part_points,
|
| 275 |
+
mapping_gt2gen,
|
| 276 |
+
num_states,
|
| 277 |
+
include_base,
|
| 278 |
+
gt_base_idx,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
return {
|
| 282 |
+
"AS-CD": (cd_gen2gt["AS-CD"] + cd_gt2gen["AS-CD"]) / 2,
|
| 283 |
+
"RS-CD": (cd_gen2gt["RS-CD"] + cd_gt2gen["RS-CD"]) / 2,
|
| 284 |
+
}
|
metrics/giou.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 3 |
+
import numpy as np
|
| 4 |
+
from metrics.iou import (
|
| 5 |
+
_sample_points_in_box3d,
|
| 6 |
+
_apply_backward_transformations,
|
| 7 |
+
_apply_forward_transformations,
|
| 8 |
+
_count_points_in_box3d,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def giou_aabb(bbox1_vertices, bbox2_verices):
|
| 13 |
+
"""
|
| 14 |
+
Compute the generalized IoU between two axis-aligned bounding boxes\n
|
| 15 |
+
- bbox1_vertices: the vertices of the first bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
|
| 16 |
+
- bbox2_vertices: the vertices of the second bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
|
| 17 |
+
|
| 18 |
+
Return:\n
|
| 19 |
+
- giou: the gIoU between the two bounding boxes
|
| 20 |
+
"""
|
| 21 |
+
volume1 = np.prod(np.max(bbox1_vertices, axis=0) - np.min(bbox1_vertices, axis=0))
|
| 22 |
+
volume2 = np.prod(np.max(bbox2_verices, axis=0) - np.min(bbox2_verices, axis=0))
|
| 23 |
+
|
| 24 |
+
# Compute the intersection and union of the two bounding boxes
|
| 25 |
+
min_bbox = np.maximum(np.min(bbox1_vertices, axis=0), np.min(bbox2_verices, axis=0))
|
| 26 |
+
max_bbox = np.minimum(np.max(bbox1_vertices, axis=0), np.max(bbox2_verices, axis=0))
|
| 27 |
+
intersection = np.prod(np.clip(max_bbox - min_bbox, a_min=0, a_max=None))
|
| 28 |
+
union = volume1 + volume2 - intersection
|
| 29 |
+
# Compute IoU
|
| 30 |
+
iou = intersection / union if union > 0 else 0
|
| 31 |
+
|
| 32 |
+
# Compute the smallest enclosing box
|
| 33 |
+
min_enclosing_bbox = np.minimum(np.min(bbox1_vertices, axis=0), np.min(bbox2_verices, axis=0))
|
| 34 |
+
max_enclosing_bbox = np.maximum(np.max(bbox1_vertices, axis=0), np.max(bbox2_verices, axis=0))
|
| 35 |
+
volume3 = np.prod(max_enclosing_bbox - min_enclosing_bbox)
|
| 36 |
+
|
| 37 |
+
# Compute gIoU
|
| 38 |
+
giou = iou - (volume3 - union) / volume3 if volume3 > 0 else iou
|
| 39 |
+
|
| 40 |
+
return giou
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def sampling_giou(
|
| 44 |
+
bbox1_vertices,
|
| 45 |
+
bbox2_vertices,
|
| 46 |
+
bbox1_transformations,
|
| 47 |
+
bbox2_transformations,
|
| 48 |
+
num_samples=10000,
|
| 49 |
+
):
|
| 50 |
+
"""
|
| 51 |
+
Compute the IoU between two bounding boxes\n
|
| 52 |
+
- bbox1_vertices: the vertices of the first bounding box\n
|
| 53 |
+
- bbox2_vertices: the vertices of the second bounding box\n
|
| 54 |
+
- bbox1_transformations: list of transformations applied to the first bounding box\n
|
| 55 |
+
- bbox2_transformations: list of transformations applied to the second bounding box\n
|
| 56 |
+
- num_samples (optional): the number of samples to use per bounding box\n
|
| 57 |
+
|
| 58 |
+
Return:\n
|
| 59 |
+
- iou: the IoU between the two bounding boxes after applying the transformations
|
| 60 |
+
"""
|
| 61 |
+
# if no transformations are applied, use the axis-aligned bounding box IoU
|
| 62 |
+
if len(bbox1_transformations) == 0 and len(bbox2_transformations) == 0:
|
| 63 |
+
return giou_aabb(bbox1_vertices, bbox2_vertices)
|
| 64 |
+
|
| 65 |
+
# Volume of the two bounding boxes
|
| 66 |
+
bbox1_volume = np.prod(
|
| 67 |
+
np.max(bbox1_vertices, axis=0) - np.min(bbox1_vertices, axis=0)
|
| 68 |
+
)
|
| 69 |
+
bbox2_volume = np.prod(
|
| 70 |
+
np.max(bbox2_vertices, axis=0) - np.min(bbox2_vertices, axis=0)
|
| 71 |
+
)
|
| 72 |
+
# Volume of the smallest enclosing box
|
| 73 |
+
min_enclosing_bbox = np.minimum(np.min(bbox1_vertices, axis=0), np.min(bbox2_vertices, axis=0))
|
| 74 |
+
max_enclosing_bbox = np.maximum(np.max(bbox1_vertices, axis=0), np.max(bbox2_vertices, axis=0))
|
| 75 |
+
cbbox_volume = np.prod(max_enclosing_bbox - min_enclosing_bbox)
|
| 76 |
+
|
| 77 |
+
# Sample points in the two bounding boxes
|
| 78 |
+
bbox1_points = _sample_points_in_box3d(bbox1_vertices, num_samples)
|
| 79 |
+
bbox2_points = _sample_points_in_box3d(bbox2_vertices, num_samples)
|
| 80 |
+
|
| 81 |
+
# Transform the points
|
| 82 |
+
forward_bbox1_points = _apply_forward_transformations(
|
| 83 |
+
bbox1_points, bbox1_transformations
|
| 84 |
+
)
|
| 85 |
+
forward_bbox2_points = _apply_forward_transformations(
|
| 86 |
+
bbox2_points, bbox2_transformations
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Transform the forward points to the other box's rest pose frame
|
| 90 |
+
forward_bbox1_points_in_rest_bbox2_frame = _apply_backward_transformations(
|
| 91 |
+
forward_bbox1_points, bbox2_transformations
|
| 92 |
+
)
|
| 93 |
+
forward_bbox2_points_in_rest_bbox1_frame = _apply_backward_transformations(
|
| 94 |
+
forward_bbox2_points, bbox1_transformations
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Count the number of points in the other bounding box
|
| 98 |
+
num_bbox1_points_in_bbox2 = _count_points_in_box3d(
|
| 99 |
+
forward_bbox1_points_in_rest_bbox2_frame, bbox2_vertices
|
| 100 |
+
)
|
| 101 |
+
num_bbox2_points_in_bbox1 = _count_points_in_box3d(
|
| 102 |
+
forward_bbox2_points_in_rest_bbox1_frame, bbox1_vertices
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Compute the IoU
|
| 106 |
+
intersect = (
|
| 107 |
+
bbox1_volume * num_bbox1_points_in_bbox2
|
| 108 |
+
+ bbox2_volume * num_bbox2_points_in_bbox1
|
| 109 |
+
) / 2
|
| 110 |
+
union = bbox1_volume * num_samples + bbox2_volume * num_samples - intersect
|
| 111 |
+
iou = intersect / union
|
| 112 |
+
|
| 113 |
+
giou = iou - (cbbox_volume * num_samples - union) / (cbbox_volume * num_samples) if cbbox_volume > 0 else iou
|
| 114 |
+
|
| 115 |
+
return giou
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def sampling_cDist(
|
| 119 |
+
part1,
|
| 120 |
+
part2,
|
| 121 |
+
bbox1_transformations,
|
| 122 |
+
bbox2_transformations,
|
| 123 |
+
):
|
| 124 |
+
'''
|
| 125 |
+
Compute the centroid distance between two bounding boxes\n
|
| 126 |
+
- bbox1_vertices: the vertices of the first bounding box\n
|
| 127 |
+
- bbox2_vertices: the vertices of the second bounding box\n
|
| 128 |
+
- bbox1_transformations: list of transformations applied to the first bounding box\n
|
| 129 |
+
- bbox2_transformations: list of transformations applied to the second bounding box\n
|
| 130 |
+
'''
|
| 131 |
+
|
| 132 |
+
bbox1_centroid = np.array(part1['aabb']['center'], dtype=np.float32).reshape(1, 3)
|
| 133 |
+
bbox2_centroid = np.array(part2['aabb']['center'], dtype=np.float32).reshape(1, 3)
|
| 134 |
+
|
| 135 |
+
# Transform the centroids
|
| 136 |
+
bbox1_transformed_centroids = _apply_forward_transformations(bbox1_centroid, bbox1_transformations)
|
| 137 |
+
bbox2_transformed_centroids = _apply_forward_transformations(bbox2_centroid, bbox2_transformations)
|
| 138 |
+
|
| 139 |
+
# Compute the centroid distance
|
| 140 |
+
cDist = np.linalg.norm(bbox1_transformed_centroids - bbox2_transformed_centroids)
|
| 141 |
+
|
| 142 |
+
return cDist
|
metrics/iou.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def _sample_points_in_box3d(bbox_vertices, num_samples):
|
| 5 |
+
"""
|
| 6 |
+
Sample points in a axis-aligned 3D bounding box\n
|
| 7 |
+
- bbox_vertices: the vertices of the bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
|
| 8 |
+
- num_samples: the number of samples to use\n
|
| 9 |
+
|
| 10 |
+
Return:\n
|
| 11 |
+
- points: the sampled points in the form: [[x0, y0, z0], [x1, y1, z1], ...]
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
# Compute the bounding box size
|
| 15 |
+
bbox_size = np.max(bbox_vertices, axis=0) - np.min(bbox_vertices, axis=0)
|
| 16 |
+
|
| 17 |
+
# Sample points in the bounding box
|
| 18 |
+
points = np.random.rand(num_samples, 3) * bbox_size + np.min(bbox_vertices, axis=0)
|
| 19 |
+
|
| 20 |
+
return points
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _apply_forward_transformations(points, transformations):
|
| 24 |
+
"""
|
| 25 |
+
Apply forward transformations to the points\n
|
| 26 |
+
- points: the points in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
|
| 27 |
+
- transformations: list of transformations to apply\n
|
| 28 |
+
|
| 29 |
+
Return:\n
|
| 30 |
+
- points_transformed: the transformed points in the form: [[x0, y0, z0], [x1, y1, z1], ...]
|
| 31 |
+
"""
|
| 32 |
+
if len(transformations) == 0:
|
| 33 |
+
return points
|
| 34 |
+
|
| 35 |
+
# To homogeneous coordinates
|
| 36 |
+
points_transformed = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1)
|
| 37 |
+
|
| 38 |
+
# Apply the transformations one by one in order
|
| 39 |
+
for transformation in transformations:
|
| 40 |
+
if transformation["type"] == "translation":
|
| 41 |
+
points_transformed = np.matmul(
|
| 42 |
+
transformation["matrix"], points_transformed.T
|
| 43 |
+
).T
|
| 44 |
+
|
| 45 |
+
elif transformation["type"] == "rotation":
|
| 46 |
+
axis_origin = np.append(transformation["rotation_axis_origin"], 0)
|
| 47 |
+
points_recentered = points_transformed - axis_origin
|
| 48 |
+
|
| 49 |
+
points_rotated = np.matmul(transformation["matrix"], points_recentered.T).T
|
| 50 |
+
points_transformed = points_rotated + axis_origin
|
| 51 |
+
|
| 52 |
+
elif transformation["type"] == "plucker":
|
| 53 |
+
points_transformed = np.matmul(
|
| 54 |
+
transformation["matrix"], points_transformed.T
|
| 55 |
+
).T
|
| 56 |
+
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"Unknown transformation type: {transformation['type']}")
|
| 59 |
+
|
| 60 |
+
return points_transformed[..., :3]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _apply_backward_transformations(points, transformations):
|
| 64 |
+
"""
|
| 65 |
+
Apply backward transformations to the points\n
|
| 66 |
+
- points: the points in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
|
| 67 |
+
- transformations: list of transformations to apply\n
|
| 68 |
+
- The inverse of the transformations are applied in reverse order\n
|
| 69 |
+
|
| 70 |
+
Return:\n
|
| 71 |
+
- points_transformed: the transformed points in the form: [[x0, y0, z0], [x1, y1, z1], ...]
|
| 72 |
+
|
| 73 |
+
Reference: https://mathematica.stackexchange.com/questions/106257/how-do-i-get-the-inverse-of-a-homogeneous-transformation-matrix
|
| 74 |
+
"""
|
| 75 |
+
if len(transformations) == 0:
|
| 76 |
+
return points
|
| 77 |
+
|
| 78 |
+
# To homogeneous coordinates
|
| 79 |
+
points_transformed = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1)
|
| 80 |
+
|
| 81 |
+
# Apply the transformations one by one in reverse order
|
| 82 |
+
for transformation in transformations[::-1]:
|
| 83 |
+
inv_transformation = np.eye(4)
|
| 84 |
+
inv_transformation[:3, :3] = transformation["matrix"][:3, :3].T
|
| 85 |
+
inv_transformation[:3, 3] = -np.matmul(
|
| 86 |
+
transformation["matrix"][:3, :3].T, transformation["matrix"][:3, 3]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if transformation["type"] == "translation":
|
| 90 |
+
points_transformed = np.matmul(inv_transformation, points_transformed.T).T
|
| 91 |
+
|
| 92 |
+
elif transformation["type"] == "rotation":
|
| 93 |
+
axis_origin = np.append(transformation["rotation_axis_origin"], 0)
|
| 94 |
+
points_recentered = points_transformed - axis_origin
|
| 95 |
+
|
| 96 |
+
points_rotated = np.matmul(inv_transformation, points_recentered.T).T
|
| 97 |
+
points_transformed = points_rotated + axis_origin
|
| 98 |
+
|
| 99 |
+
elif transformation["type"] == "plucker":
|
| 100 |
+
points_transformed = np.matmul(inv_transformation, points_transformed.T).T
|
| 101 |
+
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f"Unknown transformation type: {transformation['type']}")
|
| 104 |
+
|
| 105 |
+
return points_transformed[..., :3]
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _count_points_in_box3d(points, bbox_vertices):
|
| 109 |
+
"""
|
| 110 |
+
Count the number of points in a 3D bounding box\n
|
| 111 |
+
- points: the points in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
|
| 112 |
+
- bbox_vertices: the vertices of the bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
|
| 113 |
+
- The bbox is assumed to be axis-aligned\n
|
| 114 |
+
|
| 115 |
+
Return:\n
|
| 116 |
+
- num_points_in_bbox: the number of points in the bounding box
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
# Count the number of points in the bounding box
|
| 120 |
+
num_points_in_bbox = np.sum(
|
| 121 |
+
np.all(points >= np.min(bbox_vertices, axis=0), axis=1)
|
| 122 |
+
& np.all(points <= np.max(bbox_vertices, axis=0), axis=1)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return num_points_in_bbox
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def iou_aabb(bbox1_vertices, bbox2_verices):
|
| 129 |
+
"""
|
| 130 |
+
Compute the IoU between two axis-aligned bounding boxes\n
|
| 131 |
+
- bbox1_vertices: the vertices of the first bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
|
| 132 |
+
- bbox2_vertices: the vertices of the second bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
|
| 133 |
+
|
| 134 |
+
Return:\n
|
| 135 |
+
- iou: the IoU between the two bounding boxes
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
# Compute the intersection and union of the two bounding boxes
|
| 139 |
+
min_bbox = np.maximum(np.min(bbox1_vertices, axis=0), np.min(bbox2_verices, axis=0))
|
| 140 |
+
max_bbox = np.minimum(np.max(bbox1_vertices, axis=0), np.max(bbox2_verices, axis=0))
|
| 141 |
+
intersection = np.prod(np.clip(max_bbox - min_bbox, a_min=0, a_max=None))
|
| 142 |
+
union = (
|
| 143 |
+
np.prod(np.max(bbox1_vertices, axis=0) - np.min(bbox1_vertices, axis=0))
|
| 144 |
+
+ np.prod(np.max(bbox2_verices, axis=0) - np.min(bbox2_verices, axis=0))
|
| 145 |
+
- intersection
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Compute the IoU
|
| 149 |
+
iou = intersection / union if union > 0 else 0
|
| 150 |
+
|
| 151 |
+
return iou
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def sampling_iou(
|
| 155 |
+
bbox1_vertices,
|
| 156 |
+
bbox2_vertices,
|
| 157 |
+
bbox1_transformations,
|
| 158 |
+
bbox2_transformations,
|
| 159 |
+
num_samples=10000,
|
| 160 |
+
):
|
| 161 |
+
"""
|
| 162 |
+
Compute the IoU between two bounding boxes\n
|
| 163 |
+
- bbox1_vertices: the vertices of the first bounding box\n
|
| 164 |
+
- bbox2_vertices: the vertices of the second bounding box\n
|
| 165 |
+
- bbox1_transformations: list of transformations applied to the first bounding box\n
|
| 166 |
+
- bbox2_transformations: list of transformations applied to the second bounding box\n
|
| 167 |
+
- num_samples (optional): the number of samples to use per bounding box\n
|
| 168 |
+
|
| 169 |
+
Return:\n
|
| 170 |
+
- iou: the IoU between the two bounding boxes after applying the transformations
|
| 171 |
+
"""
|
| 172 |
+
# if no transformations are applied, use the axis-aligned bounding box IoU
|
| 173 |
+
if len(bbox1_transformations) == 0 and len(bbox2_transformations) == 0:
|
| 174 |
+
return iou_aabb(bbox1_vertices, bbox2_vertices)
|
| 175 |
+
|
| 176 |
+
# Volume of the two bounding boxes
|
| 177 |
+
bbox1_volume = np.prod(
|
| 178 |
+
np.max(bbox1_vertices, axis=0) - np.min(bbox1_vertices, axis=0)
|
| 179 |
+
)
|
| 180 |
+
bbox2_volume = np.prod(
|
| 181 |
+
np.max(bbox2_vertices, axis=0) - np.min(bbox2_vertices, axis=0)
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Sample points in the two bounding boxes
|
| 185 |
+
bbox1_points = _sample_points_in_box3d(bbox1_vertices, num_samples)
|
| 186 |
+
bbox2_points = _sample_points_in_box3d(bbox2_vertices, num_samples)
|
| 187 |
+
|
| 188 |
+
# Transform the points
|
| 189 |
+
forward_bbox1_points = _apply_forward_transformations(
|
| 190 |
+
bbox1_points, bbox1_transformations
|
| 191 |
+
)
|
| 192 |
+
forward_bbox2_points = _apply_forward_transformations(
|
| 193 |
+
bbox2_points, bbox2_transformations
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Transform the forward points to the other box's rest pose frame
|
| 197 |
+
forward_bbox1_points_in_rest_bbox2_frame = _apply_backward_transformations(
|
| 198 |
+
forward_bbox1_points, bbox2_transformations
|
| 199 |
+
)
|
| 200 |
+
forward_bbox2_points_in_rest_bbox1_frame = _apply_backward_transformations(
|
| 201 |
+
forward_bbox2_points, bbox1_transformations
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Count the number of points in the other bounding box
|
| 205 |
+
num_bbox1_points_in_bbox2 = _count_points_in_box3d(
|
| 206 |
+
forward_bbox1_points_in_rest_bbox2_frame, bbox2_vertices
|
| 207 |
+
)
|
| 208 |
+
num_bbox2_points_in_bbox1 = _count_points_in_box3d(
|
| 209 |
+
forward_bbox2_points_in_rest_bbox1_frame, bbox1_vertices
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Compute the IoU
|
| 213 |
+
intersect = (
|
| 214 |
+
bbox1_volume * num_bbox1_points_in_bbox2
|
| 215 |
+
+ bbox2_volume * num_bbox2_points_in_bbox1
|
| 216 |
+
) / 2
|
| 217 |
+
union = bbox1_volume * num_samples + bbox2_volume * num_samples - intersect
|
| 218 |
+
iou = intersect / union
|
| 219 |
+
|
| 220 |
+
return iou
|
metrics/iou_cdist.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file computes the IoU-based and centroid-distance-based metrics in a symmetric manner\n
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys, os
|
| 6 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 7 |
+
import numpy as np
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from objects.dict_utils import (
|
| 10 |
+
get_base_part_idx,
|
| 11 |
+
get_bbox_vertices,
|
| 12 |
+
remove_handles,
|
| 13 |
+
compute_overall_bbox_size,
|
| 14 |
+
rescale_object,
|
| 15 |
+
find_part_mapping,
|
| 16 |
+
zero_center_object,
|
| 17 |
+
)
|
| 18 |
+
from objects.motions import transform_all_parts
|
| 19 |
+
from metrics.giou import sampling_giou, sampling_cDist
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _get_scores(
|
| 23 |
+
src_dict,
|
| 24 |
+
tgt_dict,
|
| 25 |
+
original_src_bbox_vertices,
|
| 26 |
+
original_tgt_bbox_vertices,
|
| 27 |
+
mapping,
|
| 28 |
+
num_states,
|
| 29 |
+
rotation_fix_range,
|
| 30 |
+
num_samples,
|
| 31 |
+
iou_include_base,
|
| 32 |
+
):
|
| 33 |
+
# Record the indices of the base parts of the src objects
|
| 34 |
+
src_base_idx = get_base_part_idx(src_dict)
|
| 35 |
+
|
| 36 |
+
# Compute the sum of IoU between the generated object and the candidate object over a number of articulation states
|
| 37 |
+
num_parts_in_src = len(src_dict["diffuse_tree"])
|
| 38 |
+
iou_per_part_and_state = np.zeros((num_parts_in_src, num_states), dtype=np.float32)
|
| 39 |
+
cDist_per_part_and_state = np.zeros(
|
| 40 |
+
(num_parts_in_src, num_states), dtype=np.float32
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
states = np.linspace(0, 1, num_states)
|
| 44 |
+
for state_idx, state in enumerate(states):
|
| 45 |
+
|
| 46 |
+
# Get a fresh copy of the bounding box vertices in rest pose
|
| 47 |
+
src_bbox_vertices = deepcopy(original_src_bbox_vertices)
|
| 48 |
+
tgt_bbox_vertices = deepcopy(original_tgt_bbox_vertices)
|
| 49 |
+
|
| 50 |
+
# Transform the objects to the current state using the joints
|
| 51 |
+
src_part_transfomrations = transform_all_parts(
|
| 52 |
+
src_bbox_vertices,
|
| 53 |
+
src_dict,
|
| 54 |
+
state,
|
| 55 |
+
rotation_fix_range=rotation_fix_range,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
tgt_part_transfomrations = transform_all_parts(
|
| 60 |
+
tgt_bbox_vertices,
|
| 61 |
+
tgt_dict,
|
| 62 |
+
state,
|
| 63 |
+
rotation_fix_range=rotation_fix_range,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Compute the IoU between the two objects using the transformed bounding boxes and the part mapping
|
| 67 |
+
for src_part_idx in range(num_parts_in_src):
|
| 68 |
+
|
| 69 |
+
# Get the index of the corresponding part in the candidate object
|
| 70 |
+
tgt_part_idx = int(mapping[src_part_idx, 0])
|
| 71 |
+
|
| 72 |
+
# Always use a fresh copy of the bounding box vertices in rest pose in case dry_run=False is incorrectly set
|
| 73 |
+
src_part_bbox_vertices = deepcopy(original_src_bbox_vertices)[src_part_idx]
|
| 74 |
+
tgt_part_bbox_vertices = deepcopy(original_tgt_bbox_vertices)[tgt_part_idx]
|
| 75 |
+
|
| 76 |
+
# Compute the sampling-based IoU between the two parts
|
| 77 |
+
|
| 78 |
+
iou_per_part_and_state[src_part_idx, state_idx] = sampling_giou(
|
| 79 |
+
src_part_bbox_vertices,
|
| 80 |
+
tgt_part_bbox_vertices,
|
| 81 |
+
src_part_transfomrations[src_part_idx],
|
| 82 |
+
tgt_part_transfomrations[tgt_part_idx],
|
| 83 |
+
num_samples=num_samples,
|
| 84 |
+
)
|
| 85 |
+
# Compute the centriod distance between the two matched parts
|
| 86 |
+
cDist_per_part_and_state[src_part_idx, state_idx] = sampling_cDist(
|
| 87 |
+
src_dict["diffuse_tree"][src_part_idx],
|
| 88 |
+
tgt_dict["diffuse_tree"][tgt_part_idx],
|
| 89 |
+
src_part_transfomrations[src_part_idx],
|
| 90 |
+
tgt_part_transfomrations[tgt_part_idx],
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# IoU and cDist at the resting state
|
| 94 |
+
per_part_iou_avg_at_rest = iou_per_part_and_state[:, 0]
|
| 95 |
+
per_part_cDist_avg_at_rest = cDist_per_part_and_state[:, 0]
|
| 96 |
+
|
| 97 |
+
# Average the IoU over the states
|
| 98 |
+
per_part_iou_avg_over_states = np.sum(iou_per_part_and_state, axis=1) / num_states
|
| 99 |
+
# Average the cDist over the states
|
| 100 |
+
per_part_cDist_avg_over_states = (
|
| 101 |
+
np.sum(cDist_per_part_and_state, axis=1) / num_states
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Remove the base part if specified
|
| 105 |
+
if not iou_include_base:
|
| 106 |
+
per_part_iou_avg_over_states = np.delete(
|
| 107 |
+
per_part_iou_avg_over_states, src_base_idx
|
| 108 |
+
)
|
| 109 |
+
per_part_iou_avg_at_rest = np.delete(per_part_iou_avg_at_rest, src_base_idx)
|
| 110 |
+
per_part_cDist_avg_over_states = np.delete(
|
| 111 |
+
per_part_cDist_avg_over_states, src_base_idx
|
| 112 |
+
)
|
| 113 |
+
per_part_cDist_avg_at_rest = np.delete(per_part_cDist_avg_at_rest, src_base_idx)
|
| 114 |
+
|
| 115 |
+
aid_iou = float(np.mean(per_part_iou_avg_over_states)) if len(per_part_iou_avg_over_states) > 0 else 0
|
| 116 |
+
aid_cdist = float(np.mean(per_part_cDist_avg_over_states)) if len(per_part_cDist_avg_over_states) > 0 else 1
|
| 117 |
+
rid_iou = float(np.mean(per_part_iou_avg_at_rest)) if len(per_part_iou_avg_at_rest) > 0 else 0
|
| 118 |
+
rid_cdist = float(np.mean(per_part_cDist_avg_at_rest)) if len(per_part_cDist_avg_at_rest) > 0 else 1
|
| 119 |
+
|
| 120 |
+
return {
|
| 121 |
+
"AS-IoU": 1. - aid_iou,
|
| 122 |
+
"AS-cDist": aid_cdist,
|
| 123 |
+
"RS-IoU": 1. - rid_iou,
|
| 124 |
+
"RS-cDist": rid_cdist
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def IoU_cDist(
|
| 129 |
+
gen_obj_dict,
|
| 130 |
+
gt_obj_dict,
|
| 131 |
+
num_states=2,
|
| 132 |
+
compare_handles=False,
|
| 133 |
+
iou_include_base=False,
|
| 134 |
+
rotation_fix_range=True,
|
| 135 |
+
num_samples=10000,
|
| 136 |
+
):
|
| 137 |
+
"""
|
| 138 |
+
Compute the IoU-based and centroid-distance-based metrics\n
|
| 139 |
+
This metric is the average sum of IoU between parts in the two objects over the sampled articulation states and at the resting state\n
|
| 140 |
+
|
| 141 |
+
- gen_obj_dict: the dictionary of the generated object\n
|
| 142 |
+
- gt_obj_dict: the dictionary of the gt object\n
|
| 143 |
+
- num_states: the number of articulation states to compute the metric\n
|
| 144 |
+
- compare_handles (optional): whether to compare the handles\n
|
| 145 |
+
- iou_include_base (optional): whether to include the base part in the IoU computation\n
|
| 146 |
+
- rotation_fix_range (optional): whether to fix the rotation range to 90 degrees for revolute joints\n
|
| 147 |
+
- num_samples (optional): the number of samples to use\n
|
| 148 |
+
|
| 149 |
+
Return:\n
|
| 150 |
+
- scores: a dictionary of the computed scores\n
|
| 151 |
+
- "AS-IoU": the average IoU over the articulation states\n
|
| 152 |
+
- "AS-cDist": the average centroid distance over the articulation states\n
|
| 153 |
+
- "RS-IoU": the average IoU at the resting state\n
|
| 154 |
+
- "RS-cDist": the average centroid distance at the resting state\n
|
| 155 |
+
"""
|
| 156 |
+
# Make copies of the dictionaries to avoid modifying the original dictionaries
|
| 157 |
+
gen_dict = deepcopy(gen_obj_dict)
|
| 158 |
+
gt_dict = deepcopy(gt_obj_dict)
|
| 159 |
+
|
| 160 |
+
# Strip the handles from the object if not comparing them
|
| 161 |
+
if not compare_handles:
|
| 162 |
+
gen_dict = remove_handles(gen_dict)
|
| 163 |
+
gt_dict = remove_handles(gt_dict)
|
| 164 |
+
|
| 165 |
+
# Zero center the objects
|
| 166 |
+
zero_center_object(gen_dict)
|
| 167 |
+
zero_center_object(gt_dict)
|
| 168 |
+
|
| 169 |
+
# scale the generated object as a whole to match the size of the gt object
|
| 170 |
+
gen_bbox_size = compute_overall_bbox_size(gen_dict)
|
| 171 |
+
gt_bbox_size = compute_overall_bbox_size(gt_dict)
|
| 172 |
+
scale_factor = gt_bbox_size / gen_bbox_size
|
| 173 |
+
rescale_object(gen_dict, scale_factor)
|
| 174 |
+
|
| 175 |
+
mapping_gen2gt = find_part_mapping(gen_dict, gt_dict, use_hungarian=True)
|
| 176 |
+
# for i in range(mapping_gen2gt.shape[0]):
|
| 177 |
+
# if mapping_gen2gt[i][0] < 100:
|
| 178 |
+
# gen_dict['diffuse_tree'][i]["parent"] = gt_dict['diffuse_tree'][int(mapping_gen2gt[i][0])]["parent"]
|
| 179 |
+
# gen_dict['diffuse_tree'][i]["children"] = gt_dict['diffuse_tree'][int(mapping_gen2gt[i][0])]["children"]
|
| 180 |
+
# gen_dict['diffuse_tree'][i]["id"] = gt_dict['diffuse_tree'][int(mapping_gen2gt[i][0])]["id"]
|
| 181 |
+
# mapping_gen2gt = find_part_mapping(gen_dict, gt_dict, use_hungarian=True)
|
| 182 |
+
mapping_gt2gen = find_part_mapping(gt_dict, gen_dict, use_hungarian=True)
|
| 183 |
+
|
| 184 |
+
# Save the original bounding box vertices in rest pose
|
| 185 |
+
original_gen_bbox_vertices = np.array(
|
| 186 |
+
[get_bbox_vertices(gen_dict, i) for i in range(len(gen_dict["diffuse_tree"]))],
|
| 187 |
+
dtype=np.float32,
|
| 188 |
+
)
|
| 189 |
+
original_gt_bbox_vertices = np.array(
|
| 190 |
+
[get_bbox_vertices(gt_dict, i) for i in range(len(gt_dict["diffuse_tree"]))],
|
| 191 |
+
dtype=np.float32,
|
| 192 |
+
)
|
| 193 |
+
# import ipdb
|
| 194 |
+
# ipdb.set_trace()
|
| 195 |
+
scores_gen2gt = _get_scores(
|
| 196 |
+
gen_dict,
|
| 197 |
+
gt_dict,
|
| 198 |
+
original_gen_bbox_vertices,
|
| 199 |
+
original_gt_bbox_vertices,
|
| 200 |
+
mapping_gen2gt,
|
| 201 |
+
num_states,
|
| 202 |
+
rotation_fix_range,
|
| 203 |
+
num_samples,
|
| 204 |
+
iou_include_base,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
scores_gt2gen = _get_scores(
|
| 208 |
+
gt_dict,
|
| 209 |
+
gen_dict,
|
| 210 |
+
original_gt_bbox_vertices,
|
| 211 |
+
original_gen_bbox_vertices,
|
| 212 |
+
mapping_gt2gen,
|
| 213 |
+
num_states,
|
| 214 |
+
rotation_fix_range,
|
| 215 |
+
num_samples,
|
| 216 |
+
iou_include_base,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
scores = {
|
| 221 |
+
"AS-IoU": (scores_gen2gt["AS-IoU"] + scores_gt2gen["AS-IoU"]) / 2,
|
| 222 |
+
"AS-cDist": (scores_gen2gt["AS-cDist"] + scores_gt2gen["AS-cDist"]) / 2,
|
| 223 |
+
"RS-IoU": (scores_gen2gt["RS-IoU"] + scores_gt2gen["RS-IoU"]) / 2,
|
| 224 |
+
"RS-cDist": (scores_gen2gt["RS-cDist"] + scores_gt2gen["RS-cDist"]) / 2,
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
return scores
|
models/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
models = {}
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def register(name):
|
| 5 |
+
def decorator(cls):
|
| 6 |
+
models[name] = cls
|
| 7 |
+
return cls
|
| 8 |
+
|
| 9 |
+
return decorator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def make(name, config):
|
| 13 |
+
if name == 'model_B9':
|
| 14 |
+
name = 'denoiser_singapo'
|
| 15 |
+
model = models[name](config)
|
| 16 |
+
return model
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from . import denoiser
|
models/denoiser.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 3 |
+
import torch
|
| 4 |
+
import models
|
| 5 |
+
from torch import nn
|
| 6 |
+
from diffusers.models.attention import Attention, FeedForward
|
| 7 |
+
from models.utils import (
|
| 8 |
+
PEmbeder,
|
| 9 |
+
FinalLayer,
|
| 10 |
+
VisAttnProcessor,
|
| 11 |
+
MyAdaLayerNormZero
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
class RAPCrossAttnBlock(nn.Module):
|
| 15 |
+
def __init__(self, dim, num_layers, num_heads, head_dim, dropout=0.0, img_emb_dims=None):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.layers = nn.ModuleList([
|
| 18 |
+
Attention(
|
| 19 |
+
query_dim=dim,
|
| 20 |
+
cross_attention_dim=dim,
|
| 21 |
+
heads=num_heads,
|
| 22 |
+
dim_head=head_dim,
|
| 23 |
+
dropout=dropout,
|
| 24 |
+
bias=True,
|
| 25 |
+
cross_attention_norm="layer_norm",
|
| 26 |
+
processor=VisAttnProcessor(),
|
| 27 |
+
)
|
| 28 |
+
for _ in range(num_layers)
|
| 29 |
+
])
|
| 30 |
+
self.norms = nn.ModuleList([
|
| 31 |
+
nn.LayerNorm(dim) for _ in range(num_layers)
|
| 32 |
+
])
|
| 33 |
+
|
| 34 |
+
img_emb_layers = []
|
| 35 |
+
for i in range(len(img_emb_dims) - 1):
|
| 36 |
+
img_emb_layers.append(nn.Linear(img_emb_dims[i], img_emb_dims[i + 1]))
|
| 37 |
+
img_emb_layers.append(nn.LeakyReLU(inplace=True))
|
| 38 |
+
img_emb_layers.pop(-1)
|
| 39 |
+
self.img_emb = nn.Sequential(*img_emb_layers)
|
| 40 |
+
self.init_img_emb_weights()
|
| 41 |
+
|
| 42 |
+
def init_img_emb_weights(self):
|
| 43 |
+
for m in self.img_emb.modules():
|
| 44 |
+
if isinstance(m, nn.Linear):
|
| 45 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_in")
|
| 46 |
+
if m.bias is not None:
|
| 47 |
+
nn.init.constant_(m.bias, 0)
|
| 48 |
+
|
| 49 |
+
def forward(self, img_first, img_second):
|
| 50 |
+
"""
|
| 51 |
+
Inputs:
|
| 52 |
+
img_first: (B, Np, D)
|
| 53 |
+
img_second: (B, Np, D)
|
| 54 |
+
Output:
|
| 55 |
+
fused_feat: (B, Np, D)
|
| 56 |
+
"""
|
| 57 |
+
img_first = self.img_emb(img_first)
|
| 58 |
+
img_second = self.img_emb(img_second)
|
| 59 |
+
fused = img_second
|
| 60 |
+
for norm, attn in zip(self.norms, self.layers):
|
| 61 |
+
normed = norm(fused)
|
| 62 |
+
delta, _ = attn(normed, encoder_hidden_states=img_first, attention_mask=None)
|
| 63 |
+
fused = fused + delta # residual connection
|
| 64 |
+
return fused
|
| 65 |
+
|
| 66 |
+
class Attn_Block(nn.Module):
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
dim: int,
|
| 70 |
+
num_attention_heads: int,
|
| 71 |
+
attention_head_dim: int,
|
| 72 |
+
dropout=0.0,
|
| 73 |
+
activation_fn: str = "geglu",
|
| 74 |
+
num_embeds_ada_norm: int = None,
|
| 75 |
+
attention_bias: bool = False,
|
| 76 |
+
norm_elementwise_affine: bool = True,
|
| 77 |
+
final_dropout: bool = False,
|
| 78 |
+
class_dropout_prob: float = 0.0, # for classifier-free
|
| 79 |
+
img_emb_dims=None,
|
| 80 |
+
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
self.norm1 = MyAdaLayerNormZero(dim, num_embeds_ada_norm, class_dropout_prob)
|
| 85 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 86 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 87 |
+
self.norm4 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 88 |
+
self.norm5 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 89 |
+
self.norm6 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
| 90 |
+
|
| 91 |
+
self.local_attn = Attention(
|
| 92 |
+
query_dim=dim,
|
| 93 |
+
heads=num_attention_heads,
|
| 94 |
+
dim_head=attention_head_dim,
|
| 95 |
+
dropout=dropout,
|
| 96 |
+
bias=attention_bias,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.global_attn = Attention(
|
| 100 |
+
query_dim=dim,
|
| 101 |
+
heads=num_attention_heads,
|
| 102 |
+
dim_head=attention_head_dim,
|
| 103 |
+
dropout=dropout,
|
| 104 |
+
bias=attention_bias,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.graph_attn = Attention(
|
| 108 |
+
query_dim=dim,
|
| 109 |
+
heads=num_attention_heads,
|
| 110 |
+
dim_head=attention_head_dim,
|
| 111 |
+
dropout=dropout,
|
| 112 |
+
bias=attention_bias,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.img_attn = Attention(
|
| 116 |
+
query_dim=dim,
|
| 117 |
+
cross_attention_dim=dim,
|
| 118 |
+
heads=num_attention_heads,
|
| 119 |
+
dim_head=attention_head_dim,
|
| 120 |
+
dropout=dropout,
|
| 121 |
+
bias=attention_bias,
|
| 122 |
+
cross_attention_norm="layer_norm",
|
| 123 |
+
processor=VisAttnProcessor(), # to be removed for release model
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
self.img_attn_second = Attention(
|
| 127 |
+
query_dim=dim,
|
| 128 |
+
cross_attention_dim=dim,
|
| 129 |
+
heads=num_attention_heads,
|
| 130 |
+
dim_head=attention_head_dim,
|
| 131 |
+
dropout=dropout,
|
| 132 |
+
bias=attention_bias,
|
| 133 |
+
cross_attention_norm="layer_norm",
|
| 134 |
+
processor=VisAttnProcessor(), # to be removed for release model
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.ff = FeedForward(
|
| 138 |
+
dim,
|
| 139 |
+
dropout=dropout,
|
| 140 |
+
activation_fn=activation_fn,
|
| 141 |
+
final_dropout=final_dropout,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# image embedding layers
|
| 145 |
+
layers = []
|
| 146 |
+
for i in range(len(img_emb_dims) - 1):
|
| 147 |
+
layers.append(nn.Linear(img_emb_dims[i], img_emb_dims[i + 1]))
|
| 148 |
+
layers.append(nn.LeakyReLU(inplace=True))
|
| 149 |
+
layers.pop(-1)
|
| 150 |
+
self.img_emb = nn.Sequential(*layers)
|
| 151 |
+
self.init_img_emb_weights()
|
| 152 |
+
|
| 153 |
+
def init_img_emb_weights(self):
|
| 154 |
+
for m in self.img_emb.modules():
|
| 155 |
+
if isinstance(m, nn.Linear):
|
| 156 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_in")
|
| 157 |
+
if m.bias is not None:
|
| 158 |
+
nn.init.constant_(m.bias, 0)
|
| 159 |
+
|
| 160 |
+
def forward(
|
| 161 |
+
self,
|
| 162 |
+
hidden_states,
|
| 163 |
+
img_patches,
|
| 164 |
+
fuse_feat,
|
| 165 |
+
pad_mask,
|
| 166 |
+
attr_mask,
|
| 167 |
+
graph_mask,
|
| 168 |
+
timestep,
|
| 169 |
+
class_labels,
|
| 170 |
+
label_free=False,
|
| 171 |
+
):
|
| 172 |
+
# image patches embedding
|
| 173 |
+
img_emb = self.img_emb(img_patches)
|
| 174 |
+
|
| 175 |
+
# adaptive normalization, taken timestep and class_labels as input condition
|
| 176 |
+
norm_hidden_states, gate_1, shift_mlp, scale_mlp, gate_mlp, gate_2, gate_3 = (
|
| 177 |
+
self.norm1(
|
| 178 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype,
|
| 179 |
+
label_free=label_free
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# local attribute self-attention
|
| 184 |
+
attr_out = self.local_attn(norm_hidden_states, attention_mask=attr_mask)
|
| 185 |
+
attr_out = gate_1.unsqueeze(1) * attr_out
|
| 186 |
+
hidden_states = hidden_states + attr_out
|
| 187 |
+
|
| 188 |
+
# global attribute self-attention
|
| 189 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 190 |
+
global_out = self.global_attn(norm_hidden_states, attention_mask=pad_mask)
|
| 191 |
+
global_out = gate_2.unsqueeze(1) * global_out
|
| 192 |
+
hidden_states = hidden_states + global_out
|
| 193 |
+
|
| 194 |
+
# graph relation self-attention
|
| 195 |
+
norm_hidden_states = self.norm3(hidden_states)
|
| 196 |
+
graph_out = self.graph_attn(norm_hidden_states, attention_mask=graph_mask)
|
| 197 |
+
graph_out = gate_3.unsqueeze(1) * graph_out
|
| 198 |
+
hidden_states = hidden_states + graph_out
|
| 199 |
+
|
| 200 |
+
img_first, img_second = img_emb.chunk(2, dim=1)
|
| 201 |
+
|
| 202 |
+
# cross attention with image patches
|
| 203 |
+
norm_hidden_states = self.norm4(hidden_states)
|
| 204 |
+
B, Na, D = norm_hidden_states.shape
|
| 205 |
+
Np = img_first.shape[1] # number of image patches
|
| 206 |
+
mode_num = Na // 32
|
| 207 |
+
reshaped = norm_hidden_states.reshape(B, Na // mode_num, mode_num, D)
|
| 208 |
+
bboxes = reshaped[:, :, 0, :] # (B, K, D)
|
| 209 |
+
# cross attention between bbox attributes and image patches
|
| 210 |
+
bbox_img_out, bbox_cross_attn_map = self.img_attn(
|
| 211 |
+
bboxes,
|
| 212 |
+
encoder_hidden_states=img_first,
|
| 213 |
+
attention_mask=None,
|
| 214 |
+
) # cross_attn_map: (B, n_head, K, Np)
|
| 215 |
+
|
| 216 |
+
# to reshape the cross_attn_map back to (B, n_head, Na*5, Np), reduntant for other attributes, fix later
|
| 217 |
+
# cross_attn_map_reshape = torch.zeros(size=(B, bbox_cross_attn_map.shape[1], Na // mode_num, mode_num, Np), device=bbox_cross_attn_map.device)
|
| 218 |
+
# cross_attn_map_reshape[:, :, :, 0, :] = bbox_cross_attn_map
|
| 219 |
+
# cross_attn_map = cross_attn_map_reshape.reshape(B, bbox_cross_attn_map.shape[1], Na, Np)
|
| 220 |
+
|
| 221 |
+
# assemble the output of cross attention with bbox attributes and other attributes
|
| 222 |
+
img_out = torch.empty(size=(B, Na // mode_num, mode_num, D), device=hidden_states.device, dtype=hidden_states.dtype)
|
| 223 |
+
img_out[:, :, 0, :] = bbox_img_out
|
| 224 |
+
img_out[:, :, 1:, :] = reshaped[:, :, 1:, :]
|
| 225 |
+
img_out = img_out.reshape(B, Na, D)
|
| 226 |
+
hidden_states = hidden_states + img_out
|
| 227 |
+
|
| 228 |
+
norm_hidden_states = self.norm6(hidden_states)
|
| 229 |
+
B, Na, D = norm_hidden_states.shape
|
| 230 |
+
Np = img_second.shape[1] # number of image patches
|
| 231 |
+
mode_num = Na // 32
|
| 232 |
+
reshaped = norm_hidden_states.reshape(B, Na // mode_num, mode_num, D)
|
| 233 |
+
joints = reshaped # (B, K, 4, D)
|
| 234 |
+
joints = joints.reshape(B, Na // mode_num * 5, D)
|
| 235 |
+
# cross attention between bbox attributes and image patches
|
| 236 |
+
joint_img_out, bbox_cross_attn_map = self.img_attn_second(
|
| 237 |
+
joints,
|
| 238 |
+
encoder_hidden_states=fuse_feat,
|
| 239 |
+
attention_mask=None,
|
| 240 |
+
) # cross_attn_map: (B, n_head, K*4, Np)
|
| 241 |
+
|
| 242 |
+
# to reshape the cross_attn_map back to (B, n_head, Na*5, Np), reduntant for other attributes, fix later
|
| 243 |
+
# cross_attn_map_reshape = torch.zeros(size=(B, bbox_cross_attn_map.shape[1], Na // mode_num, mode_num, Np), device=bbox_cross_attn_map.device)
|
| 244 |
+
# cross_attn_map_reshape[:, :, :, 1:5, :] = bbox_cross_attn_map.reshape(
|
| 245 |
+
# B, bbox_cross_attn_map.shape[1], Na // mode_num, 4, Np
|
| 246 |
+
# )
|
| 247 |
+
# cross_attn_map = cross_attn_map_reshape.reshape(B, bbox_cross_attn_map.shape[1], Na, Np)
|
| 248 |
+
|
| 249 |
+
# assemble the output of cross attention with bbox attributes and other attributes
|
| 250 |
+
img_out = torch.empty(size=(B, Na // mode_num, mode_num, D), device=hidden_states.device, dtype=hidden_states.dtype)
|
| 251 |
+
img_out = joint_img_out.reshape(B, Na // mode_num, 5, D)
|
| 252 |
+
img_out = img_out.reshape(B, Na, D)
|
| 253 |
+
hidden_states = hidden_states + img_out
|
| 254 |
+
|
| 255 |
+
# feed-forward
|
| 256 |
+
norm_hidden_states = self.norm5(hidden_states)
|
| 257 |
+
norm_hidden_states = (
|
| 258 |
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 259 |
+
)
|
| 260 |
+
ff_output = self.ff(norm_hidden_states)
|
| 261 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
| 262 |
+
hidden_states = ff_output + hidden_states
|
| 263 |
+
|
| 264 |
+
return hidden_states
|
| 265 |
+
|
| 266 |
+
@models.register("denoiser")
|
| 267 |
+
class Denoiser(nn.Module):
|
| 268 |
+
"""
|
| 269 |
+
Denoiser based on CAGE's attribute attention block + our ICA module, with 4 sequential attentions: LA -> GA -> GRA -> ICA
|
| 270 |
+
Different image adapters for each layer.
|
| 271 |
+
The image cross attention is with key-padding masks (object mask, part mask)
|
| 272 |
+
*** The ICA only applies to the bbox attributes, not other attributes such as motion params.***
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def __init__(self, hparams):
|
| 276 |
+
super().__init__()
|
| 277 |
+
self.hparams = hparams
|
| 278 |
+
self.K = self.hparams.get("K", 32)
|
| 279 |
+
|
| 280 |
+
in_ch = hparams.in_ch
|
| 281 |
+
attn_dim = hparams.attn_dim
|
| 282 |
+
mid_dim = attn_dim // 2
|
| 283 |
+
n_head = hparams.n_head
|
| 284 |
+
head_dim = attn_dim // n_head
|
| 285 |
+
num_embeds_ada_norm = 6 * attn_dim
|
| 286 |
+
|
| 287 |
+
# embedding layers for different node attributes
|
| 288 |
+
self.aabb_emb = nn.Sequential(
|
| 289 |
+
nn.Linear(in_ch, mid_dim),
|
| 290 |
+
nn.ReLU(inplace=True),
|
| 291 |
+
nn.Linear(mid_dim, attn_dim),
|
| 292 |
+
)
|
| 293 |
+
self.jaxis_emb = nn.Sequential(
|
| 294 |
+
nn.Linear(in_ch, mid_dim),
|
| 295 |
+
nn.ReLU(inplace=True),
|
| 296 |
+
nn.Linear(mid_dim, attn_dim),
|
| 297 |
+
)
|
| 298 |
+
self.range_emb = nn.Sequential(
|
| 299 |
+
nn.Linear(in_ch, mid_dim),
|
| 300 |
+
nn.ReLU(inplace=True),
|
| 301 |
+
nn.Linear(mid_dim, attn_dim),
|
| 302 |
+
)
|
| 303 |
+
self.label_emb = nn.Sequential(
|
| 304 |
+
nn.Linear(in_ch, mid_dim),
|
| 305 |
+
nn.ReLU(inplace=True),
|
| 306 |
+
nn.Linear(mid_dim, attn_dim),
|
| 307 |
+
)
|
| 308 |
+
self.jtype_emb = nn.Sequential(
|
| 309 |
+
nn.Linear(in_ch, mid_dim),
|
| 310 |
+
nn.ReLU(inplace=True),
|
| 311 |
+
nn.Linear(mid_dim, attn_dim),
|
| 312 |
+
)
|
| 313 |
+
# self.node_type_emb = nn.Sequential(
|
| 314 |
+
# nn.Linear(in_ch, mid_dim),
|
| 315 |
+
# nn.ReLU(inplace=True),
|
| 316 |
+
# nn.Linear(mid_dim, attn_dim),
|
| 317 |
+
# )
|
| 318 |
+
# positional encoding for nodes and attributes
|
| 319 |
+
self.pe_node = PEmbeder(self.K, attn_dim)
|
| 320 |
+
self.pe_attr = PEmbeder(self.hparams.mode_num, attn_dim)
|
| 321 |
+
|
| 322 |
+
# attention layers
|
| 323 |
+
self.attn_layers = nn.ModuleList(
|
| 324 |
+
[
|
| 325 |
+
Attn_Block(
|
| 326 |
+
dim=attn_dim,
|
| 327 |
+
num_attention_heads=n_head,
|
| 328 |
+
attention_head_dim=head_dim,
|
| 329 |
+
class_dropout_prob=hparams.get("cat_drop_prob", 0.0),
|
| 330 |
+
dropout=hparams.dropout,
|
| 331 |
+
activation_fn="geglu",
|
| 332 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 333 |
+
attention_bias=False,
|
| 334 |
+
norm_elementwise_affine=True,
|
| 335 |
+
final_dropout=False,
|
| 336 |
+
img_emb_dims=hparams.get("img_emb_dims", None),
|
| 337 |
+
)
|
| 338 |
+
for d in range(hparams.n_layers)
|
| 339 |
+
]
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
self.image_interaction = RAPCrossAttnBlock(
|
| 343 |
+
dim=attn_dim,
|
| 344 |
+
num_layers=6,
|
| 345 |
+
num_heads=n_head,
|
| 346 |
+
head_dim=head_dim,
|
| 347 |
+
dropout=hparams.dropout,
|
| 348 |
+
img_emb_dims=hparams.get("img_emb_dims", None),
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
self.final_layer = FinalLayer(attn_dim, in_ch)
|
| 352 |
+
|
| 353 |
+
def forward(
|
| 354 |
+
self,
|
| 355 |
+
x,
|
| 356 |
+
cat,
|
| 357 |
+
timesteps,
|
| 358 |
+
feat,
|
| 359 |
+
key_pad_mask=None,
|
| 360 |
+
graph_mask=None,
|
| 361 |
+
attr_mask=None,
|
| 362 |
+
label_free=False,
|
| 363 |
+
):
|
| 364 |
+
B = x.shape[0]
|
| 365 |
+
x = x.view(B, self.K, 5 * 6)
|
| 366 |
+
|
| 367 |
+
# embedding layers for different attributes
|
| 368 |
+
x_aabb = self.aabb_emb(x[..., :6])
|
| 369 |
+
x_jtype = self.jtype_emb(x[..., 6:12])
|
| 370 |
+
x_jaxis = self.jaxis_emb(x[..., 12:18])
|
| 371 |
+
x_range = self.range_emb(x[..., 18:24])
|
| 372 |
+
x_label = self.label_emb(x[..., 24:30])
|
| 373 |
+
# x_node_type = self.node_type_emb(x[..., 30:36])
|
| 374 |
+
|
| 375 |
+
# concatenate all attribute embeddings
|
| 376 |
+
x_ = torch.cat(
|
| 377 |
+
[x_aabb, x_jtype, x_jaxis, x_range, x_label], dim=2
|
| 378 |
+
) # (B, K, 6*attn_dim)
|
| 379 |
+
x_ = x_.view(B, self.K * self.hparams.mode_num, self.hparams.attn_dim)
|
| 380 |
+
|
| 381 |
+
# positional encoding for nodes and attributes
|
| 382 |
+
idx_attr = torch.tensor(
|
| 383 |
+
[0, 1, 2, 3, 4], device=x.device, dtype=torch.long
|
| 384 |
+
).repeat(self.K)
|
| 385 |
+
idx_node = torch.arange(
|
| 386 |
+
self.K, device=x.device, dtype=torch.long
|
| 387 |
+
).repeat_interleave(self.hparams.mode_num)
|
| 388 |
+
x_ = self.pe_attr(self.pe_node(x_, idx=idx_node), idx=idx_attr)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# init tensor to store attention maps
|
| 392 |
+
Np = feat.shape[1]
|
| 393 |
+
|
| 394 |
+
img_first, img_second = feat.chunk(2, dim=1)
|
| 395 |
+
fused_img_feat = self.image_interaction(img_first, img_second) # (B, Np, D)
|
| 396 |
+
|
| 397 |
+
# attention layers
|
| 398 |
+
for i, attn_layer in enumerate(self.attn_layers):
|
| 399 |
+
x_ = attn_layer(
|
| 400 |
+
hidden_states=x_,
|
| 401 |
+
img_patches=feat,
|
| 402 |
+
fuse_feat=fused_img_feat,
|
| 403 |
+
timestep=timesteps,
|
| 404 |
+
class_labels=cat,
|
| 405 |
+
pad_mask=key_pad_mask,
|
| 406 |
+
graph_mask=graph_mask,
|
| 407 |
+
attr_mask=attr_mask,
|
| 408 |
+
label_free=label_free,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
y = self.final_layer(x_, timesteps, cat)
|
| 412 |
+
return {
|
| 413 |
+
'noise_pred': y,
|
| 414 |
+
'attn_maps': None,
|
| 415 |
+
}
|
models/utils.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding
|
| 5 |
+
|
| 6 |
+
class FinalLayer(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Final layer of the diffusion model that outputs the final logits.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, in_ch, out_ch=None, dropout=0.0):
|
| 11 |
+
super().__init__()
|
| 12 |
+
out_ch = in_ch if out_ch is None else out_ch
|
| 13 |
+
self.linear = nn.Linear(in_ch, out_ch)
|
| 14 |
+
self.norm = AdaLayerNormTC(in_ch, 2 * in_ch, dropout)
|
| 15 |
+
|
| 16 |
+
def forward(self, x, t, cond=None):
|
| 17 |
+
assert cond is not None
|
| 18 |
+
x = self.norm(x, t, cond)
|
| 19 |
+
x = self.linear(x)
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AdaLayerNormTC(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
Norm layer modified to incorporate timestep and condition embeddings.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, embedding_dim, num_embeddings, dropout):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.emb = CombinedTimestepLabelEmbeddings(
|
| 31 |
+
num_embeddings, embedding_dim, dropout
|
| 32 |
+
)
|
| 33 |
+
self.silu = nn.SiLU()
|
| 34 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
| 35 |
+
self.norm = nn.LayerNorm(
|
| 36 |
+
embedding_dim, elementwise_affine=False, eps=torch.finfo(torch.float16).eps
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def forward(self, x, timestep, cond):
|
| 40 |
+
emb = self.linear(self.silu(self.emb(timestep, cond, hidden_dtype=None)))
|
| 41 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 42 |
+
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class PEmbeder(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
Positional embedding layer.
|
| 49 |
+
"""
|
| 50 |
+
def __init__(self, vocab_size, d_model):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
| 53 |
+
self._init_embeddings()
|
| 54 |
+
|
| 55 |
+
def _init_embeddings(self):
|
| 56 |
+
nn.init.kaiming_normal_(self.embed.weight, mode="fan_in")
|
| 57 |
+
|
| 58 |
+
def forward(self, x, idx=None):
|
| 59 |
+
if idx is None:
|
| 60 |
+
idx = torch.arange(x.shape[1], device=x.device, dtype=torch.long)
|
| 61 |
+
return x + self.embed(idx)
|
| 62 |
+
|
| 63 |
+
class CombinedTimestepLabelEmbeddings(nn.Module):
|
| 64 |
+
'''Modified from diffusers.models.embeddings.CombinedTimestepLabelEmbeddings'''
|
| 65 |
+
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
| 69 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 70 |
+
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
| 71 |
+
|
| 72 |
+
def forward(self, timestep, class_labels, hidden_dtype=None, label_free=False):
|
| 73 |
+
timesteps_proj = self.time_proj(timestep)
|
| 74 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
| 75 |
+
force_drop_ids = None # training mode
|
| 76 |
+
if label_free: # inference mode, force_drop_ids is set to all ones to be dropped in class_embedder
|
| 77 |
+
force_drop_ids = torch.ones_like(class_labels, dtype=torch.bool, device=class_labels.device)
|
| 78 |
+
class_labels = self.class_embedder(class_labels, force_drop_ids) # (N, D)
|
| 79 |
+
conditioning = timesteps_emb + class_labels # (N, D)
|
| 80 |
+
return conditioning
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class MyAdaLayerNormZero(nn.Module):
|
| 84 |
+
"""
|
| 85 |
+
Adaptive layer norm zero (adaLN-Zero), borrowed from diffusers.models.attention.AdaLayerNormZero.
|
| 86 |
+
Extended to incorporate scale parameters (gate_2, gate_3) for intermidate attention layers.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(self, embedding_dim, num_embeddings, class_dropout_prob):
|
| 90 |
+
super().__init__()
|
| 91 |
+
|
| 92 |
+
self.emb = CombinedTimestepLabelEmbeddings(
|
| 93 |
+
num_embeddings, embedding_dim, class_dropout_prob
|
| 94 |
+
)
|
| 95 |
+
self.silu = nn.SiLU()
|
| 96 |
+
self.linear = nn.Linear(embedding_dim, 8 * embedding_dim, bias=True)
|
| 97 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
| 98 |
+
|
| 99 |
+
def forward(self, x, timestep, class_labels, hidden_dtype=None, label_free=False):
|
| 100 |
+
emb_t_cls = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype, label_free=label_free)
|
| 101 |
+
emb = self.linear(self.silu(emb_t_cls))
|
| 102 |
+
(
|
| 103 |
+
shift_msa,
|
| 104 |
+
scale_msa,
|
| 105 |
+
gate_msa,
|
| 106 |
+
shift_mlp,
|
| 107 |
+
scale_mlp,
|
| 108 |
+
gate_mlp,
|
| 109 |
+
gate_2,
|
| 110 |
+
gate_3,
|
| 111 |
+
) = emb.chunk(8, dim=1)
|
| 112 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 113 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_2, gate_3
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class VisAttnProcessor:
|
| 117 |
+
r"""
|
| 118 |
+
This code is adapted from diffusers.models.attention_processor.AttnProcessor.
|
| 119 |
+
Used for visualizing the attention maps when testing, NOT for training.
|
| 120 |
+
"""
|
| 121 |
+
|
| 122 |
+
def __call__(
|
| 123 |
+
self,
|
| 124 |
+
attn,
|
| 125 |
+
hidden_states,
|
| 126 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 127 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 128 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 129 |
+
*args,
|
| 130 |
+
**kwargs,
|
| 131 |
+
) -> torch.Tensor:
|
| 132 |
+
# Removed
|
| 133 |
+
# if len(args) > 0 or kwargs.get("scale", None) is not None:
|
| 134 |
+
# deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
| 135 |
+
# deprecate("scale", "1.0.0", deprecation_message)
|
| 136 |
+
|
| 137 |
+
residual = hidden_states
|
| 138 |
+
|
| 139 |
+
if attn.spatial_norm is not None:
|
| 140 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 141 |
+
|
| 142 |
+
input_ndim = hidden_states.ndim
|
| 143 |
+
|
| 144 |
+
if input_ndim == 4:
|
| 145 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 146 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 147 |
+
|
| 148 |
+
batch_size, sequence_length, _ = (
|
| 149 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 150 |
+
)
|
| 151 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 152 |
+
|
| 153 |
+
if attn.group_norm is not None:
|
| 154 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 155 |
+
|
| 156 |
+
query = attn.to_q(hidden_states)
|
| 157 |
+
|
| 158 |
+
if encoder_hidden_states is None:
|
| 159 |
+
encoder_hidden_states = hidden_states
|
| 160 |
+
elif attn.norm_cross:
|
| 161 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 162 |
+
|
| 163 |
+
key = attn.to_k(encoder_hidden_states)
|
| 164 |
+
value = attn.to_v(encoder_hidden_states)
|
| 165 |
+
|
| 166 |
+
query = attn.head_to_batch_dim(query) # (40, 160, 16)
|
| 167 |
+
key = attn.head_to_batch_dim(key) # (40, 256, 16)
|
| 168 |
+
value = attn.head_to_batch_dim(value) # (40, 256, 16)
|
| 169 |
+
|
| 170 |
+
if attention_mask is not None:
|
| 171 |
+
if attention_mask.dtype == torch.bool:
|
| 172 |
+
attn_mask = torch.zeros_like(attention_mask, dtype=query.dtype, device=query.device)
|
| 173 |
+
attn_mask = attn_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
| 174 |
+
else:
|
| 175 |
+
attn_mask = attention_mask
|
| 176 |
+
assert attn_mask.dtype == query.dtype, f"query and attention_mask must have the same dtype, but got {query.dtype} and {attention_mask.dtype}."
|
| 177 |
+
else:
|
| 178 |
+
attn_mask = None
|
| 179 |
+
attention_probs = attn.get_attention_scores(query, key, attn_mask) # (40, 160, 256)
|
| 180 |
+
hidden_states = torch.bmm(attention_probs, value) # (40, 160, 16)
|
| 181 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 182 |
+
|
| 183 |
+
# linear proj
|
| 184 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 185 |
+
# dropout
|
| 186 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 187 |
+
|
| 188 |
+
if input_ndim == 4:
|
| 189 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 190 |
+
|
| 191 |
+
if attn.residual_connection:
|
| 192 |
+
hidden_states = hidden_states + residual
|
| 193 |
+
|
| 194 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 195 |
+
|
| 196 |
+
attention_probs = attention_probs.reshape(batch_size, attn.heads, query.shape[1], sequence_length)
|
| 197 |
+
|
| 198 |
+
return hidden_states, attention_probs
|
| 199 |
+
|
my_utils/__init__.py
ADDED
|
File without changes
|
my_utils/callbacks.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 3 |
+
import torch
|
| 4 |
+
from my_utils.misc import dump_config
|
| 5 |
+
from lightning.pytorch.callbacks.callback import Callback
|
| 6 |
+
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
| 7 |
+
|
| 8 |
+
class ConfigSnapshotCallback(Callback):
|
| 9 |
+
def __init__(self, config):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.config = config
|
| 12 |
+
|
| 13 |
+
def setup(self, trainer, pl_module, stage) -> None:
|
| 14 |
+
self.savedir = os.path.join(pl_module.hparams.exp_dir, 'config')
|
| 15 |
+
|
| 16 |
+
@rank_zero_only
|
| 17 |
+
def save_config_snapshot(self):
|
| 18 |
+
os.makedirs(self.savedir, exist_ok=True)
|
| 19 |
+
dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config)
|
| 20 |
+
|
| 21 |
+
def on_fit_start(self, trainer, pl_module):
|
| 22 |
+
self.save_config_snapshot()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GPUCacheCleanCallback(Callback):
|
| 26 |
+
def on_train_batch_start(self, *args, **kwargs):
|
| 27 |
+
torch.cuda.empty_cache()
|
| 28 |
+
|
| 29 |
+
def on_validation_batch_start(self, *args, **kwargs):
|
| 30 |
+
torch.cuda.empty_cache()
|
| 31 |
+
|
| 32 |
+
def on_test_batch_start(self, *args, **kwargs):
|
| 33 |
+
torch.cuda.empty_cache()
|
| 34 |
+
|
| 35 |
+
def on_predict_batch_start(self, *args, **kwargs):
|
| 36 |
+
torch.cuda.empty_cache()
|
my_utils/lr_schedulers.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code copied from lightning-bolts
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import List
|
| 7 |
+
from torch.optim import Optimizer
|
| 8 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LinearWarmupCosineAnnealingLR(_LRScheduler):
|
| 12 |
+
"""Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and
|
| 13 |
+
base_lr followed by a cosine annealing schedule between base_lr and eta_min.
|
| 14 |
+
|
| 15 |
+
.. warning::
|
| 16 |
+
It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
|
| 17 |
+
after each iteration as calling it after each epoch will keep the starting lr at
|
| 18 |
+
warmup_start_lr for the first epoch which is 0 in most cases.
|
| 19 |
+
|
| 20 |
+
.. warning::
|
| 21 |
+
passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
|
| 22 |
+
It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
|
| 23 |
+
:func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
|
| 24 |
+
epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
|
| 25 |
+
train and validation methods.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
optimizer: Optimizer,
|
| 31 |
+
warmup_epochs: int,
|
| 32 |
+
max_epochs: int,
|
| 33 |
+
warmup_start_lr: float = 0.0,
|
| 34 |
+
eta_min: float = 0.0,
|
| 35 |
+
last_epoch: int = -1,
|
| 36 |
+
) -> None:
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
optimizer (Optimizer): Wrapped optimizer.
|
| 40 |
+
warmup_epochs (int): Maximum number of iterations for linear warmup
|
| 41 |
+
max_epochs (int): Maximum number of iterations
|
| 42 |
+
warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
|
| 43 |
+
eta_min (float): Minimum learning rate. Default: 0.
|
| 44 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
| 45 |
+
"""
|
| 46 |
+
self.warmup_epochs = warmup_epochs
|
| 47 |
+
self.max_epochs = max_epochs
|
| 48 |
+
self.warmup_start_lr = warmup_start_lr
|
| 49 |
+
self.eta_min = eta_min
|
| 50 |
+
|
| 51 |
+
super().__init__(optimizer, last_epoch)
|
| 52 |
+
|
| 53 |
+
def get_lr(self) -> List[float]:
|
| 54 |
+
"""Compute learning rate using chainable form of the scheduler."""
|
| 55 |
+
if not self._get_lr_called_within_step:
|
| 56 |
+
warnings.warn(
|
| 57 |
+
"To get the last learning rate computed by the scheduler; please use `get_last_lr()`.",
|
| 58 |
+
UserWarning,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if self.last_epoch == 0:
|
| 62 |
+
return [self.warmup_start_lr] * len(self.base_lrs)
|
| 63 |
+
if self.last_epoch < self.warmup_epochs:
|
| 64 |
+
return [
|
| 65 |
+
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
|
| 66 |
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
| 67 |
+
]
|
| 68 |
+
if self.last_epoch == self.warmup_epochs:
|
| 69 |
+
return self.base_lrs
|
| 70 |
+
if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
|
| 71 |
+
return [
|
| 72 |
+
group["lr"]
|
| 73 |
+
+ (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
|
| 74 |
+
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
return [
|
| 78 |
+
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
|
| 79 |
+
/ (
|
| 80 |
+
1
|
| 81 |
+
+ math.cos(
|
| 82 |
+
math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
|
| 83 |
+
)
|
| 84 |
+
)
|
| 85 |
+
* (group["lr"] - self.eta_min)
|
| 86 |
+
+ self.eta_min
|
| 87 |
+
for group in self.optimizer.param_groups
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
def _get_closed_form_lr(self) -> List[float]:
|
| 91 |
+
"""Called when epoch is passed as a param to the `step` function of the scheduler."""
|
| 92 |
+
if self.last_epoch < self.warmup_epochs:
|
| 93 |
+
return [
|
| 94 |
+
self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
|
| 95 |
+
for base_lr in self.base_lrs
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
return [
|
| 99 |
+
self.eta_min
|
| 100 |
+
+ 0.5
|
| 101 |
+
* (base_lr - self.eta_min)
|
| 102 |
+
* (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
|
| 103 |
+
for base_lr in self.base_lrs
|
| 104 |
+
]
|
my_utils/misc.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import OmegaConf
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# ============ Register OmegaConf Recolvers ============= #
|
| 5 |
+
OmegaConf.register_new_resolver('add', lambda a, b: a + b)
|
| 6 |
+
OmegaConf.register_new_resolver('sub', lambda a, b: a - b)
|
| 7 |
+
OmegaConf.register_new_resolver('mul', lambda a, b: a * b)
|
| 8 |
+
OmegaConf.register_new_resolver('div', lambda a, b: a / b)
|
| 9 |
+
# ======================================================= #
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def prompt(question):
|
| 13 |
+
inp = input(f"{question} (y/n)").lower().strip()
|
| 14 |
+
if inp and inp == 'y':
|
| 15 |
+
return True
|
| 16 |
+
if inp and inp == 'n':
|
| 17 |
+
return False
|
| 18 |
+
return prompt(question)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_config(*yaml_files, cli_args=[]):
|
| 22 |
+
yaml_confs = [OmegaConf.load(f) for f in yaml_files]
|
| 23 |
+
cli_conf = OmegaConf.from_cli(cli_args)
|
| 24 |
+
conf = OmegaConf.merge(*yaml_confs, cli_conf)
|
| 25 |
+
OmegaConf.resolve(conf)
|
| 26 |
+
return conf
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def config_to_primitive(config, resolve=True):
|
| 30 |
+
return OmegaConf.to_container(config, resolve=resolve)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def dump_config(path, config):
|
| 34 |
+
with open(path, 'w') as fp:
|
| 35 |
+
OmegaConf.save(config=config, f=fp)
|
my_utils/plot.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import os, sys
|
| 2 |
+
# sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 3 |
+
# import matplotlib
|
| 4 |
+
# matplotlib.use('Agg')
|
| 5 |
+
# import numpy as np
|
| 6 |
+
# import networkx as nx
|
| 7 |
+
# from io import BytesIO
|
| 8 |
+
# from PIL import Image, ImageDraw
|
| 9 |
+
# from matplotlib import pyplot as plt
|
| 10 |
+
# from sklearn.decomposition import PCA
|
| 11 |
+
# from my_utils.refs import graph_color_ref
|
| 12 |
+
|
| 13 |
+
# def add_text(text, imgarr):
|
| 14 |
+
# '''
|
| 15 |
+
# Function to add text to image
|
| 16 |
+
|
| 17 |
+
# Args:
|
| 18 |
+
# - text (str): text to add
|
| 19 |
+
# - imgarr (np.array): image array
|
| 20 |
+
|
| 21 |
+
# Returns:
|
| 22 |
+
# - img (np.array): image array with text
|
| 23 |
+
# '''
|
| 24 |
+
# img = Image.fromarray(imgarr)
|
| 25 |
+
# I = ImageDraw.Draw(img)
|
| 26 |
+
# I.text((10, 10), text, fill='black')
|
| 27 |
+
# return np.asarray(img)
|
| 28 |
+
|
| 29 |
+
# def get_color(ref, n_nodes):
|
| 30 |
+
# '''
|
| 31 |
+
# Function to color the nodes
|
| 32 |
+
|
| 33 |
+
# Args:
|
| 34 |
+
# - ref (list): list of color reference
|
| 35 |
+
# - n_nodes (int): number of nodes
|
| 36 |
+
|
| 37 |
+
# Returns:
|
| 38 |
+
# - colors (list): list of colors
|
| 39 |
+
# '''
|
| 40 |
+
# N = len(ref)
|
| 41 |
+
# colors = []
|
| 42 |
+
# for i in range(n_nodes):
|
| 43 |
+
# colors.append(np.array([[int(i) for i in ref[i%N][4:-1].split(',')]]) / 255.)
|
| 44 |
+
# return colors
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# def make_grid(images, cols=5):
|
| 48 |
+
# """
|
| 49 |
+
# Arrange list of images into a N x cols grid.
|
| 50 |
+
|
| 51 |
+
# Args:
|
| 52 |
+
# - images (list): List of Numpy arrays representing the images.
|
| 53 |
+
# - cols (int): Number of columns for the grid.
|
| 54 |
+
|
| 55 |
+
# Returns:
|
| 56 |
+
# - grid (numpy array): Numpy array representing the image grid.
|
| 57 |
+
# """
|
| 58 |
+
# # Determine the dimensions of each image
|
| 59 |
+
# img_h, img_w, _ = images[0].shape
|
| 60 |
+
# rows = len(images) // cols
|
| 61 |
+
|
| 62 |
+
# # Initialize a blank canvas
|
| 63 |
+
# grid = np.zeros((rows * img_h, cols * img_w, 3), dtype=images[0].dtype)
|
| 64 |
+
|
| 65 |
+
# # Place each image onto the grid
|
| 66 |
+
# for idx, img in enumerate(images):
|
| 67 |
+
# y = (idx // cols) * img_h
|
| 68 |
+
# x = (idx % cols) * img_w
|
| 69 |
+
# grid[y: y + img_h, x: x + img_w] = img
|
| 70 |
+
|
| 71 |
+
# return grid
|
| 72 |
+
|
| 73 |
+
# def viz_graph(info_dict, res=256):
|
| 74 |
+
# '''
|
| 75 |
+
# Function to plot the directed graph
|
| 76 |
+
|
| 77 |
+
# Args:
|
| 78 |
+
# - info_dict (dict): output json containing the graph information
|
| 79 |
+
# - res (int): resolution of the image
|
| 80 |
+
|
| 81 |
+
# Returns:
|
| 82 |
+
# - img_arr (np.array): image array
|
| 83 |
+
# '''
|
| 84 |
+
# # build tree
|
| 85 |
+
# tree = info_dict['diffuse_tree']
|
| 86 |
+
# edges = []
|
| 87 |
+
# for node in tree:
|
| 88 |
+
# edges += [(node['id'], child) for child in node['children']]
|
| 89 |
+
# G = nx.DiGraph()
|
| 90 |
+
# G.add_edges_from(edges)
|
| 91 |
+
|
| 92 |
+
# # plot tree
|
| 93 |
+
# plt.figure(figsize=(res/100, res/100))
|
| 94 |
+
|
| 95 |
+
# colors = get_color(graph_color_ref, len(tree))
|
| 96 |
+
# pos = nx.nx_agraph.graphviz_layout(G, prog="twopi", args="")
|
| 97 |
+
# node_order = sorted(G.nodes())
|
| 98 |
+
# nx.draw(G, pos, node_color=colors, nodelist=node_order, edge_color='k', with_labels=False)
|
| 99 |
+
|
| 100 |
+
# buf = BytesIO()
|
| 101 |
+
# plt.savefig(buf, format="png", dpi=100)
|
| 102 |
+
# buf.seek(0)
|
| 103 |
+
# img = Image.open(buf)
|
| 104 |
+
# img_arr = np.asarray(img)
|
| 105 |
+
# buf.close()
|
| 106 |
+
# plt.clf()
|
| 107 |
+
# plt.close()
|
| 108 |
+
# return img_arr[:, :, :3]
|
| 109 |
+
|
| 110 |
+
# def viz_patch_feat_pca(feat):
|
| 111 |
+
# pca = PCA(n_components=3)
|
| 112 |
+
# pca.fit(feat)
|
| 113 |
+
# feat_pca = pca.transform(feat)
|
| 114 |
+
|
| 115 |
+
# t = np.array(feat_pca)
|
| 116 |
+
# t_min = t.min(axis=0, keepdims=True)
|
| 117 |
+
# t_max = t.max(axis=0, keepdims=True)
|
| 118 |
+
# normalized_t = (t - t_min) / (t_max - t_min)
|
| 119 |
+
|
| 120 |
+
# array = (normalized_t * 255).astype(np.uint8)
|
| 121 |
+
# img_array = array.reshape(16, 16, 3)
|
| 122 |
+
# return img_array
|
my_utils/refs.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# reference of object categories
|
| 2 |
+
cat_ref = {
|
| 3 |
+
"Table": 0,
|
| 4 |
+
"Dishwasher": 1,
|
| 5 |
+
"StorageFurniture": 2,
|
| 6 |
+
"Refrigerator": 3,
|
| 7 |
+
"WashingMachine": 4,
|
| 8 |
+
"Microwave": 5,
|
| 9 |
+
"Oven": 6,
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
data_mode_ref = {
|
| 13 |
+
"aabb_max": 0,
|
| 14 |
+
"aabb_min": 1,
|
| 15 |
+
"joint_type": 2,
|
| 16 |
+
"axis_dir": 3,
|
| 17 |
+
"axis_ori": 4,
|
| 18 |
+
"joint_range": 5,
|
| 19 |
+
"label": 6
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
# reference of semantic labels for each part
|
| 23 |
+
sem_ref = {
|
| 24 |
+
"fwd": {
|
| 25 |
+
"door": 0,
|
| 26 |
+
"drawer": 1,
|
| 27 |
+
"base": 2,
|
| 28 |
+
"handle": 3,
|
| 29 |
+
"wheel": 4,
|
| 30 |
+
"knob": 5,
|
| 31 |
+
"shelf": 6,
|
| 32 |
+
"tray": 7,
|
| 33 |
+
},
|
| 34 |
+
"bwd": {
|
| 35 |
+
0: "door",
|
| 36 |
+
1: "drawer",
|
| 37 |
+
2: "base",
|
| 38 |
+
3: "handle",
|
| 39 |
+
4: "wheel",
|
| 40 |
+
5: "knob",
|
| 41 |
+
6: "shelf",
|
| 42 |
+
7: "tray",
|
| 43 |
+
},
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# reference of joint types for each part
|
| 47 |
+
joint_ref = {
|
| 48 |
+
"fwd": {"fixed": 1, "revolute": 2, "prismatic": 3, "screw": 4, "continuous": 5},
|
| 49 |
+
"bwd": {1: "fixed", 2: "revolute", 3: "prismatic", 4: "screw", 5: "continuous"},
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
import plotly.express as px
|
| 54 |
+
|
| 55 |
+
# pallette for joint type color
|
| 56 |
+
joint_color_ref = px.colors.qualitative.Set1
|
| 57 |
+
# pallette for graph node color
|
| 58 |
+
# graph_color_ref = px.colors.qualitative.Bold + px.colors.qualitative.Prism
|
| 59 |
+
# graph_color_ref = [
|
| 60 |
+
# "rgb(200, 200, 200)", # 奶橙黄
|
| 61 |
+
# "rgb(255, 196, 200)", # 莓奶粉
|
| 62 |
+
# "rgb(154, 228, 186)", # 牛油果绿
|
| 63 |
+
# "rgb(252, 208, 140)", # 奶橙黄
|
| 64 |
+
# "rgb(217, 189, 250)", # 薄紫
|
| 65 |
+
# "rgb(203, 237, 164)", # 抹茶绿
|
| 66 |
+
# "rgb(188, 229, 235)", # 青蓝灰
|
| 67 |
+
# "rgb(179, 199, 243)", # 雾蓝
|
| 68 |
+
# "rgb(255, 224, 130)", # 淡柠黄
|
| 69 |
+
# "rgb(222, 179, 212)", # 粉紫
|
| 70 |
+
# "rgb(148, 212, 224)", # 冰蓝
|
| 71 |
+
# ]
|
| 72 |
+
graph_color_ref = [
|
| 73 |
+
"rgb(160, 160, 160)", # 奶橙灰 → 深灰白,对比提升
|
| 74 |
+
"rgb(255, 130, 145)", # 莓奶粉 → 更亮更红
|
| 75 |
+
"rgb(80, 200, 150)", # 牛油果绿 → 更深更绿
|
| 76 |
+
"rgb(255, 180, 60)", # 奶橙黄 → 更橙更亮
|
| 77 |
+
"rgb(180, 140, 255)", # 薄紫 → 更强饱和度紫
|
| 78 |
+
"rgb(130, 210, 50)", # 抹茶绿 → 偏亮偏黄的绿
|
| 79 |
+
"rgb(90, 190, 220)", # 青蓝灰 → 加蓝提升对比
|
| 80 |
+
"rgb(100, 150, 255)", # 雾蓝 → 饱和冷蓝
|
| 81 |
+
"rgb(255, 200, 0)", # 淡柠黄 → 纯柠黄
|
| 82 |
+
"rgb(200, 100, 190)", # 粉紫 → 更紫
|
| 83 |
+
"rgb(80, 180, 255)", # 冰蓝 → 更冷更亮的蓝
|
| 84 |
+
"rgb(255, 130, 145)", # 莓奶粉 → 更亮更红
|
| 85 |
+
"rgb(80, 200, 150)", # 牛油果绿 → 更深更绿
|
| 86 |
+
"rgb(255, 180, 60)", # 奶橙黄 → 更橙更亮
|
| 87 |
+
"rgb(180, 140, 255)", # 薄紫 → 更强饱和度紫
|
| 88 |
+
"rgb(130, 210, 50)", # 抹茶绿 → 偏亮偏黄的绿
|
| 89 |
+
"rgb(90, 190, 220)", # 青蓝灰 → 加蓝提升对比
|
| 90 |
+
"rgb(100, 150, 255)", # 雾蓝 → 饱和冷蓝
|
| 91 |
+
"rgb(255, 200, 0)", # 淡柠黄 → 纯柠黄
|
| 92 |
+
"rgb(200, 100, 190)", # 粉紫 → 更紫
|
| 93 |
+
"rgb(80, 180, 255)", # 冰蓝 → 更冷更亮的蓝
|
| 94 |
+
"rgb(255, 130, 145)", # 莓奶粉 → 更亮更红
|
| 95 |
+
"rgb(80, 200, 150)", # 牛油果绿 → 更深更绿
|
| 96 |
+
"rgb(255, 180, 60)", # 奶橙黄 → 更橙更亮
|
| 97 |
+
"rgb(180, 140, 255)", # 薄紫 → 更强饱和度紫
|
| 98 |
+
"rgb(130, 210, 50)", # 抹茶绿 → 偏亮偏黄的绿
|
| 99 |
+
"rgb(90, 190, 220)", # 青蓝灰 → 加蓝提升对比
|
| 100 |
+
"rgb(100, 150, 255)", # 雾蓝 → 饱和冷蓝
|
| 101 |
+
"rgb(255, 200, 0)", # 淡柠黄 → 纯柠黄
|
| 102 |
+
"rgb(200, 100, 190)", # 粉紫 → 更紫
|
| 103 |
+
"rgb(80, 180, 255)", # 冰蓝 → 更冷更亮的蓝
|
| 104 |
+
"rgb(255, 130, 145)", # 莓奶粉 → 更亮更红
|
| 105 |
+
"rgb(80, 200, 150)", # 牛油果绿 → 更深更绿
|
| 106 |
+
"rgb(255, 180, 60)", # 奶橙黄 → 更橙更亮
|
| 107 |
+
"rgb(180, 140, 255)", # 薄紫 → 更强饱和度紫
|
| 108 |
+
"rgb(130, 210, 50)", # 抹茶绿 → 偏亮偏黄的绿
|
| 109 |
+
"rgb(90, 190, 220)", # 青蓝灰 → 加蓝提升对比
|
| 110 |
+
"rgb(100, 150, 255)", # 雾蓝 → 饱和冷蓝
|
| 111 |
+
"rgb(255, 200, 0)", # 淡柠黄 → 纯柠黄
|
| 112 |
+
"rgb(200, 100, 190)", # 粉紫 → 更紫
|
| 113 |
+
"rgb(80, 180, 255)", # 冰蓝 → 更冷更亮的蓝
|
| 114 |
+
]
|
| 115 |
+
# pallette for semantic label color
|
| 116 |
+
semantic_color_ref = px.colors.qualitative.Vivid_r
|
| 117 |
+
# attention map visulaization color
|
| 118 |
+
attn_color_ref = px.colors.sequential.Viridis
|
| 119 |
+
|
| 120 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 121 |
+
|
| 122 |
+
cmap_attn = LinearSegmentedColormap.from_list("mycmap", attn_color_ref, N=256)
|
my_utils/render.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import os, sys
|
| 2 |
+
# sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 3 |
+
# import trimesh
|
| 4 |
+
# import pyrender
|
| 5 |
+
# import numpy as np
|
| 6 |
+
# # import open3d as o3d
|
| 7 |
+
# from copy import deepcopy
|
| 8 |
+
# os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
| 9 |
+
# from my_utils.refs import semantic_color_ref, graph_color_ref, joint_color_ref
|
| 10 |
+
|
| 11 |
+
# def get_rotation_axis_angle(k, theta):
|
| 12 |
+
# '''
|
| 13 |
+
# Rotation matrix converter from axis-angle using Rodrigues' rotation formula
|
| 14 |
+
|
| 15 |
+
# Args:
|
| 16 |
+
# k (np.ndarray): 3D unit vector representing the axis to rotate about.
|
| 17 |
+
# theta (float): Angle to rotate with in radians.
|
| 18 |
+
|
| 19 |
+
# Returns:
|
| 20 |
+
# R (np.ndarray): 3x3 rotation matrix.
|
| 21 |
+
# '''
|
| 22 |
+
# if np.linalg.norm(k) == 0.:
|
| 23 |
+
# return np.eye(3)
|
| 24 |
+
# k = k / np.linalg.norm(k)
|
| 25 |
+
# kx, ky, kz = k[0], k[1], k[2]
|
| 26 |
+
# cos, sin = np.cos(theta), np.sin(theta)
|
| 27 |
+
# R = np.zeros((3, 3), dtype=np.float32)
|
| 28 |
+
# R[0, 0] = cos + (kx**2) * (1 - cos)
|
| 29 |
+
# R[0, 1] = kx * ky * (1 - cos) - kz * sin
|
| 30 |
+
# R[0, 2] = kx * kz * (1 - cos) + ky * sin
|
| 31 |
+
# R[1, 0] = kx * ky * (1 - cos) + kz * sin
|
| 32 |
+
# R[1, 1] = cos + (ky**2) * (1 - cos)
|
| 33 |
+
# R[1, 2] = ky * kz * (1 - cos) - kx * sin
|
| 34 |
+
# R[2, 0] = kx * kz * (1 - cos) - ky * sin
|
| 35 |
+
# R[2, 1] = ky * kz * (1 - cos) + kx * sin
|
| 36 |
+
# R[2, 2] = cos + (kz**2) * (1 - cos)
|
| 37 |
+
# return R
|
| 38 |
+
|
| 39 |
+
# def rescale_axis(jtype, axis_d, axis_o, box_center):
|
| 40 |
+
# '''
|
| 41 |
+
# Function to rescale the axis for rendering
|
| 42 |
+
|
| 43 |
+
# Args:
|
| 44 |
+
# - jtype (int): joint type
|
| 45 |
+
# - axis_d (np.array): axis direction
|
| 46 |
+
# - axis_o (np.array): axis origin
|
| 47 |
+
# - box_center (np.array): bounding box center
|
| 48 |
+
|
| 49 |
+
# Returns:
|
| 50 |
+
# - center (np.array): rescaled axis origin
|
| 51 |
+
# - axis_d (np.array): rescaled axis direction
|
| 52 |
+
# '''
|
| 53 |
+
# if jtype == 0 or jtype == 1:
|
| 54 |
+
# return [0., 0., 0.], [0., 0., 0.]
|
| 55 |
+
# if jtype == 3 or jtype == 4:
|
| 56 |
+
# center = box_center
|
| 57 |
+
# else:
|
| 58 |
+
# center = axis_o + np.dot(axis_d, box_center-axis_o) * axis_d
|
| 59 |
+
# return center.tolist(), axis_d.tolist()
|
| 60 |
+
|
| 61 |
+
# # def get_axis_mesh(k, axis_o, bbox_center, joint_type):
|
| 62 |
+
# # '''
|
| 63 |
+
# # Function to get the axis mesh
|
| 64 |
+
|
| 65 |
+
# # Args:
|
| 66 |
+
# # - k (np.array): axis direction
|
| 67 |
+
# # - center (np.array): axis origin
|
| 68 |
+
# # - bbox_center (np.array): bounding box center
|
| 69 |
+
# # - joint_type (int): joint type
|
| 70 |
+
# # '''
|
| 71 |
+
# # if joint_type == 0 or joint_type == 1 or np.linalg.norm(k) == 0. :
|
| 72 |
+
# # return None
|
| 73 |
+
|
| 74 |
+
# # k = k / np.linalg.norm(k)
|
| 75 |
+
|
| 76 |
+
# # if joint_type == 3 or joint_type == 4: # prismatic or screw
|
| 77 |
+
# # axis_o = bbox_center
|
| 78 |
+
# # else: # revolute or continuous
|
| 79 |
+
# # axis_o = axis_o + np.dot(k, bbox_center-axis_o) * k
|
| 80 |
+
# # axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.015, cone_radius=0.03, cylinder_height=1.0, cone_height=0.08)
|
| 81 |
+
# # arrow = np.array([0., 0., 1.], dtype=np.float32)
|
| 82 |
+
# # n = np.cross(arrow, k)
|
| 83 |
+
# # rad = np.arccos(np.dot(arrow, k))
|
| 84 |
+
# # R_arrow = get_rotation_axis_angle(n, rad)
|
| 85 |
+
# # axis.rotate(R_arrow, center=(0, 0, 0))
|
| 86 |
+
# # axis.translate(axis_o[:3])
|
| 87 |
+
# # axis.compute_vertex_normals()
|
| 88 |
+
# # vertices = np.asarray(axis.vertices)
|
| 89 |
+
# # faces = np.asarray(axis.triangles)
|
| 90 |
+
# # trimesh_axis = trimesh.Trimesh(vertices=vertices, faces=faces)
|
| 91 |
+
# # # trimesh_axis.visual.vertex_colors = np.array([0, 0, 0, 1.0], dtype=np.float32)
|
| 92 |
+
# # trimesh_axis.visual.vertex_colors = np.repeat(np.array([0, 0, 0, 1.0]), vertices.shape[0], axis=0)
|
| 93 |
+
# # return trimesh_axis
|
| 94 |
+
|
| 95 |
+
# def get_camera_pose(eye, look_at, up):
|
| 96 |
+
# """
|
| 97 |
+
# Compute the 4x4 transformation matrix for a camera pose.
|
| 98 |
+
|
| 99 |
+
# Parameters:
|
| 100 |
+
# eye (np.ndarray): 3D position of the camera.
|
| 101 |
+
# look_at (np.ndarray): 3D point the camera is looking at.
|
| 102 |
+
# up (np.ndarray): Up vector.
|
| 103 |
+
|
| 104 |
+
# Returns:
|
| 105 |
+
# pose (np.ndarray): 4x4 transformation matrix representing the camera pose.
|
| 106 |
+
# """
|
| 107 |
+
# # Compute the forward, right, and new up vectors
|
| 108 |
+
# forward = (look_at - eye)
|
| 109 |
+
# forward = forward / np.linalg.norm(forward)
|
| 110 |
+
|
| 111 |
+
# right = np.cross(forward, up)
|
| 112 |
+
# right = right / np.linalg.norm(right)
|
| 113 |
+
|
| 114 |
+
# new_up = np.cross(right, forward)
|
| 115 |
+
# new_up = new_up / np.linalg.norm(new_up)
|
| 116 |
+
|
| 117 |
+
# # Create rotation matrix
|
| 118 |
+
# pose = np.eye(4)
|
| 119 |
+
# pose[0:3, 0] = right
|
| 120 |
+
# pose[0:3, 1] = new_up
|
| 121 |
+
# pose[0:3, 2] = -forward # Negative because the camera looks along the negative Z axis in its local coordinate
|
| 122 |
+
# pose[0:3, 3] = eye
|
| 123 |
+
|
| 124 |
+
# return pose
|
| 125 |
+
|
| 126 |
+
# def get_rotation_axis_angle_box(axis, angle):
|
| 127 |
+
# axis = axis / np.linalg.norm(axis)
|
| 128 |
+
# return trimesh.transformations.rotation_matrix(angle, axis)
|
| 129 |
+
|
| 130 |
+
# def get_colored_box(center, size, jtype=None, jrange=None, axis_d=None, axis_o=None):
|
| 131 |
+
# '''
|
| 132 |
+
# Create a solid color box and its animated state if joint info is provided
|
| 133 |
+
|
| 134 |
+
# Args:
|
| 135 |
+
# center (np.array): box center (3,)
|
| 136 |
+
# size (np.array): box size (3,)
|
| 137 |
+
# color (list or array): RGBA color, e.g. [255, 0, 0, 255]
|
| 138 |
+
# jtype (int): joint type (2=rot, 3=slide, 4=screw, 5=continuous)
|
| 139 |
+
# jrange (list): joint motion range
|
| 140 |
+
# axis_d (np.array): axis direction (3,)
|
| 141 |
+
# axis_o (np.array): axis origin (3,)
|
| 142 |
+
|
| 143 |
+
# Returns:
|
| 144 |
+
# box: trimesh.Trimesh at rest
|
| 145 |
+
# box_anim: trimesh.Trimesh after transformation
|
| 146 |
+
# '''
|
| 147 |
+
# size = np.clip(size, a_min=0.005, a_max=3.0)
|
| 148 |
+
# center = np.clip(center, a_min=-3.0, a_max=3.0)
|
| 149 |
+
|
| 150 |
+
# # Rest state box
|
| 151 |
+
# box = trimesh.creation.box(extents=size)
|
| 152 |
+
# box.apply_translation(center)
|
| 153 |
+
|
| 154 |
+
# # Animated state (deepcopy + transform)
|
| 155 |
+
# box_anim = deepcopy(box)
|
| 156 |
+
|
| 157 |
+
# if jtype is not None:
|
| 158 |
+
# if jtype == 2: # revolute
|
| 159 |
+
# theta = np.deg2rad(jrange[1])
|
| 160 |
+
# T = trimesh.transformations.translation_matrix(axis_o)
|
| 161 |
+
# R_3 = get_rotation_axis_angle(axis_d, theta)
|
| 162 |
+
# R = np.eye(4, dtype=np.float32)
|
| 163 |
+
# R[:3, :3] = R_3
|
| 164 |
+
# T_inv = trimesh.transformations.translation_matrix(-axis_o)
|
| 165 |
+
# box_anim.apply_transform(T @ R @ T_inv)
|
| 166 |
+
|
| 167 |
+
# elif jtype == 3: # prismatic
|
| 168 |
+
# dist = float(jrange[1])
|
| 169 |
+
# T = trimesh.transformations.translation_matrix(axis_d * dist)
|
| 170 |
+
# box_anim.apply_transform(T)
|
| 171 |
+
|
| 172 |
+
# elif jtype == 4: # screw
|
| 173 |
+
# theta = np.pi / 4
|
| 174 |
+
# dist = float(jrange[1])
|
| 175 |
+
# T1 = trimesh.transformations.translation_matrix(-axis_o)
|
| 176 |
+
# R = get_rotation_axis_angle(axis_d, theta)
|
| 177 |
+
# T2 = trimesh.transformations.translation_matrix(axis_o + axis_d * dist)
|
| 178 |
+
# box_anim.apply_transform(T1 @ R @ T2)
|
| 179 |
+
|
| 180 |
+
# elif jtype == 5: # continuous
|
| 181 |
+
# theta = np.pi / 4
|
| 182 |
+
# T = trimesh.transformations.translation_matrix(-axis_o)
|
| 183 |
+
# R_3 = get_rotation_axis_angle(axis_d, theta)
|
| 184 |
+
# R = np.eye(4, dtype=np.float32)
|
| 185 |
+
# R[:3, :3] = R_3
|
| 186 |
+
# T_inv = trimesh.transformations.translation_matrix(axis_o)
|
| 187 |
+
# box_anim.apply_transform(T @ R @ T_inv)
|
| 188 |
+
|
| 189 |
+
# return box, box_anim
|
| 190 |
+
|
| 191 |
+
# # def get_bbox_mesh_pair(center, size, radius=0.01, jtype=None, jrange=None, axis_d=None, axis_o=None):
|
| 192 |
+
# # '''
|
| 193 |
+
# # Function to get the bounding box mesh pair
|
| 194 |
+
|
| 195 |
+
# # Args:
|
| 196 |
+
# # - center (np.array): bounding box center
|
| 197 |
+
# # - size (np.array): bounding box size
|
| 198 |
+
# # - radius (float): radius of the cylinder
|
| 199 |
+
# # - jtype (int): joint type
|
| 200 |
+
# # - jrange (list): joint range
|
| 201 |
+
# # - axis_d (np.array): axis direction
|
| 202 |
+
# # - axis_o (np.array): axis origin
|
| 203 |
+
|
| 204 |
+
# # Returns:
|
| 205 |
+
# # - trimesh_box (trimesh object): trimesh object for the bbox at resting state
|
| 206 |
+
# # - trimesh_box_anim (trimesh object): trimesh object for the bbox at opening state
|
| 207 |
+
# # '''
|
| 208 |
+
|
| 209 |
+
# # size = np.clip(size, a_max=3, a_min=0.005)
|
| 210 |
+
# # center = np.clip(center, a_max=3, a_min=-3)
|
| 211 |
+
|
| 212 |
+
# # line_box = o3d.geometry.TriangleMesh()
|
| 213 |
+
# # z_cylinder = o3d.geometry.TriangleMesh.create_cylinder(radius=radius, height=size[2])
|
| 214 |
+
# # y_cylinder = o3d.geometry.TriangleMesh.create_cylinder(radius=radius, height=size[1])
|
| 215 |
+
# # R_y = get_rotation_axis_angle(np.array([1., 0., 0.], dtype=np.float32), np.pi / 2)
|
| 216 |
+
# # y_cylinder.rotate(R_y, center=(0, 0, 0))
|
| 217 |
+
# # x_cylinder = o3d.geometry.TriangleMesh.create_cylinder(radius=radius, height=size[0])
|
| 218 |
+
# # R_x = get_rotation_axis_angle(np.array([0., 1., 0.], dtype=np.float32), np.pi / 2)
|
| 219 |
+
# # x_cylinder.rotate(R_x, center=(0, 0, 0))
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# # z1 = deepcopy(z_cylinder)
|
| 223 |
+
# # z1.translate(np.array([-size[0] / 2, size[1] / 2, 0.], dtype=np.float32))
|
| 224 |
+
# # line_box += z1.translate(center[:3])
|
| 225 |
+
# # z2 = deepcopy(z_cylinder)
|
| 226 |
+
# # z2.translate(np.array([size[0] / 2, size[1] / 2, 0.], dtype=np.float32))
|
| 227 |
+
# # line_box += z2.translate(center[:3])
|
| 228 |
+
# # z3 = deepcopy(z_cylinder)
|
| 229 |
+
# # z3.translate(np.array([-size[0] / 2, -size[1] / 2, 0.], dtype=np.float32))
|
| 230 |
+
# # line_box += z3.translate(center[:3])
|
| 231 |
+
# # z4 = deepcopy(z_cylinder)
|
| 232 |
+
# # z4.translate(np.array([size[0] / 2, -size[1] / 2, 0.], dtype=np.float32))
|
| 233 |
+
# # line_box += z4.translate(center[:3])
|
| 234 |
+
|
| 235 |
+
# # y1 = deepcopy(y_cylinder)
|
| 236 |
+
# # y1.translate(np.array([-size[0] / 2, 0., size[2] / 2], dtype=np.float32))
|
| 237 |
+
# # line_box += y1.translate(center[:3])
|
| 238 |
+
# # y2 = deepcopy(y_cylinder)
|
| 239 |
+
# # y2.translate(np.array([size[0] / 2, 0., size[2] / 2], dtype=np.float32))
|
| 240 |
+
# # line_box += y2.translate(center[:3])
|
| 241 |
+
# # y3 = deepcopy(y_cylinder)
|
| 242 |
+
# # y3.translate(np.array([-size[0] / 2, 0., -size[2] / 2], dtype=np.float32))
|
| 243 |
+
# # line_box += y3.translate(center[:3])
|
| 244 |
+
# # y4 = deepcopy(y_cylinder)
|
| 245 |
+
# # y4.translate(np.array([size[0] / 2, 0., -size[2] / 2], dtype=np.float32))
|
| 246 |
+
# # line_box += y4.translate(center[:3])
|
| 247 |
+
|
| 248 |
+
# # x1 = deepcopy(x_cylinder)
|
| 249 |
+
# # x1.translate(np.array([0., -size[1] / 2, size[2] / 2], dtype=np.float32))
|
| 250 |
+
# # line_box += x1.translate(center[:3])
|
| 251 |
+
# # x2 = deepcopy(x_cylinder)
|
| 252 |
+
# # x2.translate(np.array([0., size[1] / 2, size[2] / 2], dtype=np.float32))
|
| 253 |
+
# # line_box += x2.translate(center[:3])
|
| 254 |
+
# # x3 = deepcopy(x_cylinder)
|
| 255 |
+
# # x3.translate(np.array([0., -size[1] / 2, -size[2] / 2], dtype=np.float32))
|
| 256 |
+
# # line_box += x3.translate(center[:3])
|
| 257 |
+
# # x4 = deepcopy(x_cylinder)
|
| 258 |
+
# # x4.translate(np.array([0., size[1] / 2, -size[2] / 2]))
|
| 259 |
+
# # line_box += x4.translate(center[:3])
|
| 260 |
+
|
| 261 |
+
# # # transform
|
| 262 |
+
# # line_box_anim = deepcopy(line_box)
|
| 263 |
+
# # if jtype == 2: # revolute
|
| 264 |
+
# # theta = np.deg2rad(jrange[1])
|
| 265 |
+
# # line_box_anim.translate(-axis_o)
|
| 266 |
+
# # R = get_rotation_axis_angle(axis_d, theta)
|
| 267 |
+
# # line_box_anim.rotate(R, center=(0, 0, 0))
|
| 268 |
+
# # line_box_anim.translate(axis_o)
|
| 269 |
+
# # elif jtype == 3: # prismatic
|
| 270 |
+
# # dist = np.array(jrange[1], dtype=np.float32)
|
| 271 |
+
# # line_box_anim.translate(axis_d * dist)
|
| 272 |
+
# # elif jtype == 4: # screw
|
| 273 |
+
# # dist = np.array(jrange[1], dtype=np.float32)
|
| 274 |
+
# # theta = 0.25 * np.pi
|
| 275 |
+
# # R = get_rotation_axis_angle(axis_d, theta)
|
| 276 |
+
# # line_box_anim.translate(-axis_o)
|
| 277 |
+
# # line_box_anim.rotate(R, center=(0, 0, 0))
|
| 278 |
+
# # line_box_anim.translate(axis_o)
|
| 279 |
+
# # line_box_anim.translate(axis_d * dist)
|
| 280 |
+
# # elif jtype == 5: # continuous
|
| 281 |
+
# # theta = 0.25 * np.pi
|
| 282 |
+
# # R = get_rotation_axis_angle(axis_d, theta)
|
| 283 |
+
# # line_box_anim.translate(-axis_o)
|
| 284 |
+
# # line_box_anim.rotate(R, center=(0, 0, 0))
|
| 285 |
+
# # line_box_anim.translate(axis_o)
|
| 286 |
+
|
| 287 |
+
# # vertices = np.asarray(line_box.vertices)
|
| 288 |
+
# # faces = np.asarray(line_box.triangles)
|
| 289 |
+
# # trimesh_box = trimesh.Trimesh(vertices=vertices, faces=faces)
|
| 290 |
+
# # trimesh_box.visual.vertex_colors = np.array([0.0, 1.0, 1.0, 1.0], dtype=np.float32)
|
| 291 |
+
|
| 292 |
+
# # vertices_anim = np.asarray(line_box_anim.vertices)
|
| 293 |
+
# # faces_anim = np.asarray(line_box_anim.triangles)
|
| 294 |
+
# # trimesh_box_anim = trimesh.Trimesh(vertices=vertices_anim, faces=faces_anim)
|
| 295 |
+
# # trimesh_box_anim.visual.vertex_colors = np.array([0.0, 1.0, 1.0, 1.0], dtype=np.float32)
|
| 296 |
+
|
| 297 |
+
# # return trimesh_box, trimesh_box_anim
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# def get_color_from_palette(palette, idx):
|
| 301 |
+
# '''
|
| 302 |
+
# Function to get the color from the palette
|
| 303 |
+
|
| 304 |
+
# Args:
|
| 305 |
+
# - palette (list): list of color reference
|
| 306 |
+
# - idx (int): index of the color
|
| 307 |
+
|
| 308 |
+
# Returns:
|
| 309 |
+
# - color (np.array): color in the index of idx
|
| 310 |
+
# '''
|
| 311 |
+
# ref = palette[idx % len(palette)]
|
| 312 |
+
# ref_list = [int(i) for i in ref[4:-1].split(',')]
|
| 313 |
+
# if idx % len(palette) == 0:
|
| 314 |
+
# ref_list.append(120)
|
| 315 |
+
# else:
|
| 316 |
+
# ref_list.append(255)
|
| 317 |
+
# color = np.array([ref_list], dtype=np.float32) / 255.
|
| 318 |
+
# return color
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# def render_anim_parts(aabbs, axiss, resolution=256):
|
| 323 |
+
# '''
|
| 324 |
+
# Function to render the 3D bounding boxes and axes in the scene
|
| 325 |
+
|
| 326 |
+
# Args:
|
| 327 |
+
# aabbs: list of trimesh objects for the bounding box of each part
|
| 328 |
+
# axiss: list of trimesh objects for the axis of each part
|
| 329 |
+
# resolution: resolution of the rendered image
|
| 330 |
+
|
| 331 |
+
# Returns:
|
| 332 |
+
# color_img: rendered image
|
| 333 |
+
# '''
|
| 334 |
+
# n_parts = len(aabbs)
|
| 335 |
+
# # build mesh for each 3D bounding box
|
| 336 |
+
# scene = pyrender.Scene()
|
| 337 |
+
# for i in range(n_parts):
|
| 338 |
+
# scene.add(aabbs[i])
|
| 339 |
+
# if axiss[i] is not None:
|
| 340 |
+
# scene.add(axiss[i])
|
| 341 |
+
|
| 342 |
+
# # Add light to the scene
|
| 343 |
+
# scene.ambient_light = np.full(shape=3, fill_value=1.5, dtype=np.float32)
|
| 344 |
+
# light = pyrender.DirectionalLight(color=np.ones(2), intensity=5.0)
|
| 345 |
+
|
| 346 |
+
# # Add camera to the scene
|
| 347 |
+
# pose = get_camera_pose(eye=np.array([1.5, 1.2, 4.5]), look_at=np.array([0, 0, 0]), up=np.array([0, 1, 0]))
|
| 348 |
+
# camera = pyrender.PerspectiveCamera(yfov=np.pi / 5.0, aspectRatio=1.0)
|
| 349 |
+
# scene.add(light, pose=pose)
|
| 350 |
+
# scene.add(camera, pose=pose)
|
| 351 |
+
|
| 352 |
+
# # Offscreen Rendering
|
| 353 |
+
# offscreen_renderer = pyrender.OffscreenRenderer(resolution, resolution)
|
| 354 |
+
|
| 355 |
+
# # Render the scene
|
| 356 |
+
# color_img, _ = offscreen_renderer.render(scene)
|
| 357 |
+
|
| 358 |
+
# # Cleanup
|
| 359 |
+
# offscreen_renderer.delete()
|
| 360 |
+
# scene.clear()
|
| 361 |
+
# return color_img
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
# def draw_boxes_axiss_anim(aabbs_0, aabbs_1, axiss, mode='graph', resolution=256, types=None):
|
| 365 |
+
# '''
|
| 366 |
+
# Function to draw the 3D bounding boxes and axes of the two frames
|
| 367 |
+
|
| 368 |
+
# Args:
|
| 369 |
+
# aabbs_0: list of trimesh objects for the bounding box of each part in the resting state
|
| 370 |
+
# aabbs_1: list of trimesh objects for the bounding box of each part in the open state
|
| 371 |
+
# axiss: list of trimesh objects for the axis of each part
|
| 372 |
+
# mode:
|
| 373 |
+
# 'graph' using palette corresponding to graph node,
|
| 374 |
+
# 'jtype' using palette corresponding to joint type,
|
| 375 |
+
# 'semantic' using palette corresponding to semantic label
|
| 376 |
+
# resolution: resolution of the rendered image
|
| 377 |
+
# types: ids corresponding to each joint type or semantic label, if mode is 'jtype' or 'semantic'
|
| 378 |
+
# '''
|
| 379 |
+
# n_parts = len(aabbs_0)
|
| 380 |
+
# ren_aabbs_0 = []
|
| 381 |
+
# ren_aabbs_1 = []
|
| 382 |
+
# ren_axiss = []
|
| 383 |
+
# if mode == 'graph':
|
| 384 |
+
# palette = graph_color_ref
|
| 385 |
+
# # Add meshes to the scene
|
| 386 |
+
# for i in range(n_parts):
|
| 387 |
+
# color = get_color_from_palette(palette, i)
|
| 388 |
+
# aabb_0 = pyrender.Mesh.from_trimesh(aabbs_0[i], smooth=False)
|
| 389 |
+
# aabb_0.primitives[0].color_0 = color.repeat(aabb_0.primitives[0].positions.shape[0], axis=0)
|
| 390 |
+
# ren_aabbs_0.append(aabb_0)
|
| 391 |
+
# aabb_1 = pyrender.Mesh.from_trimesh(aabbs_1[i], smooth=False)
|
| 392 |
+
# aabb_1.primitives[0].color_0 = color.repeat(aabb_1.primitives[0].positions.shape[0], axis=0)
|
| 393 |
+
# ren_aabbs_1.append(aabb_1)
|
| 394 |
+
# if axiss[i] is not None:
|
| 395 |
+
# axis = pyrender.Mesh.from_trimesh(axiss[i], smooth=False)
|
| 396 |
+
# axis.primitives[0].color_0 = color.repeat(axis.primitives[0].positions.shape[0], axis=0)
|
| 397 |
+
# ren_axiss.append(axis)
|
| 398 |
+
# else:
|
| 399 |
+
# ren_axiss.append(None)
|
| 400 |
+
# elif mode == 'jtype' or mode == 'semantic':
|
| 401 |
+
# assert types is not None
|
| 402 |
+
# palette = joint_color_ref if mode == 'jtype' else semantic_color_ref
|
| 403 |
+
# # Add meshes to the scene
|
| 404 |
+
# for i in range(n_parts):
|
| 405 |
+
# color = get_color_from_palette(palette, types[i])
|
| 406 |
+
# aabb_0 = pyrender.Mesh.from_trimesh(aabbs_0[i], smooth=False)
|
| 407 |
+
# aabb_0.primitives[0].color_0 = color.repeat(aabb_0.primitives[0].positions.shape[0], axis=0)
|
| 408 |
+
# ren_aabbs_0.append(aabb_0)
|
| 409 |
+
# aabb_1 = pyrender.Mesh.from_trimesh(aabbs_1[i], smooth=False)
|
| 410 |
+
# aabb_1.primitives[0].color_0 = color.repeat(aabb_1.primitives[0].positions.shape[0], axis=0)
|
| 411 |
+
# ren_aabbs_1.append(aabb_1)
|
| 412 |
+
|
| 413 |
+
# if axiss[i] is not None:
|
| 414 |
+
# axis = pyrender.Mesh.from_trimesh(axiss[i], smooth=False)
|
| 415 |
+
# ren_axiss.append(axis)
|
| 416 |
+
# else:
|
| 417 |
+
# ren_axiss.append(None)
|
| 418 |
+
# else:
|
| 419 |
+
# raise ValueError('mode must be either graph or type')
|
| 420 |
+
|
| 421 |
+
# img0 = render_anim_parts(ren_aabbs_0, ren_axiss, resolution=resolution)
|
| 422 |
+
# img1 = render_anim_parts(ren_aabbs_1, ren_axiss, resolution=resolution)
|
| 423 |
+
# return np.concatenate([img0, img1], axis=1)
|
| 424 |
+
|
| 425 |
+
# def prepare_meshes(info_dict):
|
| 426 |
+
# """
|
| 427 |
+
# Function to prepare the bbox and axis meshes for visualization
|
| 428 |
+
|
| 429 |
+
# Args:
|
| 430 |
+
# - info_dict (dict): output json containing the graph information
|
| 431 |
+
# """
|
| 432 |
+
# from my_utils.refs import joint_ref, sem_ref
|
| 433 |
+
# tree = info_dict["diffuse_tree"]
|
| 434 |
+
# bbox_0, bbox_1, axiss, labels, jtypes = [], [], [], [], []
|
| 435 |
+
# root_id = 0
|
| 436 |
+
# # get root id
|
| 437 |
+
# for node in tree:
|
| 438 |
+
# if node["parent"] == -1:
|
| 439 |
+
# root_id = node["id"]
|
| 440 |
+
# for node in tree:
|
| 441 |
+
# # retrieve info
|
| 442 |
+
# box_cen = np.array(node["aabb"]["center"], dtype=np.float32)
|
| 443 |
+
# box_size = np.array(node["aabb"]["size"], dtype=np.float32)
|
| 444 |
+
# axis_d = np.array(node["joint"]["axis"]["direction"], dtype=np.float32)
|
| 445 |
+
# axis_o = np.array(node["joint"]["axis"]["origin"], dtype=np.float32)
|
| 446 |
+
# jtype = joint_ref["fwd"][node["joint"]["type"]]
|
| 447 |
+
# # construct meshes for bbox in two states (closed and fully open)
|
| 448 |
+
# if node["id"] == root_id or node["parent"] == root_id: # use the joint info directly
|
| 449 |
+
# bb_0, bb_1 = get_colored_box(
|
| 450 |
+
# box_cen,
|
| 451 |
+
# box_size,
|
| 452 |
+
# jtype=jtype,
|
| 453 |
+
# jrange= node["joint"]["range"],
|
| 454 |
+
# axis_d=axis_d,
|
| 455 |
+
# axis_o=axis_o,
|
| 456 |
+
# )
|
| 457 |
+
# else: # use the parent joint info
|
| 458 |
+
# parent_id = node["parent"]
|
| 459 |
+
# bb_0, bb_1 = get_colored_box(
|
| 460 |
+
# box_cen,
|
| 461 |
+
# box_size,
|
| 462 |
+
# jtype=joint_ref["fwd"][tree[parent_id]["joint"]["type"]],
|
| 463 |
+
# jrange=tree[parent_id]["joint"]["range"],
|
| 464 |
+
# axis_d=np.array(tree[parent_id]["joint"]["axis"]["direction"], dtype=np.float32),
|
| 465 |
+
# axis_o=np.array(tree[parent_id]["joint"]["axis"]["origin"], dtype=np.float32),
|
| 466 |
+
# )
|
| 467 |
+
# # construct mesh for joint axis
|
| 468 |
+
# axis_mesh = get_axis_mesh(axis_d, axis_o, box_cen, node["joint"]["type"])
|
| 469 |
+
# # append
|
| 470 |
+
# bbox_0.append(bb_0)
|
| 471 |
+
# bbox_1.append(bb_1)
|
| 472 |
+
# axiss.append(axis_mesh)
|
| 473 |
+
# labels.append(sem_ref["fwd"][node["name"]])
|
| 474 |
+
# jtypes.append(jtype)
|
| 475 |
+
|
| 476 |
+
# return {
|
| 477 |
+
# "bbox_0": bbox_0,
|
| 478 |
+
# "bbox_1": bbox_1,
|
| 479 |
+
# "axiss": axiss,
|
| 480 |
+
# "labels": labels,
|
| 481 |
+
# "jtypes": jtypes,
|
| 482 |
+
# }
|
my_utils/savermixins.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import imageio
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
class SaverMixin():
|
| 8 |
+
|
| 9 |
+
@property
|
| 10 |
+
def save_dir(self):
|
| 11 |
+
return self.hparams.save_dir
|
| 12 |
+
|
| 13 |
+
def convert_format(self, data):
|
| 14 |
+
if isinstance(data, np.ndarray):
|
| 15 |
+
return data
|
| 16 |
+
elif isinstance(data, torch.Tensor):
|
| 17 |
+
return data.cpu().numpy()
|
| 18 |
+
elif isinstance(data, list):
|
| 19 |
+
return [self.convert_format(d) for d in data]
|
| 20 |
+
elif isinstance(data, dict):
|
| 21 |
+
return {k: self.convert_format(v) for k, v in data.items()}
|
| 22 |
+
else:
|
| 23 |
+
raise TypeError('Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting', type(data))
|
| 24 |
+
|
| 25 |
+
def get_save_path(self, filename):
|
| 26 |
+
save_path = os.path.join(self.save_dir, filename)
|
| 27 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 28 |
+
return save_path
|
| 29 |
+
|
| 30 |
+
def save_rgb_image(self, filename, img):
|
| 31 |
+
imageio.imwrite(self.get_save_path(filename), img)
|
| 32 |
+
|
| 33 |
+
def save_rgb_video(self, filename, stage='fit', filter=None):
|
| 34 |
+
img_dir = os.path.join(self.logger.log_dir, 'images', stage)
|
| 35 |
+
|
| 36 |
+
writer_graph = imageio.get_writer(os.path.join(img_dir, filename), fps=1)
|
| 37 |
+
|
| 38 |
+
for file in sorted(os.listdir(img_dir)):
|
| 39 |
+
if file.endswith('.png') and 'gt' not in file:
|
| 40 |
+
if filter is not None:
|
| 41 |
+
if filter in file:
|
| 42 |
+
writer_graph.append_data(imageio.imread(os.path.join(img_dir, file)))
|
| 43 |
+
else:
|
| 44 |
+
writer_graph.append_data(imageio.imread(os.path.join(img_dir, file)))
|
| 45 |
+
|
| 46 |
+
writer_graph.close()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def save_json(self, filename, data):
|
| 51 |
+
save_path = self.get_save_path(filename)
|
| 52 |
+
with open(save_path, 'w') as f:
|
| 53 |
+
json.dump(data, f)
|
| 54 |
+
|
| 55 |
+
|
objects/__init__.py
ADDED
|
File without changes
|
objects/dict_utils.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.optimize import linear_sum_assignment
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_base_part_idx(obj_dict):
|
| 6 |
+
"""
|
| 7 |
+
Get the index of the base part in the object dictionary\n
|
| 8 |
+
|
| 9 |
+
- obj_dict: the object dictionary\n
|
| 10 |
+
|
| 11 |
+
Return:\n
|
| 12 |
+
- base_part_idx: the index of the base part
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
# Adjust for NAP's corner case
|
| 16 |
+
base_part_ids = np.where(
|
| 17 |
+
[part["parent"] == -1 for part in obj_dict["diffuse_tree"]]
|
| 18 |
+
)[0]
|
| 19 |
+
if len(base_part_ids) > 0:
|
| 20 |
+
return base_part_ids[0].item()
|
| 21 |
+
else:
|
| 22 |
+
raise ValueError("No base part found")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_bbox_vertices(obj_dict, part_idx):
|
| 26 |
+
"""
|
| 27 |
+
Get the 8 vertices of the bounding box\n
|
| 28 |
+
The order of the vertices is the same as the order that pytorch3d.ops.box3d_overlap expects\n
|
| 29 |
+
(This order is not necessary since we are not using pytorch3d.ops.box3d_overlap anymore)\n
|
| 30 |
+
|
| 31 |
+
- bbox_center: the center of the bounding box in the form: [cx, cy, cz]\n
|
| 32 |
+
- bbox_size: the size of the bounding box in the form: [lx, ly, lz]\n
|
| 33 |
+
|
| 34 |
+
Return:\n
|
| 35 |
+
- bbox_vertices: the 8 vertices of the bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
part = obj_dict["diffuse_tree"][part_idx]
|
| 39 |
+
bbox_center = np.array(part["aabb"]["center"], dtype=np.float32)
|
| 40 |
+
bbox_size_half = np.array(part["aabb"]["size"], dtype=np.float32) / 2
|
| 41 |
+
|
| 42 |
+
bbox_vertices = np.zeros((8, 3), dtype=np.float32)
|
| 43 |
+
|
| 44 |
+
# Get the 8 vertices of the bounding box in the order that pytorch3d.ops.box3d_overlap expects:
|
| 45 |
+
# 0: (x0, y0, z0) # 1: (x1, y0, z0) # 2: (x1, y1, z0) # 3: (x0, y1, z0)
|
| 46 |
+
# 4: (x0, y0, z1) # 5: (x1, y0, z1) # 6: (x1, y1, z1) # 7: (x0, y1, z1)
|
| 47 |
+
bbox_vertices[0, :] = bbox_center - bbox_size_half
|
| 48 |
+
bbox_vertices[1, :] = bbox_center + np.array(
|
| 49 |
+
[bbox_size_half[0], -bbox_size_half[1], -bbox_size_half[2]], dtype=np.float32
|
| 50 |
+
)
|
| 51 |
+
bbox_vertices[2, :] = bbox_center + np.array(
|
| 52 |
+
[bbox_size_half[0], bbox_size_half[1], -bbox_size_half[2]], dtype=np.float32
|
| 53 |
+
)
|
| 54 |
+
bbox_vertices[3, :] = bbox_center + np.array(
|
| 55 |
+
[-bbox_size_half[0], bbox_size_half[1], -bbox_size_half[2]], dtype=np.float32
|
| 56 |
+
)
|
| 57 |
+
bbox_vertices[4, :] = bbox_center + np.array(
|
| 58 |
+
[-bbox_size_half[0], -bbox_size_half[1], bbox_size_half[2]], dtype=np.float32
|
| 59 |
+
)
|
| 60 |
+
bbox_vertices[5, :] = bbox_center + np.array(
|
| 61 |
+
[bbox_size_half[0], -bbox_size_half[1], bbox_size_half[2]], dtype=np.float32
|
| 62 |
+
)
|
| 63 |
+
bbox_vertices[6, :] = bbox_center + bbox_size_half
|
| 64 |
+
bbox_vertices[7, :] = bbox_center + np.array(
|
| 65 |
+
[-bbox_size_half[0], bbox_size_half[1], bbox_size_half[2]], dtype=np.float32
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return bbox_vertices
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def compute_overall_bbox_size(obj_dict):
|
| 72 |
+
"""
|
| 73 |
+
Compute the overall bounding box size of the object\n
|
| 74 |
+
|
| 75 |
+
- obj_dict: the object dictionary\n
|
| 76 |
+
|
| 77 |
+
Return:\n
|
| 78 |
+
- bbox_size: the overall bounding box size in the form: [lx, ly, lz]
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
bbox_min = np.zeros((len(obj_dict["diffuse_tree"]), 3), dtype=np.float32)
|
| 82 |
+
bbox_max = np.zeros((len(obj_dict["diffuse_tree"]), 3), dtype=np.float32)
|
| 83 |
+
|
| 84 |
+
# For each part, compute the bounding box and store the min and max vertices
|
| 85 |
+
for part_idx, part in enumerate(obj_dict["diffuse_tree"]):
|
| 86 |
+
bbox_center = np.array(part["aabb"]["center"], dtype=np.float32)
|
| 87 |
+
bbox_size_half = np.array(part["aabb"]["size"], dtype=np.float32) / 2
|
| 88 |
+
bbox_min[part_idx] = bbox_center - bbox_size_half
|
| 89 |
+
bbox_max[part_idx] = bbox_center + bbox_size_half
|
| 90 |
+
|
| 91 |
+
# Compute the overall bounding box size
|
| 92 |
+
bbox_min = np.min(bbox_min, axis=0)
|
| 93 |
+
bbox_max = np.max(bbox_max, axis=0)
|
| 94 |
+
bbox_size = bbox_max - bbox_min
|
| 95 |
+
return bbox_size
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def remove_handles(obj_dict):
|
| 99 |
+
"""
|
| 100 |
+
Remove the handles from the object dictionary and adjust the id, parent, and children of the parts\n
|
| 101 |
+
|
| 102 |
+
- obj_dict: the object dictionary\n
|
| 103 |
+
|
| 104 |
+
Return:\n
|
| 105 |
+
- obj_dict: the object dictionary without the handles
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
# Find the indices of the handles
|
| 109 |
+
handle_idxs = np.array(
|
| 110 |
+
[
|
| 111 |
+
i
|
| 112 |
+
for i in range(len(obj_dict["diffuse_tree"]))
|
| 113 |
+
if obj_dict["diffuse_tree"][i]["name"] == "handle"
|
| 114 |
+
and obj_dict["diffuse_tree"][i]["parent"] != -1
|
| 115 |
+
]
|
| 116 |
+
) # Added to avoid corner case of NAP where the handle is the base part
|
| 117 |
+
|
| 118 |
+
# Remove the handles from the object dictionary and adjust the id, parent, and children of the parts
|
| 119 |
+
for handle_idx in handle_idxs:
|
| 120 |
+
handle = obj_dict["diffuse_tree"][handle_idx]
|
| 121 |
+
parent_idx = handle["parent"]
|
| 122 |
+
if handle_idx in obj_dict["diffuse_tree"][parent_idx]["children"]:
|
| 123 |
+
obj_dict["diffuse_tree"][parent_idx]["children"].remove(handle_idx)
|
| 124 |
+
obj_dict["diffuse_tree"].pop(handle_idx)
|
| 125 |
+
|
| 126 |
+
# Adjust the id, parent, and children of the parts
|
| 127 |
+
for part in obj_dict["diffuse_tree"]:
|
| 128 |
+
if part["id"] > handle_idx:
|
| 129 |
+
part["id"] -= 1
|
| 130 |
+
if part["parent"] > handle_idx:
|
| 131 |
+
part["parent"] -= 1
|
| 132 |
+
for i in range(len(part["children"])):
|
| 133 |
+
if part["children"][i] > handle_idx:
|
| 134 |
+
part["children"][i] -= 1
|
| 135 |
+
|
| 136 |
+
handle_idxs -= 1
|
| 137 |
+
|
| 138 |
+
return obj_dict
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# def normalize_object(obj_dict):
|
| 142 |
+
# """
|
| 143 |
+
# Normalize the object as a whole\n
|
| 144 |
+
# Make the base part to be centered at the origin and have a size of 2\n
|
| 145 |
+
|
| 146 |
+
# obj_dict: the object dictionary
|
| 147 |
+
# """
|
| 148 |
+
# # Find the base part and compute the translation and scaling factors
|
| 149 |
+
# tree = obj_dict["diffuse_tree"]
|
| 150 |
+
# for part in tree:
|
| 151 |
+
# if part["parent"] == -1:
|
| 152 |
+
# translate = -np.array(part["aabb"]["center"], dtype=np.float32)
|
| 153 |
+
# scale = 2.0 / np.array(part["aabb"]["size"], dtype=np.float32)
|
| 154 |
+
# break
|
| 155 |
+
|
| 156 |
+
# for part in tree:
|
| 157 |
+
# part["aabb"]["center"] = (
|
| 158 |
+
# np.array(part["aabb"]["center"], dtype=np.float32) + translate
|
| 159 |
+
# ) * scale
|
| 160 |
+
# part["aabb"]["size"] = np.array(part["aabb"]["size"], dtype=np.float32) * scale
|
| 161 |
+
# if part["joint"]["type"] != "fixed":
|
| 162 |
+
# part["joint"]["axis"]["origin"] = (
|
| 163 |
+
# np.array(part["joint"]["axis"]["origin"], dtype=np.float32) + translate
|
| 164 |
+
# ) * scale
|
| 165 |
+
|
| 166 |
+
def zero_center_object(obj_dict):
|
| 167 |
+
"""
|
| 168 |
+
Zero center the object as a whole\n
|
| 169 |
+
|
| 170 |
+
- obj_dict: the object dictionary
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
bbox_min = np.zeros((len(obj_dict["diffuse_tree"]), 3))
|
| 174 |
+
bbox_max = np.zeros((len(obj_dict["diffuse_tree"]), 3))
|
| 175 |
+
|
| 176 |
+
# For each part, compute the bounding box and store the min and max vertices
|
| 177 |
+
for part_idx, part in enumerate(obj_dict["diffuse_tree"]):
|
| 178 |
+
bbox_center = np.array(part["aabb"]["center"])
|
| 179 |
+
bbox_size_half = np.array(part["aabb"]["size"]) / 2
|
| 180 |
+
bbox_min[part_idx] = bbox_center - bbox_size_half
|
| 181 |
+
bbox_max[part_idx] = bbox_center + bbox_size_half
|
| 182 |
+
|
| 183 |
+
# Compute the overall bounding box size
|
| 184 |
+
bbox_min = np.min(bbox_min, axis=0)
|
| 185 |
+
bbox_max = np.max(bbox_max, axis=0)
|
| 186 |
+
bbox_center = (bbox_min + bbox_max) / 2
|
| 187 |
+
|
| 188 |
+
translate = -bbox_center
|
| 189 |
+
|
| 190 |
+
for part in obj_dict["diffuse_tree"]:
|
| 191 |
+
part["aabb"]["center"] = np.array(part["aabb"]["center"]) + translate
|
| 192 |
+
if part["joint"]["type"] != "fixed":
|
| 193 |
+
part["joint"]["axis"]["origin"] = np.array(part["joint"]["axis"]["origin"]) + translate
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def rescale_object(obj_dict, scale_factor):
|
| 197 |
+
"""
|
| 198 |
+
Rescale the object as a whole\n
|
| 199 |
+
|
| 200 |
+
- obj_dict: the object dictionary\n
|
| 201 |
+
- scale_factor: the scale factor to rescale the object
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
for part in obj_dict["diffuse_tree"]:
|
| 205 |
+
part["aabb"]["center"] = (
|
| 206 |
+
np.array(part["aabb"]["center"], dtype=np.float32) * scale_factor
|
| 207 |
+
)
|
| 208 |
+
part["aabb"]["size"] = (
|
| 209 |
+
np.array(part["aabb"]["size"], dtype=np.float32) * scale_factor
|
| 210 |
+
)
|
| 211 |
+
if part["joint"]["type"] != "fixed":
|
| 212 |
+
part["joint"]["axis"]["origin"] = (
|
| 213 |
+
np.array(part["joint"]["axis"]["origin"], dtype=np.float32)
|
| 214 |
+
* scale_factor
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def find_part_mapping(obj1_dict, obj2_dict, use_hungarian=False):
|
| 219 |
+
"""
|
| 220 |
+
Find the correspondences from the first object to the second object based on closest bbox centers\n
|
| 221 |
+
|
| 222 |
+
- obj1_dict: the first object dictionary\n
|
| 223 |
+
- obj2_dict: the second object dictionary\n
|
| 224 |
+
|
| 225 |
+
Return:\n
|
| 226 |
+
- mapping: the mapping from the first object to the second object in the form: [[obj_part_idx, distance], ...]
|
| 227 |
+
"""
|
| 228 |
+
if use_hungarian:
|
| 229 |
+
return hungarian_matching(obj1_dict, obj2_dict)
|
| 230 |
+
|
| 231 |
+
# Initialize the distances to be +inf
|
| 232 |
+
mapping = np.ones((len(obj1_dict["diffuse_tree"]), 2)) * np.inf
|
| 233 |
+
|
| 234 |
+
# For each part in the first object, find the closest part in the second object based on the bounding box center
|
| 235 |
+
for req_part_idx, req_part in enumerate(obj1_dict["diffuse_tree"]):
|
| 236 |
+
for obj_part_idx, obj_part in enumerate(obj2_dict["diffuse_tree"]):
|
| 237 |
+
distance = np.linalg.norm(
|
| 238 |
+
np.array(req_part["aabb"]["center"])
|
| 239 |
+
- np.array(obj_part["aabb"]["center"])
|
| 240 |
+
)
|
| 241 |
+
if distance < mapping[req_part_idx, 1]:
|
| 242 |
+
mapping[req_part_idx, :] = [obj_part_idx, distance]
|
| 243 |
+
|
| 244 |
+
return mapping
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def hungarian_matching(obj1_dict, obj2_dict):
|
| 248 |
+
"""
|
| 249 |
+
Find the correspondences from the first object to the second object based on closest bbox centers using Hungarian algorithm\n
|
| 250 |
+
|
| 251 |
+
- obj1_dict: the first object dictionary\n
|
| 252 |
+
- obj2_dict: the second object dictionary\n
|
| 253 |
+
|
| 254 |
+
Return:\n
|
| 255 |
+
- mapping: the mapping from the first object to the second object in the form: [[obj_part_idx], ...]
|
| 256 |
+
"""
|
| 257 |
+
INF = 9999999
|
| 258 |
+
|
| 259 |
+
tree1 = obj1_dict["diffuse_tree"]
|
| 260 |
+
tree2 = obj2_dict["diffuse_tree"]
|
| 261 |
+
|
| 262 |
+
n_parts1 = len(tree1)
|
| 263 |
+
n_parts2 = len(tree2)
|
| 264 |
+
n_parts_max = max(n_parts1, n_parts2)
|
| 265 |
+
|
| 266 |
+
# Initialize the cost matrix
|
| 267 |
+
cost_matrix = np.ones((n_parts_max, n_parts_max), dtype=np.float32) * INF
|
| 268 |
+
for i in range(n_parts1):
|
| 269 |
+
for j in range(n_parts2):
|
| 270 |
+
cost_matrix[i, j] = np.linalg.norm(
|
| 271 |
+
np.array(tree1[i]["aabb"]["center"], dtype=np.float32)
|
| 272 |
+
- np.array(tree2[j]["aabb"]["center"], dtype=np.float32)
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Find the correspondences using the Hungarian algorithm
|
| 276 |
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
| 277 |
+
|
| 278 |
+
# Valid correspondences are those with all cost less than INF
|
| 279 |
+
valid_correspondences = np.where(cost_matrix[row_ind, col_ind] < INF)[0]
|
| 280 |
+
invalid_correspondences = np.where(np.logical_not(cost_matrix[row_ind, col_ind] < INF))[0]
|
| 281 |
+
|
| 282 |
+
row_i = row_ind[valid_correspondences]
|
| 283 |
+
col_i = col_ind[valid_correspondences]
|
| 284 |
+
|
| 285 |
+
# Construct the mapping
|
| 286 |
+
mapping = np.zeros(
|
| 287 |
+
(n_parts1, 2), dtype=np.float32
|
| 288 |
+
)
|
| 289 |
+
mapping[row_i, 0] = col_i
|
| 290 |
+
mapping[row_i, 1] = cost_matrix[row_i, col_i]
|
| 291 |
+
|
| 292 |
+
# assign the index of the most closely matched part
|
| 293 |
+
if n_parts1 > n_parts2:
|
| 294 |
+
row_j = row_ind[invalid_correspondences]
|
| 295 |
+
col_j = cost_matrix[row_j, :].argmin(axis=1)
|
| 296 |
+
mapping[row_j, 0] = col_j
|
| 297 |
+
mapping[row_j, 1] = cost_matrix[row_j, col_j]
|
| 298 |
+
|
| 299 |
+
return mapping
|
objects/motions.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import quaternion
|
| 5 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 6 |
+
from objects.dict_utils import get_base_part_idx
|
| 7 |
+
|
| 8 |
+
def transform_all_parts(part_vertices, obj_dict, joint_state,
|
| 9 |
+
rotation_fix_range=True, dry_run=True):
|
| 10 |
+
"""
|
| 11 |
+
Transform all parts of the object according to the joint state\n
|
| 12 |
+
|
| 13 |
+
- part_vertices: vertices of the object in rest pose in the form:\n
|
| 14 |
+
- [K_parts, N_vertices, 3]\n
|
| 15 |
+
- obj_dict: the object dictionary\n
|
| 16 |
+
- joint_state: the joint state in the range of [0, 1]\n
|
| 17 |
+
- rotation_fix_range (optional): whether to fix the rotation range to 90 degrees for revolute joints\n
|
| 18 |
+
- dry_run (optional): if True, only return the transformation matrices without changing the vertices\n
|
| 19 |
+
|
| 20 |
+
Return:\n
|
| 21 |
+
- part_transformations: records of the transformations applied to the parts\n
|
| 22 |
+
"""
|
| 23 |
+
part_transformations = [[] for _ in range(len(obj_dict["diffuse_tree"]))]
|
| 24 |
+
if joint_state == 0.0:
|
| 25 |
+
return part_transformations
|
| 26 |
+
|
| 27 |
+
# Get a visit order of the parts such that children parts are visited before parents
|
| 28 |
+
part_visit_order = []
|
| 29 |
+
base_idx = get_base_part_idx(obj_dict)
|
| 30 |
+
indices_to_visit = [base_idx]
|
| 31 |
+
while len(indices_to_visit) > 0: # Breadth-first traversal
|
| 32 |
+
current_idx = indices_to_visit.pop(0)
|
| 33 |
+
part_visit_order.append(current_idx)
|
| 34 |
+
# if current_idx == 9:
|
| 35 |
+
# import ipdb
|
| 36 |
+
# ipdb.set_trace()
|
| 37 |
+
indices_to_visit += obj_dict["diffuse_tree"][current_idx]["children"]
|
| 38 |
+
part_visit_order.reverse()
|
| 39 |
+
|
| 40 |
+
# Transform the parts in the visit order - children first, then parents
|
| 41 |
+
for i in part_visit_order:
|
| 42 |
+
part = obj_dict["diffuse_tree"][i]
|
| 43 |
+
joint = part["joint"]
|
| 44 |
+
children_idxs = part["children"]
|
| 45 |
+
|
| 46 |
+
# Store the transformation used to transform the part and its children
|
| 47 |
+
applied_tramsformation_matrix = np.eye(4, dtype=np.float32)
|
| 48 |
+
applied_rotation_axis_origin = np.array([np.nan, np.nan, np.nan], dtype=np.float32)
|
| 49 |
+
applied_transformation_type = "none"
|
| 50 |
+
if joint["type"] == "prismatic":
|
| 51 |
+
# Translate the part and its children
|
| 52 |
+
translation = np.array(joint["axis"]["direction"], dtype=np.float32) * joint["range"][1] * joint_state
|
| 53 |
+
|
| 54 |
+
if not dry_run:
|
| 55 |
+
part_vertices[[i] + children_idxs] += translation
|
| 56 |
+
|
| 57 |
+
# Store the transformation used
|
| 58 |
+
applied_tramsformation_matrix[:3, 3] = translation
|
| 59 |
+
applied_transformation_type = "translation"
|
| 60 |
+
|
| 61 |
+
elif joint["type"] == "revolute" or joint["type"] == "continuous":
|
| 62 |
+
if joint["type"] == "revolute":
|
| 63 |
+
if not rotation_fix_range:
|
| 64 |
+
# Use the full range as specified in the object file
|
| 65 |
+
rotation_radian = np.radians(joint["range"][1] * joint_state)
|
| 66 |
+
else:
|
| 67 |
+
# Fix the rotation range to 90 degrees
|
| 68 |
+
rotation_range_sign = np.sign(joint["range"][1])
|
| 69 |
+
rotation_radian = np.radians(rotation_range_sign * 90 * joint_state)
|
| 70 |
+
|
| 71 |
+
else:
|
| 72 |
+
rotation_radian = np.radians(360 * joint_state)
|
| 73 |
+
|
| 74 |
+
# Prepare the rotation matrix via axis-angle representation and quaternion
|
| 75 |
+
rotation_axis_origin = np.array(joint["axis"]["origin"], dtype=np.float32)
|
| 76 |
+
rotation_axis_direction = np.array(joint["axis"]["direction"], dtype=np.float32) / np.linalg.norm(joint["axis"]["direction"])
|
| 77 |
+
rotation_matrix = quaternion.as_rotation_matrix(quaternion.from_rotation_vector(rotation_radian * rotation_axis_direction))
|
| 78 |
+
|
| 79 |
+
if not dry_run:
|
| 80 |
+
# Rotate the part and its children
|
| 81 |
+
vertices_to_rotate = (part_vertices[[i] + children_idxs] - rotation_axis_origin)
|
| 82 |
+
part_vertices[[i] + children_idxs] = np.matmul(rotation_matrix, vertices_to_rotate.transpose([0, 2, 1])).transpose([0, 2, 1]) + rotation_axis_origin
|
| 83 |
+
|
| 84 |
+
# Store the transformation used
|
| 85 |
+
applied_tramsformation_matrix[:3, :3] = rotation_matrix
|
| 86 |
+
applied_rotation_axis_origin = rotation_axis_origin
|
| 87 |
+
applied_transformation_type = "rotation"
|
| 88 |
+
|
| 89 |
+
# Record the transformation used
|
| 90 |
+
if not applied_transformation_type == "none":
|
| 91 |
+
record = {
|
| 92 |
+
"type": applied_transformation_type,
|
| 93 |
+
"matrix": applied_tramsformation_matrix,
|
| 94 |
+
"rotation_axis_origin": applied_rotation_axis_origin
|
| 95 |
+
}
|
| 96 |
+
for idx in [i] + children_idxs:
|
| 97 |
+
part_transformations[idx].append(record)
|
| 98 |
+
|
| 99 |
+
return part_transformations
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
| 2 |
+
|
| 3 |
+
torch==2.2.2
|
| 4 |
+
torchvision==0.17.2
|
| 5 |
+
pytorch-lightning==2.4.0
|
| 6 |
+
lightning==2.3.3
|
| 7 |
+
matplotlib
|
| 8 |
+
numpy==1.26.4
|
| 9 |
+
gradio==5.34.2
|
| 10 |
+
wandb
|
| 11 |
+
omegaconf
|
| 12 |
+
imageio
|
| 13 |
+
diffusers
|
| 14 |
+
plotly
|
| 15 |
+
pybullet
|
| 16 |
+
pyrender
|
| 17 |
+
trimesh
|
| 18 |
+
numpy-quaternion
|
| 19 |
+
openai
|
| 20 |
+
spaces
|
| 21 |
+
json_repair
|
retrieval/__init__.py
ADDED
|
File without changes
|
retrieval/obj_retrieval.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import random
|
| 4 |
+
import json
|
| 5 |
+
import numpy as np
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
|
| 8 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
+
from metrics.iou_cdist import IoU_cDist
|
| 10 |
+
import networkx as nx
|
| 11 |
+
|
| 12 |
+
all_categories = [
|
| 13 |
+
"Table",
|
| 14 |
+
"StorageFurniture",
|
| 15 |
+
"WashingMachine",
|
| 16 |
+
"Microwave",
|
| 17 |
+
"Dishwasher",
|
| 18 |
+
"Refrigerator",
|
| 19 |
+
"Oven",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
all_categories_acd = [
|
| 23 |
+
'armoire',
|
| 24 |
+
'bookcase',
|
| 25 |
+
'chestofdrawers',
|
| 26 |
+
'hangingcabinet',
|
| 27 |
+
'kitchencabinet'
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_hash(file, key="diffuse_tree", ignore_handles=True, dag=False):
|
| 32 |
+
tree = file[key]
|
| 33 |
+
if dag:
|
| 34 |
+
G = nx.DiGraph()
|
| 35 |
+
else:
|
| 36 |
+
G = nx.Graph()
|
| 37 |
+
for node in tree:
|
| 38 |
+
if ignore_handles and "handle" in node["name"].lower():
|
| 39 |
+
continue
|
| 40 |
+
G.add_node(node["id"])
|
| 41 |
+
if node["parent"] != -1:
|
| 42 |
+
G.add_edge(node["id"], node["parent"])
|
| 43 |
+
hashcode = nx.weisfeiler_lehman_graph_hash(G)
|
| 44 |
+
return hashcode
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _verify_mesh_exists(dir, ply_files, verbose=False):
|
| 48 |
+
"""
|
| 49 |
+
Verify that the mesh files exist\n
|
| 50 |
+
|
| 51 |
+
- dir: the directory of the object\n
|
| 52 |
+
- ply_files: the list of mesh files\n
|
| 53 |
+
- verbose (optional): whether to print the progress\n
|
| 54 |
+
|
| 55 |
+
return:\n
|
| 56 |
+
- True if the mesh files exist, False otherwise
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
for ply_file in ply_files:
|
| 60 |
+
if not os.path.exists(os.path.join(dir, ply_file)):
|
| 61 |
+
if verbose:
|
| 62 |
+
print(f" - {os.path.join(dir, ply_file)} does not exist!!!")
|
| 63 |
+
return False
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _generate_output_part_dicts(
|
| 68 |
+
candidate_dict,
|
| 69 |
+
part_idx,
|
| 70 |
+
candidate_dir,
|
| 71 |
+
requirement_part_bbox_sizes,
|
| 72 |
+
bbox_size_eps=1e-3,
|
| 73 |
+
verbose=False,
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
Generate the output part dictionary for all parts that are fulfilled by the candidate part and computing the scale factor of the parts\n
|
| 77 |
+
|
| 78 |
+
- candidate_dict: the candidate object dictionary\n
|
| 79 |
+
- part_idx: the index of the part in the candidate object\n
|
| 80 |
+
- candidate_dir: the directory of the candidate object\n
|
| 81 |
+
- requirement_part_bbox_sizes: the bounding box sizes of the requirement part in the form: [[lx1, ly1, lz1], [lx2, ly2, lz2], ...]\n
|
| 82 |
+
- bbox_size_eps (optional): the epsilon to avoid zero volume parts\n
|
| 83 |
+
- verbose (optional): whether to print the progress\n
|
| 84 |
+
|
| 85 |
+
Return:\n
|
| 86 |
+
- part_dicts: the output part dictionaries in the form:
|
| 87 |
+
- [{name, dir, files, scale_factor=[sx, sy, sz]}, z_rotate_90]
|
| 88 |
+
- z_rotate_90 is True if the part needs to be rotated by 90 degrees around the z-axis
|
| 89 |
+
- [{}, ...] if any of the mesh files do not exist
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
part_dicts = [{} for _ in range(len(requirement_part_bbox_sizes))]
|
| 93 |
+
fixed_portion = {
|
| 94 |
+
"name": candidate_dict["diffuse_tree"][part_idx]["name"],
|
| 95 |
+
"dir": candidate_dir,
|
| 96 |
+
"files": candidate_dict["diffuse_tree"][part_idx]["plys"],
|
| 97 |
+
"z_rotate_90": False,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Verify that the mesh files exist
|
| 101 |
+
if not _verify_mesh_exists(fixed_portion["dir"], fixed_portion["files"], verbose):
|
| 102 |
+
if verbose:
|
| 103 |
+
print(
|
| 104 |
+
f" - ! Found invalid mesh files in {fixed_portion['dir']}, skipping..."
|
| 105 |
+
)
|
| 106 |
+
return part_dicts # List of empty dicts
|
| 107 |
+
|
| 108 |
+
candidate_bbox_size = np.array(
|
| 109 |
+
candidate_dict["diffuse_tree"][part_idx]["aabb"]["size"]
|
| 110 |
+
)
|
| 111 |
+
candidate_bbox_size = np.maximum(
|
| 112 |
+
candidate_bbox_size, bbox_size_eps
|
| 113 |
+
) # Avoid zero volume parts
|
| 114 |
+
|
| 115 |
+
for i, requirement_part_bbox_size in enumerate(requirement_part_bbox_sizes):
|
| 116 |
+
part_dicts[i] = deepcopy(fixed_portion)
|
| 117 |
+
|
| 118 |
+
# For non-handle parts, compute the scale factor normally
|
| 119 |
+
if fixed_portion["name"] != "handle":
|
| 120 |
+
part_dicts[i]["scale_factor"] = list(
|
| 121 |
+
np.array(requirement_part_bbox_size) / candidate_bbox_size
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# For handles, need to consider the orientation of the selected handle and the orientation of the requirement handle
|
| 125 |
+
else:
|
| 126 |
+
requirement_handle_is_horizontal = (
|
| 127 |
+
requirement_part_bbox_size[0] > requirement_part_bbox_size[1]
|
| 128 |
+
)
|
| 129 |
+
candidate_handle_is_horizontal = (
|
| 130 |
+
candidate_bbox_size[0] > candidate_bbox_size[1]
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# If the orientations are different, rotate the requirement handle by 90 degrees around the z-axis before computing the scale factor
|
| 134 |
+
if requirement_handle_is_horizontal != candidate_handle_is_horizontal:
|
| 135 |
+
rotated_requirement_part_bbox_size = [
|
| 136 |
+
requirement_part_bbox_size[1],
|
| 137 |
+
requirement_part_bbox_size[0],
|
| 138 |
+
requirement_part_bbox_size[2],
|
| 139 |
+
]
|
| 140 |
+
part_dicts[i]["scale_factor"] = list(
|
| 141 |
+
np.array(rotated_requirement_part_bbox_size) / candidate_bbox_size
|
| 142 |
+
)
|
| 143 |
+
part_dicts[i]["z_rotate_90"] = True
|
| 144 |
+
|
| 145 |
+
# If the orientations are the same, compute the scale factor normally
|
| 146 |
+
else:
|
| 147 |
+
part_dicts[i]["scale_factor"] = list(
|
| 148 |
+
np.array(requirement_part_bbox_size) / candidate_bbox_size
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return part_dicts
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def find_obj_candidates(
|
| 155 |
+
requirement_dict,
|
| 156 |
+
dataset_dir,
|
| 157 |
+
hashbook_path,
|
| 158 |
+
num_states=5,
|
| 159 |
+
metric_compare_handles=False,
|
| 160 |
+
metric_iou_include_base=True,
|
| 161 |
+
metric_num_samples=10000,
|
| 162 |
+
keep_top=5,
|
| 163 |
+
gt_file_name="object.json",
|
| 164 |
+
verbose=False,
|
| 165 |
+
):
|
| 166 |
+
"""
|
| 167 |
+
Find the best object candidates for selecting the base part using AID\n
|
| 168 |
+
|
| 169 |
+
- requirement_dict: the object dictionary of the requirement\n
|
| 170 |
+
- dataset_dir: the directory of the dataset to search in\n
|
| 171 |
+
- hashbook_path: the path to the hashbook for filtering candidates\n
|
| 172 |
+
- num_states: the number of states to average the metric over\n
|
| 173 |
+
- metric_transform_plucker (optional): whether to use Plucker coordinates to move parts when computing the metric\n
|
| 174 |
+
- metric_compare_handles (optional): whether to compare handles when computing the metric\n
|
| 175 |
+
- metric_iou_include_base (optional): whether to include the base when computing the IoU\n
|
| 176 |
+
- metric_scale_factor (optional): the scale factor to scale the object before computing the metric\n
|
| 177 |
+
- Scaling up the object makes the sampling more well distributed\n
|
| 178 |
+
- metric_num_samples (optional): the number of samples to use when computing the metric\n
|
| 179 |
+
- keep_top (optional): the number of top candidates to keep\n
|
| 180 |
+
- gt_file_name (optional): the name of the ground truth json file, which describes a candidate object\n
|
| 181 |
+
- verbose (optional): whether to print the progress\n
|
| 182 |
+
|
| 183 |
+
return:\n
|
| 184 |
+
- a list of best object candidates of the form:
|
| 185 |
+
- {"category", "dir", "score"}
|
| 186 |
+
"""
|
| 187 |
+
dataset_dir = os.path.abspath(dataset_dir)
|
| 188 |
+
|
| 189 |
+
# Load the hashbook
|
| 190 |
+
with open(hashbook_path, "r") as f:
|
| 191 |
+
hashbook = json.load(f)
|
| 192 |
+
|
| 193 |
+
if 'acd' in hashbook_path:
|
| 194 |
+
all_categories = all_categories_acd
|
| 195 |
+
else:
|
| 196 |
+
all_categories = [
|
| 197 |
+
"Table",
|
| 198 |
+
"StorageFurniture",
|
| 199 |
+
"WashingMachine",
|
| 200 |
+
"Microwave",
|
| 201 |
+
"Dishwasher",
|
| 202 |
+
"Refrigerator",
|
| 203 |
+
"Oven",
|
| 204 |
+
]
|
| 205 |
+
|
| 206 |
+
# Resolve paths to directories
|
| 207 |
+
category_specified = False
|
| 208 |
+
requirement_category = ""
|
| 209 |
+
|
| 210 |
+
# if the category is specified, only search in that category, otherwise search in all categories
|
| 211 |
+
if "obj_cat" in requirement_dict["meta"]:
|
| 212 |
+
requirement_category = requirement_dict["meta"]["obj_cat"]
|
| 213 |
+
category_specified = True
|
| 214 |
+
if requirement_category == "StroageFurniture":
|
| 215 |
+
requirement_category = "StorageFurniture"
|
| 216 |
+
category_dirs = (
|
| 217 |
+
[os.path.join(dataset_dir, requirement_category)]
|
| 218 |
+
if category_specified
|
| 219 |
+
else [os.path.join(dataset_dir, category) for category in all_categories]
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Extract requirement data
|
| 223 |
+
requirement_part_names = []
|
| 224 |
+
requirement_part_bboxes = []
|
| 225 |
+
for part in requirement_dict["diffuse_tree"]:
|
| 226 |
+
requirement_part_names.append(part["name"])
|
| 227 |
+
requirement_part_bboxes.append(
|
| 228 |
+
np.concatenate([part["aabb"]["center"], part["aabb"]["size"]])
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Compute hash of the requirement graph
|
| 232 |
+
requirement_graph_hash = get_hash(requirement_dict)
|
| 233 |
+
|
| 234 |
+
# Prefetch list of ids of candidate objects with the same hash
|
| 235 |
+
# import ipdb
|
| 236 |
+
# ipdb.set_trace()
|
| 237 |
+
if category_specified and requirement_graph_hash in hashbook[requirement_category]:
|
| 238 |
+
same_hash_obj_ids = hashbook[requirement_category][requirement_graph_hash]
|
| 239 |
+
else:
|
| 240 |
+
# Use all categories if category is not specified
|
| 241 |
+
same_hash_obj_ids = []
|
| 242 |
+
for category in all_categories:
|
| 243 |
+
if requirement_graph_hash in hashbook[category]:
|
| 244 |
+
same_hash_obj_ids += hashbook[category][requirement_graph_hash]
|
| 245 |
+
|
| 246 |
+
# Iterate through all candidate objects and keep the top k candidates
|
| 247 |
+
best_obj_candidates = []
|
| 248 |
+
for category_dir in category_dirs:
|
| 249 |
+
obj_ids = os.listdir(category_dir)
|
| 250 |
+
for i, obj_id in enumerate(obj_ids):
|
| 251 |
+
if verbose:
|
| 252 |
+
print(
|
| 253 |
+
f"\r - Finding candidates from {category_dir.split('/')[-1]}: {i+1}/{len(obj_ids)}",
|
| 254 |
+
end="",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Load the candidate object
|
| 258 |
+
obj_dir = os.path.join(category_dir, obj_id)
|
| 259 |
+
if os.path.exists(os.path.join(obj_dir, gt_file_name)):
|
| 260 |
+
with open(os.path.join(obj_dir, gt_file_name), "r") as f:
|
| 261 |
+
obj_dict = json.load(f)
|
| 262 |
+
if "diffuse_tree" not in obj_dict: # Rename for compatibility
|
| 263 |
+
obj_dict["diffuse_tree"] = obj_dict.pop("arti_tree")
|
| 264 |
+
|
| 265 |
+
# Compute metric for selecting the base if the hash matches or if there are no objects with the same hash
|
| 266 |
+
if obj_id in same_hash_obj_ids or len(same_hash_obj_ids) == 0:
|
| 267 |
+
scores = IoU_cDist(
|
| 268 |
+
requirement_dict,
|
| 269 |
+
obj_dict,
|
| 270 |
+
num_states=num_states,
|
| 271 |
+
compare_handles=metric_compare_handles,
|
| 272 |
+
iou_include_base=metric_iou_include_base,
|
| 273 |
+
num_samples=metric_num_samples,
|
| 274 |
+
)
|
| 275 |
+
base_score = scores["AS-cDist"]
|
| 276 |
+
|
| 277 |
+
# Add the candidate to the list of best candidates and keep the top k candidates
|
| 278 |
+
best_obj_candidates.append(
|
| 279 |
+
{
|
| 280 |
+
"category": category_dir.split("/")[-1],
|
| 281 |
+
"dir": obj_dir,
|
| 282 |
+
"score": base_score,
|
| 283 |
+
}
|
| 284 |
+
)
|
| 285 |
+
best_obj_candidates = sorted(
|
| 286 |
+
best_obj_candidates, key=lambda x: x["score"]
|
| 287 |
+
)[:keep_top]
|
| 288 |
+
if verbose:
|
| 289 |
+
print()
|
| 290 |
+
|
| 291 |
+
return best_obj_candidates
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def pick_and_rescale_parts(
|
| 295 |
+
requirement_dict,
|
| 296 |
+
obj_candidates,
|
| 297 |
+
dataset_dir,
|
| 298 |
+
gt_file_name="object.json",
|
| 299 |
+
verbose=False,
|
| 300 |
+
):
|
| 301 |
+
"""
|
| 302 |
+
Pick and rescale parts from the object candidates
|
| 303 |
+
|
| 304 |
+
- requirement_dict: the object dictionary of the requirement\n
|
| 305 |
+
- obj_candidates: the list of best object candidates for selecting the base part\n
|
| 306 |
+
- dataset_dir: the directory of the dataset to search in\n
|
| 307 |
+
- gt_file_name (optional): the name of the ground truth file, which describes a candidate object\n
|
| 308 |
+
- verbose (optional): whether to print the progress\n
|
| 309 |
+
|
| 310 |
+
return:\n
|
| 311 |
+
- parts_to_render: a list of selected parts for the requirement parts in the form:
|
| 312 |
+
- [{name, dir, files, scale_factor=[sx, sy, sz]}, z_rotate_90]
|
| 313 |
+
- z_rotate_90 is True if the part needs to be rotated by 90 degrees around the z-axis
|
| 314 |
+
"""
|
| 315 |
+
|
| 316 |
+
# Extract requirement data
|
| 317 |
+
if 'acd' in dataset_dir:
|
| 318 |
+
all_categories = all_categories_acd
|
| 319 |
+
else:
|
| 320 |
+
all_categories = [
|
| 321 |
+
"Table",
|
| 322 |
+
"StorageFurniture",
|
| 323 |
+
"WashingMachine",
|
| 324 |
+
"Microwave",
|
| 325 |
+
"Dishwasher",
|
| 326 |
+
"Refrigerator",
|
| 327 |
+
"Oven",
|
| 328 |
+
]
|
| 329 |
+
requirement_part_names = []
|
| 330 |
+
requirement_part_bbox_sizes = []
|
| 331 |
+
for part in requirement_dict["diffuse_tree"]:
|
| 332 |
+
if part['name'] == 'wheel':
|
| 333 |
+
part['name'] = 'handle'
|
| 334 |
+
requirement_part_names.append(part["name"])
|
| 335 |
+
requirement_part_bbox_sizes.append(part["aabb"]["size"])
|
| 336 |
+
|
| 337 |
+
# Collect the unique part names and store the indices of the parts with the same name
|
| 338 |
+
unique_requirement_part_names = {}
|
| 339 |
+
for i, part_name in enumerate(requirement_part_names):
|
| 340 |
+
if part_name not in unique_requirement_part_names:
|
| 341 |
+
unique_requirement_part_names[part_name] = [i]
|
| 342 |
+
else:
|
| 343 |
+
unique_requirement_part_names[part_name].append(i)
|
| 344 |
+
|
| 345 |
+
parts_to_render = [{} for _ in range(len(requirement_part_names))]
|
| 346 |
+
|
| 347 |
+
# Iterate through the object candidates selected for the base part first
|
| 348 |
+
for candidate in obj_candidates:
|
| 349 |
+
if all(
|
| 350 |
+
[len(part) > 0 for part in parts_to_render]
|
| 351 |
+
): # Break if all parts are fulfilled
|
| 352 |
+
break
|
| 353 |
+
|
| 354 |
+
if not os.path.exists(os.path.join(candidate["dir"], gt_file_name)):
|
| 355 |
+
continue
|
| 356 |
+
# Load the candidate object
|
| 357 |
+
with open(os.path.join(candidate["dir"], gt_file_name), "r") as f:
|
| 358 |
+
candidate_dict = json.load(f)
|
| 359 |
+
|
| 360 |
+
# Pick parts from the candidate if the part name matches and the part requirement is not yet fulfilled
|
| 361 |
+
for candidate_part_idx, part in enumerate(candidate_dict["diffuse_tree"]):
|
| 362 |
+
part_needed = part["name"] in unique_requirement_part_names
|
| 363 |
+
if not part_needed:
|
| 364 |
+
continue
|
| 365 |
+
|
| 366 |
+
part_not_fulfilled = any(
|
| 367 |
+
[
|
| 368 |
+
len(parts_to_render[i]) == 0
|
| 369 |
+
for i in unique_requirement_part_names[part["name"]]
|
| 370 |
+
]
|
| 371 |
+
)
|
| 372 |
+
if not part_not_fulfilled:
|
| 373 |
+
continue
|
| 374 |
+
|
| 375 |
+
# Get the indices of the requirement parts that are fulfilled by this candidate part and their bounding box sizes
|
| 376 |
+
fullfill_part_idxs = unique_requirement_part_names[part["name"]]
|
| 377 |
+
fullfill_part_bbox_sizes = [
|
| 378 |
+
requirement_part_bbox_sizes[i] for i in fullfill_part_idxs
|
| 379 |
+
]
|
| 380 |
+
# Generate all output part dictionaries at once
|
| 381 |
+
part_dicts = _generate_output_part_dicts(
|
| 382 |
+
candidate_dict,
|
| 383 |
+
candidate_part_idx,
|
| 384 |
+
candidate["dir"],
|
| 385 |
+
fullfill_part_bbox_sizes,
|
| 386 |
+
verbose=verbose,
|
| 387 |
+
)
|
| 388 |
+
# Update the output part dictionaries
|
| 389 |
+
[
|
| 390 |
+
parts_to_render[part_idx].update(part_dicts[part_dict_idx])
|
| 391 |
+
for part_dict_idx, part_idx in enumerate(fullfill_part_idxs)
|
| 392 |
+
]
|
| 393 |
+
|
| 394 |
+
# If there are still parts that are not fulfilled
|
| 395 |
+
if any([len(part) == 0 for part in parts_to_render]):
|
| 396 |
+
# Collect the remaining part names
|
| 397 |
+
remaining_part_names = list(
|
| 398 |
+
set(
|
| 399 |
+
[
|
| 400 |
+
requirement_part_names[i]
|
| 401 |
+
for i in range(len(requirement_part_names))
|
| 402 |
+
if len(parts_to_render[i]) == 0
|
| 403 |
+
]
|
| 404 |
+
)
|
| 405 |
+
)
|
| 406 |
+
if verbose:
|
| 407 |
+
print(
|
| 408 |
+
f" - Parts {remaining_part_names} are not fulfilled by the selected candidates, searching in the dataset..."
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# If the category is specified, only search in that category, otherwise search in all categories
|
| 412 |
+
# requirement_dict["meta"]["obj_cat"] = ""
|
| 413 |
+
requirement_category = requirement_dict["meta"]["obj_cat"]
|
| 414 |
+
if requirement_category == "StroageFurniture":
|
| 415 |
+
requirement_category = "StorageFurniture"
|
| 416 |
+
category_specified = requirement_category != ""
|
| 417 |
+
if category_specified:
|
| 418 |
+
category_dirs = [os.path.join(dataset_dir, requirement_category)]
|
| 419 |
+
else:
|
| 420 |
+
category_dirs = [
|
| 421 |
+
os.path.join(dataset_dir, category) for category in all_categories
|
| 422 |
+
]
|
| 423 |
+
|
| 424 |
+
# Iterate through all objects
|
| 425 |
+
retry = True # Retry if the category is specified, but some parts are still not fulfilled (See the end of the while loop)
|
| 426 |
+
retry_time = 0
|
| 427 |
+
while retry:
|
| 428 |
+
print(retry_time)
|
| 429 |
+
retry_time += 1
|
| 430 |
+
for category_dir in category_dirs:
|
| 431 |
+
obj_ids = os.listdir(category_dir)
|
| 432 |
+
random.shuffle(obj_ids) # Randomize the order of the objects
|
| 433 |
+
for i, obj_id in enumerate(obj_ids):
|
| 434 |
+
if True:
|
| 435 |
+
print(
|
| 436 |
+
f"- Finding missing parts from {category_dir.split('/')[-1]}: {i+1}/{len(obj_ids)} \n"
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Load the candidate object
|
| 440 |
+
obj_dir = os.path.join(category_dir, obj_id)
|
| 441 |
+
if not os.path.exists(os.path.join(obj_dir, gt_file_name)):
|
| 442 |
+
continue
|
| 443 |
+
with open(os.path.join(obj_dir, gt_file_name), "r") as f:
|
| 444 |
+
candidate_dict = json.load(f)
|
| 445 |
+
|
| 446 |
+
# Pick the part from the candidate if the part name matches and the parts that are not fulfilled
|
| 447 |
+
for candidate_part_idx, part in enumerate(
|
| 448 |
+
candidate_dict["diffuse_tree"]
|
| 449 |
+
):
|
| 450 |
+
part_needed = part["name"] in remaining_part_names
|
| 451 |
+
|
| 452 |
+
if part_needed:
|
| 453 |
+
# Get the indices of the requirement parts that are fulfilled by this candidate part and their bounding box sizes
|
| 454 |
+
fullfill_part_idxs = unique_requirement_part_names[
|
| 455 |
+
part["name"]
|
| 456 |
+
]
|
| 457 |
+
fullfill_part_bbox_sizes = [
|
| 458 |
+
requirement_part_bbox_sizes[i]
|
| 459 |
+
for i in fullfill_part_idxs
|
| 460 |
+
]
|
| 461 |
+
|
| 462 |
+
# Generate all output part dictionaries at once
|
| 463 |
+
part_dicts = _generate_output_part_dicts(
|
| 464 |
+
candidate_dict,
|
| 465 |
+
candidate_part_idx,
|
| 466 |
+
obj_dir,
|
| 467 |
+
fullfill_part_bbox_sizes,
|
| 468 |
+
verbose=verbose,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# Update the output part dictionaries
|
| 472 |
+
[
|
| 473 |
+
parts_to_render[part_idx].update(
|
| 474 |
+
part_dicts[part_dict_idx]
|
| 475 |
+
)
|
| 476 |
+
for part_dict_idx, part_idx in enumerate(
|
| 477 |
+
fullfill_part_idxs
|
| 478 |
+
)
|
| 479 |
+
]
|
| 480 |
+
|
| 481 |
+
if all([len(part) > 0 for part in parts_to_render]):
|
| 482 |
+
if verbose:
|
| 483 |
+
print(" -> Found all missing parts")
|
| 484 |
+
break
|
| 485 |
+
if all([len(part) > 0 for part in parts_to_render]):
|
| 486 |
+
retry = False
|
| 487 |
+
break
|
| 488 |
+
|
| 489 |
+
# If the category is specified, but some parts are still not fulfilled, search in all categories
|
| 490 |
+
if category_specified and any([len(part) == 0 for part in parts_to_render]):
|
| 491 |
+
if verbose:
|
| 492 |
+
print(
|
| 493 |
+
" - Required category is {requirement_category}, but some parts are still not fulfilled, searching in all categories..."
|
| 494 |
+
)
|
| 495 |
+
category_specified = False
|
| 496 |
+
retry = True
|
| 497 |
+
category_dirs = [
|
| 498 |
+
os.path.join(dataset_dir, category)
|
| 499 |
+
for category in all_categories
|
| 500 |
+
if category != requirement_category
|
| 501 |
+
]
|
| 502 |
+
|
| 503 |
+
# Raise error if there are still parts that are not fulfilled
|
| 504 |
+
if any([len(part) == 0 for part in parts_to_render]):
|
| 505 |
+
raise RuntimeError(
|
| 506 |
+
"Failed to fulfill all requirements, some parts may not exist in the dataset"
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
return parts_to_render
|
retrieval/retrieval_hash_acd.json
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"armoire": {
|
| 3 |
+
"ff24f9310d003cd7b7894b2d6ec79a03": [
|
| 4 |
+
"B07H8V49M2"
|
| 5 |
+
],
|
| 6 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 7 |
+
"B07GFDZVYY"
|
| 8 |
+
],
|
| 9 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 10 |
+
"B07GFSJ69T",
|
| 11 |
+
"8415f258d5e129e9bb63cab54c7c207e04c0dfed",
|
| 12 |
+
"06764d11dcec69878cc762c36482be5ef2865443",
|
| 13 |
+
"0760a3dd43bd9dd9c0ec1ea4e033c6e121d92df5",
|
| 14 |
+
"12001de4686bf2e4b9c721c93b35a424dd48249f"
|
| 15 |
+
],
|
| 16 |
+
"8fae315ff2d1ea185b4ed6d0d092ae0e": [
|
| 17 |
+
"B07GFW9GFX",
|
| 18 |
+
"16b93d86d1a466f5982e60c8d322ddd8f312056e"
|
| 19 |
+
],
|
| 20 |
+
"c24dd733315066f7c7da3d578f954d8c": [
|
| 21 |
+
"B07H8PS4FZ",
|
| 22 |
+
"cd17a7b8d78ee79bc52015841577a2652f9e0625",
|
| 23 |
+
"3033be75f2ac15885e8f4813dc55e168af09e91f"
|
| 24 |
+
],
|
| 25 |
+
"02746714ea1f7e20dd4ac8f18d8cdac3": [
|
| 26 |
+
"3bc24e0ea79ade13d4703ca23ece1e4019f9b70b"
|
| 27 |
+
],
|
| 28 |
+
"dd3473c941b94dd6654e9f89bb51cac9": [
|
| 29 |
+
"128b5f2d072869004b7a218ab674f93f73a66670"
|
| 30 |
+
],
|
| 31 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 32 |
+
"69340d8678701e72f39cc01890da1b8af3fd603d",
|
| 33 |
+
"24f8284e4bdef4397e5b12dc4f2b74a137d63dbb"
|
| 34 |
+
],
|
| 35 |
+
"5a8eac0760a558d4174437be478ec0aa": [
|
| 36 |
+
"a88c710a0b90706398a0fd7a9d73123338d04354"
|
| 37 |
+
],
|
| 38 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 39 |
+
"28098e6540f2f2fc07f7fc6a00edd6ce371a2618"
|
| 40 |
+
],
|
| 41 |
+
"00232256ef3ac441f59b36bfc7bd190c": [
|
| 42 |
+
"11f8b552a802c6233a1332713568f05f901b725c"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
"cabinet": {
|
| 46 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 47 |
+
"B07D42T6CX"
|
| 48 |
+
],
|
| 49 |
+
"d25563e624d9195ce94b1f768fdc503d": [
|
| 50 |
+
"B07MGL8651"
|
| 51 |
+
],
|
| 52 |
+
"2fce5c033589d2dfe24fa67dc6885386": [
|
| 53 |
+
"B07QD6V13M"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
"nightstand": {
|
| 57 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 58 |
+
"B072ZK8897",
|
| 59 |
+
"xxxx54c93c4fxabe2x463exafd8xcae01106dc7d",
|
| 60 |
+
"037f34132f162235d80ce46f67c4fa2238d94da0",
|
| 61 |
+
"155c182834f40ebb1d5666d3a72ee828e097097a"
|
| 62 |
+
],
|
| 63 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 64 |
+
"B072ZMHBKQ",
|
| 65 |
+
"fed2a84682713eabe2b8d0e1e950d891c7442d5f"
|
| 66 |
+
],
|
| 67 |
+
"5a8eac0760a558d4174437be478ec0aa": [
|
| 68 |
+
"1fa3dec03ab0afb0749373b2c8da8bf77c92e271",
|
| 69 |
+
"265e819d1f027cde8dac05468a70455f62cbf069",
|
| 70 |
+
"4bd1e4ec215403d3239f09517e44568f78da3b40"
|
| 71 |
+
],
|
| 72 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 73 |
+
"b8bdd9cbc1ca695a206583afdc26f1a4c3987303",
|
| 74 |
+
"0dfabed4818c34cbaa5ef41a3bcaf89177744b2c"
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
"table": {
|
| 78 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 79 |
+
"B07K6RNQDH",
|
| 80 |
+
"B07JXXR83F"
|
| 81 |
+
],
|
| 82 |
+
"c24dd733315066f7c7da3d578f954d8c": [
|
| 83 |
+
"B082YNHQLX"
|
| 84 |
+
],
|
| 85 |
+
"ddc2ef1be48dc58fe68226818824b648": [
|
| 86 |
+
"B075Z93NHP"
|
| 87 |
+
],
|
| 88 |
+
"ec04032eee6bc67c5fdd5ec6705c3137": [
|
| 89 |
+
"B075Z93NKX"
|
| 90 |
+
],
|
| 91 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 92 |
+
"B075Z96KQL"
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
"bookcase": {
|
| 96 |
+
"5a8eac0760a558d4174437be478ec0aa": [
|
| 97 |
+
"2c681d7e64e0410d76156f500cd2df798975a25d",
|
| 98 |
+
"36d7feaf2471b67aa638609fe2c2278fba4a15a0"
|
| 99 |
+
],
|
| 100 |
+
"ff24f9310d003cd7b7894b2d6ec79a03": [
|
| 101 |
+
"a194a188dfcd8e3796529c0263448ba047ce632f"
|
| 102 |
+
],
|
| 103 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 104 |
+
"d2f7d40d8e1a56ede42103027842e000d9cacd3e"
|
| 105 |
+
],
|
| 106 |
+
"ddc2ef1be48dc58fe68226818824b648": [
|
| 107 |
+
"36d90c43b90a526247d61e709f349b3ed54081e2"
|
| 108 |
+
],
|
| 109 |
+
"1e8c6b47706f002757c3370366001f06": [
|
| 110 |
+
"d5ba163ba97f94c7aa4a4a625eb0547b8894e1a5"
|
| 111 |
+
],
|
| 112 |
+
"61f645001e86ad8a32357cc828ae33cb": [
|
| 113 |
+
"3b558ce715c88cca63b307f8e0e9b665ca57ef43"
|
| 114 |
+
],
|
| 115 |
+
"c24dd733315066f7c7da3d578f954d8c": [
|
| 116 |
+
"44641367282a6a9616c91439944d4160a0e6f66a"
|
| 117 |
+
],
|
| 118 |
+
"8fae315ff2d1ea185b4ed6d0d092ae0e": [
|
| 119 |
+
"17a8f7c040254439fe8771f66f1c526c71d24904"
|
| 120 |
+
],
|
| 121 |
+
"0e3c4946251b437ca90e5fe70efdea5b": [
|
| 122 |
+
"9418cb5d4e1b7ff0ebd469e28dfdbaa99bc61f4d"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
"desk": {
|
| 126 |
+
"ddc2ef1be48dc58fe68226818824b648": [
|
| 127 |
+
"319c007ee7d07ca84c797a512ad4a98c9abc42da",
|
| 128 |
+
"1619e2e6a18d5d374963a3a4280f7a6d76356079"
|
| 129 |
+
],
|
| 130 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 131 |
+
"7489bcf3f226d315b4627bc442422f44f9eff092",
|
| 132 |
+
"429de8a53ddf94a47ec553147dc8603b803be056"
|
| 133 |
+
],
|
| 134 |
+
"5a8eac0760a558d4174437be478ec0aa": [
|
| 135 |
+
"3a12fb2ab85d734b4b11bc6f541be9961e7ecd23",
|
| 136 |
+
"03f226dd0e012925a8674564c4de19cb786c9a88"
|
| 137 |
+
],
|
| 138 |
+
"8fae315ff2d1ea185b4ed6d0d092ae0e": [
|
| 139 |
+
"b1c0fc607c063010c4b9300c955a0e3e5f7001fc"
|
| 140 |
+
],
|
| 141 |
+
"02746714ea1f7e20dd4ac8f18d8cdac3": [
|
| 142 |
+
"01be253cbfd14b947e9dbe09d0b1959e97d72122"
|
| 143 |
+
],
|
| 144 |
+
"69144809aea48cb46eae9c3950f24a15": [
|
| 145 |
+
"d7cec0e53dadbc4291064708c84e3614b79ac3c9"
|
| 146 |
+
],
|
| 147 |
+
"f38a9419ca785a395579ce42491c830e": [
|
| 148 |
+
"561105c73bf76152a2b32e4f55f80db6a25ac0d4"
|
| 149 |
+
],
|
| 150 |
+
"d25563e624d9195ce94b1f768fdc503d": [
|
| 151 |
+
"128ed2ced9a101aa0d131fd224012fd52198003f"
|
| 152 |
+
],
|
| 153 |
+
"17a09dc7b6207f11cd18889788802b88": [
|
| 154 |
+
"6d264e3023b940b0b1d31b77e04d0c845853c1f0"
|
| 155 |
+
]
|
| 156 |
+
},
|
| 157 |
+
"dishwasher": {
|
| 158 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 159 |
+
"cb3dd8f2c8de396606e0794f6effc921aff7235d",
|
| 160 |
+
"aad01c69d7a27a1740e422f5f64b781816bd86fa"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
"microwave": {
|
| 164 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 165 |
+
"a630fbac79cdc164c9344a588f72d207b6d25e33",
|
| 166 |
+
"bf840db863fc9c2646b2f8f372e4847b2fd42e34",
|
| 167 |
+
"c3b2adbd3b89bdcd01f1e813bc4d2e06975ec727",
|
| 168 |
+
"68407fbf5c7296351b2c26c2e59510effc87a637",
|
| 169 |
+
"81da9279eae235c3faced51d516e970acdac5e84"
|
| 170 |
+
],
|
| 171 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 172 |
+
"9849af0395972ff84e40f5c1a51db8a25e3ef6f7"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
"oven": {
|
| 176 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 177 |
+
"702d3ab650d34ee1dd236b0df5882573ec70c504"
|
| 178 |
+
],
|
| 179 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 180 |
+
"00b0d5e167ae6b42666de010025efad4506563f1",
|
| 181 |
+
"c1bdae17057dfa88d5e3894433642030fd66c7d6",
|
| 182 |
+
"197a447eb68ba32ab44c50948b3af1e63048e174",
|
| 183 |
+
"239c5c38a53badc24ca6950ee78d8f6c115c3074",
|
| 184 |
+
"4a326efb8ab35d8ce823575d4c3b7033e8c3e5e8"
|
| 185 |
+
],
|
| 186 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 187 |
+
"ef32e2cd6dd99d883f627e238f55ad0766240d44",
|
| 188 |
+
"f8bff67d469223c8e8bf44553834bd5482a96ecc",
|
| 189 |
+
"41efc8ed9c8d433e9ff877f3b9ea1c0eda45479c"
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
"refrigerator": {
|
| 193 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 194 |
+
"1c7874f93ca418d7edebd150a7422095fd76897a"
|
| 195 |
+
],
|
| 196 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 197 |
+
"9449ef7831c43bc0db23aac79ea442aa71d0db11",
|
| 198 |
+
"0bb1cdb98fbdfda5b41abaa39aa7f82321e58b72",
|
| 199 |
+
"1c33bd447d70d4d22116c434d912f1fae78e02b7"
|
| 200 |
+
],
|
| 201 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 202 |
+
"cae4c60830bba615ff533dc23ffee6e6e5c7d14e"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
"sink_cabinet": {
|
| 206 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 207 |
+
"56dc6fc7669736b5fd6a85d1b14a01d029beff59"
|
| 208 |
+
],
|
| 209 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 210 |
+
"f637f110fbed653b7983d9fc6a6d53795b384461"
|
| 211 |
+
],
|
| 212 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 213 |
+
"5f074f91cc2ce2a4d5a62e6cce77c435e5dbf457",
|
| 214 |
+
"6219ef05f4a7b56419749e45a45143df8af44495",
|
| 215 |
+
"027f20642dd34e7914fc4fc4efa70fbb54bcecbb",
|
| 216 |
+
"112dc87e26450400941c6eaff60866bc19badc64"
|
| 217 |
+
],
|
| 218 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 219 |
+
"76102a94ca13d9cdaf7d5a77262cddd4df9a806c"
|
| 220 |
+
],
|
| 221 |
+
"c24dd733315066f7c7da3d578f954d8c": [
|
| 222 |
+
"829bf87b541238bb7579a05d18d4f7f1d1f98af1"
|
| 223 |
+
],
|
| 224 |
+
"ada4be0df4d7d600d6729eb4a621f99a": [
|
| 225 |
+
"f1de6498f43789e1c27150b3ae1f9b5bfc051775"
|
| 226 |
+
]
|
| 227 |
+
},
|
| 228 |
+
"tv_stand": {
|
| 229 |
+
"0e3c4946251b437ca90e5fe70efdea5b": [
|
| 230 |
+
"a77e6006efcacc637e1c2a49e72232ee0f435e35",
|
| 231 |
+
"212778c0c265e0db358ad7c8c1fa9a4bcfe41bd7"
|
| 232 |
+
],
|
| 233 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 234 |
+
"ba231dd136e3bc77fb04ade17235b923aa7b2f07"
|
| 235 |
+
],
|
| 236 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 237 |
+
"056be15536045e9a6a94b9b93cff62f72d43c326",
|
| 238 |
+
"1aebe7d2500bfbcb0c6ec787a98a2b3701099ed7"
|
| 239 |
+
],
|
| 240 |
+
"94d192237d5fe1b065910cb51d8ee711": [
|
| 241 |
+
"63b7d75b724e079aab99030084f0eae1b43b7498"
|
| 242 |
+
]
|
| 243 |
+
},
|
| 244 |
+
"washer": {
|
| 245 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 246 |
+
"6cd2dc2611c27f758c972b4874efad8c8cbd5d29",
|
| 247 |
+
"912447108f21083d877aab6653742fceccc6ce7d",
|
| 248 |
+
"00b285ed70673826ccb6941929a64abfaf5f9239",
|
| 249 |
+
"eccf8ba37c804b9067e675f09eec2c13e951f61c",
|
| 250 |
+
"031024e00569909d466d87df9fead90355ba29e5",
|
| 251 |
+
"faefe63ba896c06920aa6d23b05ab83f3d6d37ea",
|
| 252 |
+
"2c2b2914fc526f6e8bbe65511ecf58bba0027ec0",
|
| 253 |
+
"4bc6f883a374400355bcf95e611a4e8f8b950ed5"
|
| 254 |
+
]
|
| 255 |
+
},
|
| 256 |
+
"chestofdrawers": {
|
| 257 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 258 |
+
"762665bc7a958151874c15edfc2711b161376678",
|
| 259 |
+
"d96b2b9537c7c721d5a79b375aefbfacecd04f65"
|
| 260 |
+
],
|
| 261 |
+
"47259e6c2fba7c74a9b725012e01ebba": [
|
| 262 |
+
"d621c3a39d9291f0943531204d68e705633986c9"
|
| 263 |
+
],
|
| 264 |
+
"c24dd733315066f7c7da3d578f954d8c": [
|
| 265 |
+
"293c42d04758bffcc63070066170cbb1f09918cc",
|
| 266 |
+
"1c0bbc026e76c09885dc5c6f156a6c3dec605d10"
|
| 267 |
+
],
|
| 268 |
+
"25bfc0f15836b69b830cf66b5217bed6": [
|
| 269 |
+
"807955fd4dcb59b67789b24f2e7bc167027c870a"
|
| 270 |
+
],
|
| 271 |
+
"322c8717a498a3e832420518775f8ffc": [
|
| 272 |
+
"2cdb28938dff1f9ab13aee7630cea51f44f60952"
|
| 273 |
+
],
|
| 274 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 275 |
+
"831e2cff446337af8791038161c7d32d96726b20",
|
| 276 |
+
"93767fc355a1afbfff79d54a25204069e0543d2b",
|
| 277 |
+
"11808e4bfc4534caf787fa15ba07bc2cbee95fdd"
|
| 278 |
+
],
|
| 279 |
+
"3ab3e03b34bc406737d81bd5db0ee212": [
|
| 280 |
+
"4c067979055ae739365d340a1699036be5c136c7"
|
| 281 |
+
],
|
| 282 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 283 |
+
"848c2cc9428329445b3de7dc982469c583b16c1d"
|
| 284 |
+
],
|
| 285 |
+
"7936dbd6c88dda2542d5509f6078a0a9": [
|
| 286 |
+
"09a0a4a031e33e738214b812f71dd838232df54c"
|
| 287 |
+
],
|
| 288 |
+
"8fae315ff2d1ea185b4ed6d0d092ae0e": [
|
| 289 |
+
"609530b858465d9795ac43ba435fdd5f12d95956",
|
| 290 |
+
"652403bb9a0b199ebdebe538a44bc897eab624f6"
|
| 291 |
+
],
|
| 292 |
+
"ddc2ef1be48dc58fe68226818824b648": [
|
| 293 |
+
"bc59a49ebf99164d5ed88bd6eaff12ec4ed86d0a"
|
| 294 |
+
]
|
| 295 |
+
},
|
| 296 |
+
"hangingcabinet": {
|
| 297 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 298 |
+
"B07B8NBQQP",
|
| 299 |
+
"dcef0d475b7fdea2530898215feafddac1fe9bcc",
|
| 300 |
+
"bb966a4f853df29b8b37b89e157aed4eb3936aec"
|
| 301 |
+
],
|
| 302 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 303 |
+
"be1b0ef31886b4b2b42cf0bf1b7df548917c9943"
|
| 304 |
+
],
|
| 305 |
+
"5a8eac0760a558d4174437be478ec0aa": [
|
| 306 |
+
"3432a71596b6cd7e944b6f19cf6d713fe17fc8bd",
|
| 307 |
+
"4468c13bc98184bcc403027164ede52b178e5d20",
|
| 308 |
+
"644c2d1505103189b6ae49f5a58b97ed41202149"
|
| 309 |
+
]
|
| 310 |
+
},
|
| 311 |
+
"kitchencabinet": {
|
| 312 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 313 |
+
"3a1c604565d5fa1f411f3fb437a816094ae69122",
|
| 314 |
+
"88c87a19b5e883787b5707d90545e25360594822"
|
| 315 |
+
],
|
| 316 |
+
"8fae315ff2d1ea185b4ed6d0d092ae0e": [
|
| 317 |
+
"bc3c5c45d5a9126bff85470c7c48d4e2b7ebfd0d"
|
| 318 |
+
],
|
| 319 |
+
"86f7cf811774c9dc1f8ac7ebefafd51c": [
|
| 320 |
+
"0c42ee5b635c8acdb6bf235ff9420d742a16fd30"
|
| 321 |
+
],
|
| 322 |
+
"c24dd733315066f7c7da3d578f954d8c": [
|
| 323 |
+
"10444a1165c1745f960eb6183d24dd05b60e781f"
|
| 324 |
+
],
|
| 325 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 326 |
+
"1a570da5ae0c4d1d51c94be81d7248f155a0dcef"
|
| 327 |
+
]
|
| 328 |
+
}
|
| 329 |
+
}
|
retrieval/retrieval_hash_no_handles.json
ADDED
|
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"Table": {
|
| 3 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 4 |
+
"25160",
|
| 5 |
+
"32761",
|
| 6 |
+
"28594",
|
| 7 |
+
"25959",
|
| 8 |
+
"34178",
|
| 9 |
+
"23724",
|
| 10 |
+
"21473",
|
| 11 |
+
"26545",
|
| 12 |
+
"26692",
|
| 13 |
+
"27478",
|
| 14 |
+
"19384",
|
| 15 |
+
"26800",
|
| 16 |
+
"25756",
|
| 17 |
+
"23807",
|
| 18 |
+
"26886",
|
| 19 |
+
"31249",
|
| 20 |
+
"21718",
|
| 21 |
+
"23372",
|
| 22 |
+
"32601"
|
| 23 |
+
],
|
| 24 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 25 |
+
"32566",
|
| 26 |
+
"26525",
|
| 27 |
+
"29557",
|
| 28 |
+
"24644",
|
| 29 |
+
"20453",
|
| 30 |
+
"26806",
|
| 31 |
+
"26670",
|
| 32 |
+
"20279",
|
| 33 |
+
"20411",
|
| 34 |
+
"20555",
|
| 35 |
+
"27189",
|
| 36 |
+
"30869",
|
| 37 |
+
"19855",
|
| 38 |
+
"24931",
|
| 39 |
+
"22339",
|
| 40 |
+
"28668",
|
| 41 |
+
"26073",
|
| 42 |
+
"26652",
|
| 43 |
+
"21467",
|
| 44 |
+
"20985",
|
| 45 |
+
"22241",
|
| 46 |
+
"22508",
|
| 47 |
+
"29921",
|
| 48 |
+
"27044",
|
| 49 |
+
"23511"
|
| 50 |
+
],
|
| 51 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 52 |
+
"27619",
|
| 53 |
+
"26875",
|
| 54 |
+
"27267",
|
| 55 |
+
"19825",
|
| 56 |
+
"33914",
|
| 57 |
+
"23472",
|
| 58 |
+
"22692",
|
| 59 |
+
"32174",
|
| 60 |
+
"19179",
|
| 61 |
+
"29133",
|
| 62 |
+
"31601",
|
| 63 |
+
"32086",
|
| 64 |
+
"30238",
|
| 65 |
+
"25308",
|
| 66 |
+
"32746",
|
| 67 |
+
"28164",
|
| 68 |
+
"32932",
|
| 69 |
+
"34617",
|
| 70 |
+
"23782"
|
| 71 |
+
],
|
| 72 |
+
"dd3473c941b94dd6654e9f89bb51cac9": [
|
| 73 |
+
"30666"
|
| 74 |
+
],
|
| 75 |
+
"8fae315ff2d1ea185b4ed6d0d092ae0e": [
|
| 76 |
+
"33930",
|
| 77 |
+
"32052",
|
| 78 |
+
"32213"
|
| 79 |
+
],
|
| 80 |
+
"8999452509fdfa98335c5ba44ed05498": [
|
| 81 |
+
"26503",
|
| 82 |
+
"20043",
|
| 83 |
+
"29525",
|
| 84 |
+
"26657",
|
| 85 |
+
"30341"
|
| 86 |
+
],
|
| 87 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 88 |
+
"19836",
|
| 89 |
+
"22301",
|
| 90 |
+
"30663",
|
| 91 |
+
"25144",
|
| 92 |
+
"25493",
|
| 93 |
+
"32354",
|
| 94 |
+
"22433",
|
| 95 |
+
"32324",
|
| 96 |
+
"30857",
|
| 97 |
+
"26387",
|
| 98 |
+
"32259",
|
| 99 |
+
"33457",
|
| 100 |
+
"20745",
|
| 101 |
+
"33116",
|
| 102 |
+
"25913"
|
| 103 |
+
],
|
| 104 |
+
"0e3c4946251b437ca90e5fe70efdea5b": [
|
| 105 |
+
"19898",
|
| 106 |
+
"26608",
|
| 107 |
+
"22367",
|
| 108 |
+
"24152"
|
| 109 |
+
],
|
| 110 |
+
"c24dd733315066f7c7da3d578f954d8c": [
|
| 111 |
+
"34610",
|
| 112 |
+
"33810"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
"Dishwasher": {
|
| 116 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 117 |
+
"12480",
|
| 118 |
+
"12579",
|
| 119 |
+
"12542",
|
| 120 |
+
"12085",
|
| 121 |
+
"12559",
|
| 122 |
+
"12092",
|
| 123 |
+
"12580",
|
| 124 |
+
"12565",
|
| 125 |
+
"12612",
|
| 126 |
+
"11622",
|
| 127 |
+
"12530",
|
| 128 |
+
"12071",
|
| 129 |
+
"12654",
|
| 130 |
+
"12259",
|
| 131 |
+
"12558",
|
| 132 |
+
"11700",
|
| 133 |
+
"12553",
|
| 134 |
+
"12414",
|
| 135 |
+
"12543",
|
| 136 |
+
"12561",
|
| 137 |
+
"12590",
|
| 138 |
+
"12540",
|
| 139 |
+
"12621",
|
| 140 |
+
"12531",
|
| 141 |
+
"12614",
|
| 142 |
+
"12560",
|
| 143 |
+
"12428",
|
| 144 |
+
"12606",
|
| 145 |
+
"12552",
|
| 146 |
+
"12592",
|
| 147 |
+
"12617",
|
| 148 |
+
"12583",
|
| 149 |
+
"12605",
|
| 150 |
+
"12596",
|
| 151 |
+
"11661"
|
| 152 |
+
],
|
| 153 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 154 |
+
"12349",
|
| 155 |
+
"11826"
|
| 156 |
+
],
|
| 157 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 158 |
+
"12065",
|
| 159 |
+
"12484"
|
| 160 |
+
],
|
| 161 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 162 |
+
"12597"
|
| 163 |
+
]
|
| 164 |
+
},
|
| 165 |
+
"WashingMachine": {
|
| 166 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 167 |
+
"100283",
|
| 168 |
+
"103369",
|
| 169 |
+
"103361",
|
| 170 |
+
"100282",
|
| 171 |
+
"103425",
|
| 172 |
+
"103480",
|
| 173 |
+
"103778"
|
| 174 |
+
],
|
| 175 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 176 |
+
"103490",
|
| 177 |
+
"103775"
|
| 178 |
+
],
|
| 179 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 180 |
+
"103452",
|
| 181 |
+
"103518",
|
| 182 |
+
"103521"
|
| 183 |
+
],
|
| 184 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 185 |
+
"103508",
|
| 186 |
+
"103776"
|
| 187 |
+
],
|
| 188 |
+
"c24dd733315066f7c7da3d578f954d8c": [
|
| 189 |
+
"103781",
|
| 190 |
+
"103528"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
"Microwave": {
|
| 194 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 195 |
+
"7304",
|
| 196 |
+
"7236",
|
| 197 |
+
"7221",
|
| 198 |
+
"7263",
|
| 199 |
+
"7292",
|
| 200 |
+
"7306",
|
| 201 |
+
"7320",
|
| 202 |
+
"7310"
|
| 203 |
+
],
|
| 204 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 205 |
+
"7273",
|
| 206 |
+
"7167",
|
| 207 |
+
"7296",
|
| 208 |
+
"7349"
|
| 209 |
+
],
|
| 210 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 211 |
+
"7366",
|
| 212 |
+
"7119"
|
| 213 |
+
],
|
| 214 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 215 |
+
"7128"
|
| 216 |
+
]
|
| 217 |
+
},
|
| 218 |
+
"Oven": {
|
| 219 |
+
"ddc2ef1be48dc58fe68226818824b648": [
|
| 220 |
+
"101946",
|
| 221 |
+
"101930"
|
| 222 |
+
],
|
| 223 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 224 |
+
"7130",
|
| 225 |
+
"101971",
|
| 226 |
+
"7290",
|
| 227 |
+
"7138"
|
| 228 |
+
],
|
| 229 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 230 |
+
"101773",
|
| 231 |
+
"102055"
|
| 232 |
+
],
|
| 233 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 234 |
+
"7220",
|
| 235 |
+
"7187"
|
| 236 |
+
],
|
| 237 |
+
"0e3c4946251b437ca90e5fe70efdea5b": [
|
| 238 |
+
"101921",
|
| 239 |
+
"101943",
|
| 240 |
+
"101917",
|
| 241 |
+
"101947"
|
| 242 |
+
],
|
| 243 |
+
"8fae315ff2d1ea185b4ed6d0d092ae0e": [
|
| 244 |
+
"102018",
|
| 245 |
+
"102044",
|
| 246 |
+
"102060",
|
| 247 |
+
"101931"
|
| 248 |
+
],
|
| 249 |
+
"538827adb8c90adf1322121db4e66fef": [
|
| 250 |
+
"101924"
|
| 251 |
+
],
|
| 252 |
+
"c24dd733315066f7c7da3d578f954d8c": [
|
| 253 |
+
"101940",
|
| 254 |
+
"7347",
|
| 255 |
+
"102019",
|
| 256 |
+
"7332",
|
| 257 |
+
"7120",
|
| 258 |
+
"7201",
|
| 259 |
+
"101909",
|
| 260 |
+
"7179"
|
| 261 |
+
],
|
| 262 |
+
"dd3473c941b94dd6654e9f89bb51cac9": [
|
| 263 |
+
"101808",
|
| 264 |
+
"102001"
|
| 265 |
+
],
|
| 266 |
+
"ebb9c0168d323bca8e92227bdaa7a788": [
|
| 267 |
+
"101908"
|
| 268 |
+
]
|
| 269 |
+
},
|
| 270 |
+
"Refrigerator": {
|
| 271 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 272 |
+
"12066",
|
| 273 |
+
"10036",
|
| 274 |
+
"11712",
|
| 275 |
+
"11178",
|
| 276 |
+
"10867",
|
| 277 |
+
"10900",
|
| 278 |
+
"12059",
|
| 279 |
+
"11299",
|
| 280 |
+
"10620",
|
| 281 |
+
"10685",
|
| 282 |
+
"11846",
|
| 283 |
+
"12043",
|
| 284 |
+
"12248",
|
| 285 |
+
"10347",
|
| 286 |
+
"10489",
|
| 287 |
+
"10751",
|
| 288 |
+
"12036",
|
| 289 |
+
"10143",
|
| 290 |
+
"10655",
|
| 291 |
+
"10612",
|
| 292 |
+
"10068",
|
| 293 |
+
"10638",
|
| 294 |
+
"12050",
|
| 295 |
+
"11231",
|
| 296 |
+
"10586",
|
| 297 |
+
"11550",
|
| 298 |
+
"11304",
|
| 299 |
+
"10627"
|
| 300 |
+
],
|
| 301 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 302 |
+
"10849",
|
| 303 |
+
"12054",
|
| 304 |
+
"12038",
|
| 305 |
+
"10373",
|
| 306 |
+
"10944",
|
| 307 |
+
"12252",
|
| 308 |
+
"10905",
|
| 309 |
+
"11260",
|
| 310 |
+
"12055",
|
| 311 |
+
"10144",
|
| 312 |
+
"10797",
|
| 313 |
+
"11211",
|
| 314 |
+
"12250",
|
| 315 |
+
"12042",
|
| 316 |
+
"12249"
|
| 317 |
+
],
|
| 318 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 319 |
+
"11709"
|
| 320 |
+
]
|
| 321 |
+
},
|
| 322 |
+
"Safe": {
|
| 323 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 324 |
+
"102423",
|
| 325 |
+
"101583",
|
| 326 |
+
"102318",
|
| 327 |
+
"101612",
|
| 328 |
+
"102301",
|
| 329 |
+
"101591",
|
| 330 |
+
"101594",
|
| 331 |
+
"102380",
|
| 332 |
+
"101613",
|
| 333 |
+
"101599",
|
| 334 |
+
"101619",
|
| 335 |
+
"101605",
|
| 336 |
+
"102389",
|
| 337 |
+
"101363",
|
| 338 |
+
"102387"
|
| 339 |
+
],
|
| 340 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 341 |
+
"102311",
|
| 342 |
+
"101579",
|
| 343 |
+
"101611",
|
| 344 |
+
"102384",
|
| 345 |
+
"102309",
|
| 346 |
+
"101623",
|
| 347 |
+
"101593",
|
| 348 |
+
"101603",
|
| 349 |
+
"102316",
|
| 350 |
+
"101584",
|
| 351 |
+
"102381"
|
| 352 |
+
],
|
| 353 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 354 |
+
"101604",
|
| 355 |
+
"101564",
|
| 356 |
+
"102418"
|
| 357 |
+
]
|
| 358 |
+
},
|
| 359 |
+
"StorageFurniture": {
|
| 360 |
+
"cde6b48ed870286595c1455af7aff8bd": [
|
| 361 |
+
"47711",
|
| 362 |
+
"45841",
|
| 363 |
+
"45427",
|
| 364 |
+
"46653",
|
| 365 |
+
"48740",
|
| 366 |
+
"46544",
|
| 367 |
+
"44962",
|
| 368 |
+
"48517",
|
| 369 |
+
"47296",
|
| 370 |
+
"46230",
|
| 371 |
+
"45213",
|
| 372 |
+
"44853",
|
| 373 |
+
"40453",
|
| 374 |
+
"45632",
|
| 375 |
+
"45949",
|
| 376 |
+
"48855",
|
| 377 |
+
"44781",
|
| 378 |
+
"48797",
|
| 379 |
+
"45132",
|
| 380 |
+
"46981",
|
| 381 |
+
"46440",
|
| 382 |
+
"48491",
|
| 383 |
+
"45661",
|
| 384 |
+
"45503",
|
| 385 |
+
"47853",
|
| 386 |
+
"47252",
|
| 387 |
+
"48010",
|
| 388 |
+
"46879",
|
| 389 |
+
"48876",
|
| 390 |
+
"46537",
|
| 391 |
+
"45622",
|
| 392 |
+
"46641",
|
| 393 |
+
"46334",
|
| 394 |
+
"48253",
|
| 395 |
+
"48063",
|
| 396 |
+
"47233",
|
| 397 |
+
"45940",
|
| 398 |
+
"45290",
|
| 399 |
+
"46874",
|
| 400 |
+
"47438",
|
| 401 |
+
"45135",
|
| 402 |
+
"46130",
|
| 403 |
+
"46123",
|
| 404 |
+
"47088",
|
| 405 |
+
"45756",
|
| 406 |
+
"47944",
|
| 407 |
+
"47178"
|
| 408 |
+
],
|
| 409 |
+
"3ba4ffe16dfe637510ed1c3676ec6cb0": [
|
| 410 |
+
"47168",
|
| 411 |
+
"45176",
|
| 412 |
+
"45908",
|
| 413 |
+
"45249",
|
| 414 |
+
"46127",
|
| 415 |
+
"45633",
|
| 416 |
+
"45936",
|
| 417 |
+
"47388",
|
| 418 |
+
"45091",
|
| 419 |
+
"45385",
|
| 420 |
+
"48721",
|
| 421 |
+
"45667",
|
| 422 |
+
"41004",
|
| 423 |
+
"41529",
|
| 424 |
+
"45671",
|
| 425 |
+
"45130",
|
| 426 |
+
"45504",
|
| 427 |
+
"45638",
|
| 428 |
+
"45087",
|
| 429 |
+
"45164",
|
| 430 |
+
"48036",
|
| 431 |
+
"47651",
|
| 432 |
+
"45950",
|
| 433 |
+
"45623",
|
| 434 |
+
"46132",
|
| 435 |
+
"45693",
|
| 436 |
+
"45783",
|
| 437 |
+
"45007",
|
| 438 |
+
"47180",
|
| 439 |
+
"45690",
|
| 440 |
+
"48271",
|
| 441 |
+
"45267",
|
| 442 |
+
"46092",
|
| 443 |
+
"45134",
|
| 444 |
+
"47021",
|
| 445 |
+
"46744",
|
| 446 |
+
"47187",
|
| 447 |
+
"48686",
|
| 448 |
+
"46556",
|
| 449 |
+
"46408",
|
| 450 |
+
"46616",
|
| 451 |
+
"46439",
|
| 452 |
+
"45448",
|
| 453 |
+
"47963",
|
| 454 |
+
"45212",
|
| 455 |
+
"45413",
|
| 456 |
+
"45645",
|
| 457 |
+
"38516",
|
| 458 |
+
"45822",
|
| 459 |
+
"47315",
|
| 460 |
+
"45916",
|
| 461 |
+
"45173",
|
| 462 |
+
"47686",
|
| 463 |
+
"45699",
|
| 464 |
+
"45244",
|
| 465 |
+
"46889",
|
| 466 |
+
"46906",
|
| 467 |
+
"46430",
|
| 468 |
+
"45717",
|
| 469 |
+
"45937",
|
| 470 |
+
"45203",
|
| 471 |
+
"45372",
|
| 472 |
+
"47133",
|
| 473 |
+
"46966",
|
| 474 |
+
"45419",
|
| 475 |
+
"45779",
|
| 476 |
+
"45297",
|
| 477 |
+
"46922",
|
| 478 |
+
"46944",
|
| 479 |
+
"46417",
|
| 480 |
+
"45850",
|
| 481 |
+
"45620",
|
| 482 |
+
"46044",
|
| 483 |
+
"45248",
|
| 484 |
+
"48452",
|
| 485 |
+
"47817",
|
| 486 |
+
"45964",
|
| 487 |
+
"46029",
|
| 488 |
+
"46179",
|
| 489 |
+
"46787",
|
| 490 |
+
"41452",
|
| 491 |
+
"45166",
|
| 492 |
+
"48167",
|
| 493 |
+
"48243",
|
| 494 |
+
"45177",
|
| 495 |
+
"45961",
|
| 496 |
+
"47632",
|
| 497 |
+
"46033",
|
| 498 |
+
"47391",
|
| 499 |
+
"47316",
|
| 500 |
+
"41086",
|
| 501 |
+
"45384",
|
| 502 |
+
"45526",
|
| 503 |
+
"45915",
|
| 504 |
+
"47281",
|
| 505 |
+
"45415",
|
| 506 |
+
"45606",
|
| 507 |
+
"47742",
|
| 508 |
+
"46401",
|
| 509 |
+
"45323",
|
| 510 |
+
"45178",
|
| 511 |
+
"47514",
|
| 512 |
+
"46117",
|
| 513 |
+
"46197",
|
| 514 |
+
"48467",
|
| 515 |
+
"46452",
|
| 516 |
+
"48519",
|
| 517 |
+
"45855",
|
| 518 |
+
"47729",
|
| 519 |
+
"46057",
|
| 520 |
+
"45600",
|
| 521 |
+
"35059",
|
| 522 |
+
"45691",
|
| 523 |
+
"45524",
|
| 524 |
+
"48490",
|
| 525 |
+
"47182",
|
| 526 |
+
"45910",
|
| 527 |
+
"48413",
|
| 528 |
+
"45247",
|
| 529 |
+
"48023",
|
| 530 |
+
"48746",
|
| 531 |
+
"47419",
|
| 532 |
+
"45621",
|
| 533 |
+
"45403",
|
| 534 |
+
"46427",
|
| 535 |
+
"45443",
|
| 536 |
+
"45922",
|
| 537 |
+
"47601",
|
| 538 |
+
"45853",
|
| 539 |
+
"45516",
|
| 540 |
+
"46107"
|
| 541 |
+
],
|
| 542 |
+
"8fae315ff2d1ea185b4ed6d0d092ae0e": [
|
| 543 |
+
"45271",
|
| 544 |
+
"46199",
|
| 545 |
+
"47669",
|
| 546 |
+
"49188",
|
| 547 |
+
"46145",
|
| 548 |
+
"47926",
|
| 549 |
+
"45948",
|
| 550 |
+
"47290",
|
| 551 |
+
"40417",
|
| 552 |
+
"48623",
|
| 553 |
+
"47235",
|
| 554 |
+
"49062",
|
| 555 |
+
"47648",
|
| 556 |
+
"45612"
|
| 557 |
+
],
|
| 558 |
+
"1e8c6b47706f002757c3370366001f06": [
|
| 559 |
+
"47466"
|
| 560 |
+
],
|
| 561 |
+
"c502b67eb6d91d909ba398fa39bec60c": [
|
| 562 |
+
"45092",
|
| 563 |
+
"46443",
|
| 564 |
+
"45696",
|
| 565 |
+
"45189",
|
| 566 |
+
"46108",
|
| 567 |
+
"48356",
|
| 568 |
+
"45238",
|
| 569 |
+
"41003",
|
| 570 |
+
"45642",
|
| 571 |
+
"49025",
|
| 572 |
+
"46955",
|
| 573 |
+
"45749",
|
| 574 |
+
"45801",
|
| 575 |
+
"45374",
|
| 576 |
+
"46084",
|
| 577 |
+
"47227",
|
| 578 |
+
"47578",
|
| 579 |
+
"45243",
|
| 580 |
+
"45636",
|
| 581 |
+
"45387",
|
| 582 |
+
"41083",
|
| 583 |
+
"48258",
|
| 584 |
+
"41085",
|
| 585 |
+
"46060",
|
| 586 |
+
"44817",
|
| 587 |
+
"46598",
|
| 588 |
+
"46002",
|
| 589 |
+
"48169",
|
| 590 |
+
"46762",
|
| 591 |
+
"45710",
|
| 592 |
+
"48018",
|
| 593 |
+
"45262",
|
| 594 |
+
"47207",
|
| 595 |
+
"46856",
|
| 596 |
+
"46466",
|
| 597 |
+
"47595",
|
| 598 |
+
"47701",
|
| 599 |
+
"45332",
|
| 600 |
+
"49140",
|
| 601 |
+
"45219",
|
| 602 |
+
"48263",
|
| 603 |
+
"45677",
|
| 604 |
+
"47976",
|
| 605 |
+
"46549",
|
| 606 |
+
"45194",
|
| 607 |
+
"45759",
|
| 608 |
+
"46893",
|
| 609 |
+
"46120",
|
| 610 |
+
"47089"
|
| 611 |
+
],
|
| 612 |
+
"5144181ac27497fdfa9bdb5b8b799630": [
|
| 613 |
+
"45575",
|
| 614 |
+
"45594",
|
| 615 |
+
"45235",
|
| 616 |
+
"46655",
|
| 617 |
+
"49182",
|
| 618 |
+
"45687",
|
| 619 |
+
"46403",
|
| 620 |
+
"46768",
|
| 621 |
+
"46896",
|
| 622 |
+
"45001",
|
| 623 |
+
"47613",
|
| 624 |
+
"47254",
|
| 625 |
+
"46180",
|
| 626 |
+
"47024",
|
| 627 |
+
"46825",
|
| 628 |
+
"48379",
|
| 629 |
+
"49132",
|
| 630 |
+
"48878",
|
| 631 |
+
"47577",
|
| 632 |
+
"47565",
|
| 633 |
+
"45573",
|
| 634 |
+
"44826",
|
| 635 |
+
"46847",
|
| 636 |
+
"46732",
|
| 637 |
+
"45168",
|
| 638 |
+
"46277",
|
| 639 |
+
"47238",
|
| 640 |
+
"45746",
|
| 641 |
+
"47808",
|
| 642 |
+
"45662",
|
| 643 |
+
"48381",
|
| 644 |
+
"45963",
|
| 645 |
+
"45354",
|
| 646 |
+
"45676",
|
| 647 |
+
"47278",
|
| 648 |
+
"47529",
|
| 649 |
+
"46437",
|
| 650 |
+
"45378",
|
| 651 |
+
"46563",
|
| 652 |
+
"47570",
|
| 653 |
+
"45444",
|
| 654 |
+
"48700",
|
| 655 |
+
"45780",
|
| 656 |
+
"47099",
|
| 657 |
+
"46490",
|
| 658 |
+
"45523",
|
| 659 |
+
"47747",
|
| 660 |
+
"46045",
|
| 661 |
+
"45305",
|
| 662 |
+
"40147",
|
| 663 |
+
"49133",
|
| 664 |
+
"46700",
|
| 665 |
+
"46236",
|
| 666 |
+
"45505",
|
| 667 |
+
"48859",
|
| 668 |
+
"46166",
|
| 669 |
+
"46456",
|
| 670 |
+
"45162",
|
| 671 |
+
"45776",
|
| 672 |
+
"45420",
|
| 673 |
+
"46481",
|
| 674 |
+
"45767",
|
| 675 |
+
"45423",
|
| 676 |
+
"45790",
|
| 677 |
+
"49038",
|
| 678 |
+
"45670",
|
| 679 |
+
"47954",
|
| 680 |
+
"48479",
|
| 681 |
+
"46019",
|
| 682 |
+
"46801",
|
| 683 |
+
"45984",
|
| 684 |
+
"45159",
|
| 685 |
+
"49042",
|
| 686 |
+
"45784",
|
| 687 |
+
"48177",
|
| 688 |
+
"46859",
|
| 689 |
+
"46741",
|
| 690 |
+
"46134",
|
| 691 |
+
"45694",
|
| 692 |
+
"46480",
|
| 693 |
+
"45689",
|
| 694 |
+
"46037",
|
| 695 |
+
"47185",
|
| 696 |
+
"45397",
|
| 697 |
+
"48013",
|
| 698 |
+
"46699",
|
| 699 |
+
"48513",
|
| 700 |
+
"45146",
|
| 701 |
+
"45463",
|
| 702 |
+
"41510",
|
| 703 |
+
"45747"
|
| 704 |
+
],
|
| 705 |
+
"8999452509fdfa98335c5ba44ed05498": [
|
| 706 |
+
"46172",
|
| 707 |
+
"46014",
|
| 708 |
+
"45261",
|
| 709 |
+
"46109",
|
| 710 |
+
"46380"
|
| 711 |
+
],
|
| 712 |
+
"0e3c4946251b437ca90e5fe70efdea5b": [
|
| 713 |
+
"48497"
|
| 714 |
+
],
|
| 715 |
+
"c24dd733315066f7c7da3d578f954d8c": [
|
| 716 |
+
"48051"
|
| 717 |
+
],
|
| 718 |
+
"87036528afd9c9dd03a6ab72efec0136": [
|
| 719 |
+
"45725"
|
| 720 |
+
]
|
| 721 |
+
}
|
| 722 |
+
}
|
scripts/graph_pred/api.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
import base64
|
| 6 |
+
import argparse
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
from openai import AzureOpenAI
|
| 10 |
+
from scripts.graph_pred.prompt_workflow_new import messages
|
| 11 |
+
import json_repair
|
| 12 |
+
# Initialize the OpenAI client
|
| 13 |
+
|
| 14 |
+
endpoint = os.environ.get("ENDPOINT")
|
| 15 |
+
api_key = os.environ.get("API_KEY")
|
| 16 |
+
api_version = os.environ.get("API_VERSION")
|
| 17 |
+
model_name = os.environ.get("MODEL_NAME")
|
| 18 |
+
client = AzureOpenAI(
|
| 19 |
+
azure_endpoint=endpoint,
|
| 20 |
+
api_key=api_key,
|
| 21 |
+
api_version=api_version,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def encode_image(image_path: str, center_crop=False):
|
| 26 |
+
"""Resize and encode the image as base64"""
|
| 27 |
+
# load the image
|
| 28 |
+
image = Image.open(image_path)
|
| 29 |
+
|
| 30 |
+
# resize the image to 224x224
|
| 31 |
+
if center_crop: # (resize to 256x256 and then center crop to 224x224)
|
| 32 |
+
image = image.resize((256, 256))
|
| 33 |
+
width, height = image.size
|
| 34 |
+
left = (width - 224) / 2
|
| 35 |
+
top = (height - 224) / 2
|
| 36 |
+
right = (width + 224) / 2
|
| 37 |
+
bottom = (height + 224) / 2
|
| 38 |
+
image = image.crop((left, top, right, bottom))
|
| 39 |
+
else:
|
| 40 |
+
image = image.resize((224, 224))
|
| 41 |
+
|
| 42 |
+
# conver the image to bytes
|
| 43 |
+
buffer = BytesIO()
|
| 44 |
+
image.save(buffer, format="PNG")
|
| 45 |
+
buffer.seek(0)
|
| 46 |
+
# encode the image as base64
|
| 47 |
+
encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
|
| 48 |
+
return encoded_image
|
| 49 |
+
|
| 50 |
+
def display_image(image_data):
|
| 51 |
+
"""Display the image from the base64 encoded image data"""
|
| 52 |
+
img = Image.open(BytesIO(base64.b64decode(image_data)))
|
| 53 |
+
img.show()
|
| 54 |
+
img.close()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def convert_format(src):
|
| 58 |
+
'''Convert the JSON format from the response to a tree format'''
|
| 59 |
+
def _sort_nodes(tree):
|
| 60 |
+
num_nodes = len(tree)
|
| 61 |
+
sorted_tree = [dict() for _ in range(num_nodes)]
|
| 62 |
+
for node in tree:
|
| 63 |
+
sorted_tree[node["id"]] = node
|
| 64 |
+
return sorted_tree
|
| 65 |
+
|
| 66 |
+
def _traverse(node, parent_id, current_id):
|
| 67 |
+
for key, value in node.items():
|
| 68 |
+
node_id = current_id[0]
|
| 69 |
+
current_id[0] += 1
|
| 70 |
+
|
| 71 |
+
# Create the node
|
| 72 |
+
tree_node = {
|
| 73 |
+
"id": node_id,
|
| 74 |
+
"parent": parent_id,
|
| 75 |
+
"name": key,
|
| 76 |
+
"children": [],
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
# Traverse children if they exist
|
| 80 |
+
if isinstance(value, list):
|
| 81 |
+
for child in value:
|
| 82 |
+
child_id = _traverse(child, node_id, current_id)
|
| 83 |
+
tree_node["children"].append(child_id)
|
| 84 |
+
|
| 85 |
+
# Add this node to the tree
|
| 86 |
+
tree.append(tree_node)
|
| 87 |
+
return node_id
|
| 88 |
+
|
| 89 |
+
tree = []
|
| 90 |
+
current_id = [0]
|
| 91 |
+
_traverse(src, -1, current_id)
|
| 92 |
+
diffuse_tree = _sort_nodes(tree)
|
| 93 |
+
return diffuse_tree
|
| 94 |
+
|
| 95 |
+
def predict_graph_twomode(image_path, first_img_data=None, second_img_data=None, debug=False, center_crop=False):
|
| 96 |
+
'''Predict the part connectivity graph from the image'''
|
| 97 |
+
# Encode the image
|
| 98 |
+
if first_img_data is None or second_img_data is None:
|
| 99 |
+
first_img_data = encode_image(image_path, center_crop)
|
| 100 |
+
second_img_data = encode_image(image_path.replace('close', 'open'), center_crop)
|
| 101 |
+
# if debug:
|
| 102 |
+
# display_image(image_data) # for double checking the image
|
| 103 |
+
# breakpoint()
|
| 104 |
+
new_message = messages.copy()
|
| 105 |
+
new_message.append(
|
| 106 |
+
{
|
| 107 |
+
"role": "user",
|
| 108 |
+
"content": [
|
| 109 |
+
{
|
| 110 |
+
"type": "image_url",
|
| 111 |
+
"image_url": {"url": f"data:image/png;base64,{first_img_data}"},
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"type": "image_url",
|
| 115 |
+
"image_url": {"url": f"data:image/png;base64,{second_img_data}"},
|
| 116 |
+
}
|
| 117 |
+
],
|
| 118 |
+
},
|
| 119 |
+
)
|
| 120 |
+
# Get the completion from the model
|
| 121 |
+
completion = client.chat.completions.create(
|
| 122 |
+
model=model_name,
|
| 123 |
+
messages=new_message,
|
| 124 |
+
response_format={"type": "text"},
|
| 125 |
+
temperature=1,
|
| 126 |
+
max_tokens=4096,
|
| 127 |
+
top_p=1,
|
| 128 |
+
frequency_penalty=0,
|
| 129 |
+
presence_penalty=0,
|
| 130 |
+
)
|
| 131 |
+
print('processing the response...')
|
| 132 |
+
|
| 133 |
+
# Extract the response
|
| 134 |
+
content = completion.choices[0].message.content
|
| 135 |
+
|
| 136 |
+
src = json.loads(re.search(r"```json\n(.*?)\n```", content, re.DOTALL).group(1))
|
| 137 |
+
print(src)
|
| 138 |
+
# Convert the JSON format to tree format
|
| 139 |
+
diffuse_tree = convert_format(src)
|
| 140 |
+
|
| 141 |
+
return {"diffuse_tree": diffuse_tree, "original_response": content}
|
| 142 |
+
|
| 143 |
+
def save_response(save_path, response):
|
| 144 |
+
'''Save the response to a json file'''
|
| 145 |
+
with open(save_path, "w") as file:
|
| 146 |
+
json.dump(response, file, indent=4)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def gpt_infer_image_category(image1, image2):
|
| 151 |
+
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties."
|
| 152 |
+
|
| 153 |
+
text_prompt = (
|
| 154 |
+
"Given two images of an object, determine its category. "
|
| 155 |
+
"The category must be one of the following: Table, Dishwasher, StorageFurniture, "
|
| 156 |
+
"Refrigerator, WashingMachine, Microwave, Oven. "
|
| 157 |
+
"Output only the category name and nothing else. Do not include any other text."
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
content_user = [
|
| 161 |
+
{
|
| 162 |
+
"type": "text",
|
| 163 |
+
"text": text_prompt,
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"type": "image_url",
|
| 167 |
+
"image_url": {"url": f"data:image/png;base64,{image1}"},
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"type": "image_url",
|
| 171 |
+
"image_url": {"url": f"data:image/png;base64,{image2}"},
|
| 172 |
+
},
|
| 173 |
+
]
|
| 174 |
+
payload = {
|
| 175 |
+
"messages": [
|
| 176 |
+
{"role": "system", "content": system_role},
|
| 177 |
+
{"role": "user", "content": content_user},
|
| 178 |
+
],
|
| 179 |
+
"temperature": 0.1,
|
| 180 |
+
"max_tokens": 500,
|
| 181 |
+
"top_p": 0.1,
|
| 182 |
+
"frequency_penalty": 0,
|
| 183 |
+
"presence_penalty": 0,
|
| 184 |
+
"stop": None,
|
| 185 |
+
"model": model_name,
|
| 186 |
+
}
|
| 187 |
+
completion = client.chat.completions.create(**payload)
|
| 188 |
+
response = completion.choices[0].message.content
|
| 189 |
+
json_repair.loads(response)
|
| 190 |
+
|
| 191 |
+
return response
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
parser = argparse.ArgumentParser(description="Predict the part connectivity graph from an image")
|
| 196 |
+
parser.add_argument("--img_path", type=str, required=True, help="path to the image")
|
| 197 |
+
parser.add_argument("--save_path", type=str, required=True, help="path to the save the response")
|
| 198 |
+
parser.add_argument("--center_crop", action="store_true", help="whether to center crop the image to 224x224, otherwise resize to 224x224")
|
| 199 |
+
args = parser.parse_args()
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
response = predict_graph(args.img_path, args.center_crop)
|
| 203 |
+
save_response(args.save_path, response)
|
| 204 |
+
response = predict_graph_twomode(args.img_path, args.center_crop)
|
| 205 |
+
save_response(args.save_path[:-5] + 'twomode.json', response)
|
| 206 |
+
except Exception as e:
|
| 207 |
+
with open('openai_err.log', 'a') as f:
|
| 208 |
+
f.write('---------------------------\n')
|
| 209 |
+
f.write(f'{args.img_path}\n')
|
| 210 |
+
f.write(f'{e}\n')
|
scripts/graph_pred/eval.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
import networkx as nx
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
def get_hash(file, key='diffuse_tree'):
|
| 8 |
+
tree = file[key]
|
| 9 |
+
G = nx.DiGraph()
|
| 10 |
+
for node in tree:
|
| 11 |
+
G.add_node(node['id'])
|
| 12 |
+
if node['parent'] != -1:
|
| 13 |
+
G.add_edge(node['id'], node['parent'])
|
| 14 |
+
hashcode = nx.weisfeiler_lehman_graph_hash(G)
|
| 15 |
+
return hashcode
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
'''Script to evaluate the accuracy of the generated graphs'''
|
| 19 |
+
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument('--exp_dir', type=str, required=True, help='path to the experiment directory')
|
| 22 |
+
parser.add_argument('--gt_data_root', type=str, required=True, help='root directory of the ground-truth data')
|
| 23 |
+
parser.add_argument('--gt_json_name', type=str, default='object.json', help='Path to the ground truth data')
|
| 24 |
+
args = parser.parse_args()
|
| 25 |
+
|
| 26 |
+
assert os.path.exists(args.exp_dir), "The experiment directory does not exist"
|
| 27 |
+
assert os.path.exists(args.gt_data_root), "The ground-truth data root does not exist"
|
| 28 |
+
|
| 29 |
+
exp_dir = args.exp_dir
|
| 30 |
+
gt_data_dir = args.gt_data_root
|
| 31 |
+
|
| 32 |
+
acc = 0
|
| 33 |
+
files = os.listdir(exp_dir)
|
| 34 |
+
sorted(files)
|
| 35 |
+
total = len(files)
|
| 36 |
+
wrong_files = []
|
| 37 |
+
for file in tqdm(files):
|
| 38 |
+
tokens = file.split('@')
|
| 39 |
+
gt_dir = f'{gt_data_dir}'
|
| 40 |
+
for token in tokens[:-1]:
|
| 41 |
+
gt_dir = os.path.join(gt_dir, token)
|
| 42 |
+
with open(os.path.join(gt_dir, args.gt_json_name)) as f:
|
| 43 |
+
gt = json.load(f)
|
| 44 |
+
# load json files
|
| 45 |
+
with open(os.path.join(exp_dir, file)) as f:
|
| 46 |
+
pred = json.load(f)
|
| 47 |
+
# get hash for the graph
|
| 48 |
+
pred_hash = get_hash(pred)
|
| 49 |
+
gt_hash = get_hash(gt)
|
| 50 |
+
# compare hash
|
| 51 |
+
if pred_hash == gt_hash:
|
| 52 |
+
acc += 1
|
| 53 |
+
else:
|
| 54 |
+
wrong_files.append(file)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
with open(os.path.join(os.path.dirname(exp_dir), f'acc_{os.path.basename(exp_dir)}.json'), 'w') as f:
|
| 58 |
+
json.dump({'acc': acc/total, 'wrong_files': wrong_files}, f, indent=4)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
scripts/graph_pred/prompt_workflow_new.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
import base64
|
| 5 |
+
|
| 6 |
+
def encode_image(image_path: str, center_crop=False):
|
| 7 |
+
"""Resize and encode the image as base64"""
|
| 8 |
+
# load the image
|
| 9 |
+
image = Image.open(image_path)
|
| 10 |
+
|
| 11 |
+
# resize the image to 224x224
|
| 12 |
+
if center_crop: # (resize to 256x256 and then center crop to 224x224)
|
| 13 |
+
image = image.resize((256, 256))
|
| 14 |
+
width, height = image.size
|
| 15 |
+
left = (width - 224) / 2
|
| 16 |
+
top = (height - 224) / 2
|
| 17 |
+
right = (width + 224) / 2
|
| 18 |
+
bottom = (height + 224) / 2
|
| 19 |
+
image = image.crop((left, top, right, bottom))
|
| 20 |
+
else:
|
| 21 |
+
image = image.resize((512, 512))
|
| 22 |
+
|
| 23 |
+
# conver the image to bytes
|
| 24 |
+
buffer = BytesIO()
|
| 25 |
+
image.save(buffer, format="PNG")
|
| 26 |
+
buffer.seek(0)
|
| 27 |
+
# encode the image as base64
|
| 28 |
+
encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
|
| 29 |
+
return encoded_image
|
| 30 |
+
|
| 31 |
+
system_prompt = """
|
| 32 |
+
You are an expert in the recognition, structural parsing, and physical‑feasibility validation of articulated objects from image inputs.
|
| 33 |
+
|
| 34 |
+
You will be provided with two rendered images of the same object:
|
| 35 |
+
1. A closed‑state image (all movable parts in their fully closed positions)
|
| 36 |
+
2. An open‑state image (all movable parts in their fully opened positions)
|
| 37 |
+
|
| 38 |
+
Your task is to analyze the object's articulated structure and generate a connectivity graph describing the part relationships.
|
| 39 |
+
|
| 40 |
+
You must follow this workflow:
|
| 41 |
+
|
| 42 |
+
1. Part Detection:
|
| 43 |
+
- Detect candidate parts with their coarse position in the **closed-state image**, with optional assistance from the **open-state image** to resolve ambiguous or occluded parts.
|
| 44 |
+
- Allowed part types: ['base', 'door', 'drawer', 'handle', 'knob', 'tray']
|
| 45 |
+
- Ignore small decorative things directly attached to the base.
|
| 46 |
+
- There must be exactly one "base"; "tray" is only allowed if the object is a microwave (but "tray" is not the component must be shown in microwave).
|
| 47 |
+
|
| 48 |
+
2. Step-by-Step Reasoning:
|
| 49 |
+
1. Part Listing: List all detected parts and their counts (do not infer attachment yet)
|
| 50 |
+
2. Validation: Enforce structural rules:
|
| 51 |
+
- Exactly one base
|
| 52 |
+
- Each door or drawer may have at most two handles or knobs
|
| 53 |
+
- Every handle/knob must attach to a door or drawer
|
| 54 |
+
- Trays only appear in microwaves
|
| 55 |
+
3. Attachment Inference: For each non-base part, infer its parent (e.g., "drawer_1 (attached to base)").
|
| 56 |
+
- Use the open-state image to assist in verifying attachments if the closed-state image is ambiguous.
|
| 57 |
+
4. Connectivity Graph Construction: Output a JSON tree of part relationships. The base is the root.
|
| 58 |
+
|
| 59 |
+
Example output:
|
| 60 |
+
```json
|
| 61 |
+
{
|
| 62 |
+
"base": [
|
| 63 |
+
{ "door": [ { "handle": [] } ] },
|
| 64 |
+
{ "drawer": [ { "handle": [] } ] }
|
| 65 |
+
]
|
| 66 |
+
}
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Final Output: You **MUST** output a single JSON tree representing the part connectivity of the object. Use the open-state image to improve the accuracy and completeness of your structural interpretation, but base your parsing primarily on the closed-state image.
|
| 70 |
+
Note: You **MUST** give the final output in JSON format.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
answer_sys = """
|
| 74 |
+
Great — please upload the two images of the same object:
|
| 75 |
+
|
| 76 |
+
One in the closed state (all movable parts fully closed)
|
| 77 |
+
|
| 78 |
+
One in the open state (all movable parts fully open)
|
| 79 |
+
|
| 80 |
+
Once I receive both, I’ll proceed with:
|
| 81 |
+
|
| 82 |
+
Part detection on the closed image
|
| 83 |
+
|
| 84 |
+
Structured reasoning and validation
|
| 85 |
+
|
| 86 |
+
Attachment inference
|
| 87 |
+
|
| 88 |
+
Final part connectivity graph in JSON format
|
| 89 |
+
|
| 90 |
+
Let me know when you're ready to continue!
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
answer_1 = """
|
| 94 |
+
Great! Let’s go through the articulated object parsing workflow based on the provided **closed-state** and **open-state** images.
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
### **1. Part Detection**
|
| 99 |
+
|
| 100 |
+
From the **closed-state image** (first image), with verification from the **open-state image** (second image):
|
| 101 |
+
|
| 102 |
+
- **Base**: 1 base structure (entire wooden cabinet)
|
| 103 |
+
- **Doors**: 4 doors
|
| 104 |
+
- 2 on the upper section (left & right)
|
| 105 |
+
- 2 on the lower section (left & right)
|
| 106 |
+
- **Drawers**: 4 drawers
|
| 107 |
+
- 3 in a row at the middle
|
| 108 |
+
- 1 larger drawer centered below the middle row
|
| 109 |
+
- **Handles**: 8 handles
|
| 110 |
+
- 1 per drawer
|
| 111 |
+
- 1 per door
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
### **2. Step-by-Step Reasoning**
|
| 116 |
+
|
| 117 |
+
#### 2.1 **Part Listing**
|
| 118 |
+
|
| 119 |
+
- `base`: 1
|
| 120 |
+
- `door`: 4
|
| 121 |
+
- `drawer`: 4
|
| 122 |
+
- `handle`: 8
|
| 123 |
+
|
| 124 |
+
#### 2.2 **Validation**
|
| 125 |
+
|
| 126 |
+
- ✅ Exactly one base
|
| 127 |
+
- ✅ Each door has 1 handle → 4 handles for 4 doors
|
| 128 |
+
- ✅ Each drawer has 1 handle → 4 handles for 4 drawers
|
| 129 |
+
- ✅ No trays → ✓ (not a microwave)
|
| 130 |
+
- ✅ Every handle is attached to a door or drawer (verified using open-state image)
|
| 131 |
+
|
| 132 |
+
All rules are satisfied.
|
| 133 |
+
|
| 134 |
+
#### 2.3 **Attachment Inference**
|
| 135 |
+
|
| 136 |
+
From closed/open-state comparison:
|
| 137 |
+
|
| 138 |
+
- Each **door** is directly attached to the **base**
|
| 139 |
+
- Each **drawer** is mounted to slide out of the **base**
|
| 140 |
+
- Each **handle** is on its respective **door** or **drawer**
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
### **3. Connectivity Graph Construction**
|
| 145 |
+
|
| 146 |
+
Final hierarchical structure:
|
| 147 |
+
|
| 148 |
+
```json
|
| 149 |
+
{
|
| 150 |
+
"base": [
|
| 151 |
+
{ "door_1": [ { "handle_1": [] } ] },
|
| 152 |
+
{ "door_2": [ { "handle_2": [] } ] },
|
| 153 |
+
{ "door_3": [ { "handle_3": [] } ] },
|
| 154 |
+
{ "door_4": [ { "handle_4": [] } ] },
|
| 155 |
+
{ "drawer_1": [ { "handle_5": [] } ] },
|
| 156 |
+
{ "drawer_2": [ { "handle_6": [] } ] },
|
| 157 |
+
{ "drawer_3": [ { "handle_7": [] } ] },
|
| 158 |
+
{ "drawer_4": [ { "handle_8": [] } ] }
|
| 159 |
+
]
|
| 160 |
+
}
|
| 161 |
+
```
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
answer_2 = """
|
| 165 |
+
Thanks for the new image pair! Let's analyze the articulated structure of the **washing machine** based on the provided closed-state and open-state images.
|
| 166 |
+
|
| 167 |
+
---
|
| 168 |
+
|
| 169 |
+
### **1. Part Detection**
|
| 170 |
+
|
| 171 |
+
From the **closed-state image**, validated using the **open-state image**:
|
| 172 |
+
|
| 173 |
+
- **Base**: 1 main body (the washing machine structure)
|
| 174 |
+
- **Door**: 1 circular front-loading door
|
| 175 |
+
- **Handle**: Notice that there is no handle attached to the dooor
|
| 176 |
+
- **Knob**: 1 control knob on the top of the machine
|
| 177 |
+
|
| 178 |
+
---
|
| 179 |
+
|
| 180 |
+
### **2. Step-by-Step Reasoning**
|
| 181 |
+
|
| 182 |
+
#### 2.1 **Part Listing**
|
| 183 |
+
|
| 184 |
+
- `base`: 1
|
| 185 |
+
- `door`: 1
|
| 186 |
+
- `knob`: 1
|
| 187 |
+
|
| 188 |
+
#### 2.2 **Validation**
|
| 189 |
+
|
| 190 |
+
- ✅ Exactly one base
|
| 191 |
+
- ✅ No drawers, trays → Valid
|
| 192 |
+
- ✅ Trays not present (correct: not a microwave)
|
| 193 |
+
|
| 194 |
+
→ **All structural constraints satisfied**
|
| 195 |
+
|
| 196 |
+
#### 2.3 **Attachment Inference**
|
| 197 |
+
|
| 198 |
+
- The **door** is front-mounted and clearly opens via a **hinge** on the left side (from viewer's perspective)
|
| 199 |
+
- The **knob** is directly mounted on the **base** (top section)
|
| 200 |
+
---
|
| 201 |
+
|
| 202 |
+
### **3. Connectivity Graph Construction**
|
| 203 |
+
|
| 204 |
+
```json
|
| 205 |
+
{
|
| 206 |
+
"base": [
|
| 207 |
+
{ "door": []},
|
| 208 |
+
{ "knob": []}
|
| 209 |
+
]
|
| 210 |
+
}
|
| 211 |
+
```
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
answer_3 = """
|
| 215 |
+
Thanks! Let’s analyze the articulated structure of this oven using your provided closed and open state images.
|
| 216 |
+
|
| 217 |
+
---
|
| 218 |
+
|
| 219 |
+
### ✅ Articulated Object Parsing (Oven)
|
| 220 |
+
|
| 221 |
+
#### **1. Part Detection**
|
| 222 |
+
|
| 223 |
+
From the **closed-state** image (with **open-state** verification):
|
| 224 |
+
|
| 225 |
+
- **base**: 1 (the main oven body)
|
| 226 |
+
- **door**: 1 (the front oven door)
|
| 227 |
+
- **handle**: 1 (long horizontal handle on the door)
|
| 228 |
+
|
| 229 |
+
> 🔍 *The 6 small control knobs on top are **ignored** as instructed.*
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
#### **2. Step-by-Step Reasoning**
|
| 234 |
+
|
| 235 |
+
**2.1 Part Listing:**
|
| 236 |
+
|
| 237 |
+
- `base`: 1
|
| 238 |
+
- `door`: 1
|
| 239 |
+
- `handle`: 1
|
| 240 |
+
|
| 241 |
+
**2.2 Validation:**
|
| 242 |
+
|
| 243 |
+
- ✅ Exactly one base
|
| 244 |
+
- ✅ Door has 1 handle
|
| 245 |
+
- ✅ Handle is attached to the door
|
| 246 |
+
- ✅ No trays (✓ valid for an oven)
|
| 247 |
+
- ✅ No knobs or extra handles beyond what's allowed
|
| 248 |
+
|
| 249 |
+
✔ All constraints are satisfied.
|
| 250 |
+
|
| 251 |
+
**2.3 Attachment Inference:**
|
| 252 |
+
|
| 253 |
+
- The **door** is bottom-hinged and connected to the **base**
|
| 254 |
+
- The **handle** is directly mounted on the door
|
| 255 |
+
|
| 256 |
+
---
|
| 257 |
+
|
| 258 |
+
#### **3. Final Connectivity Graph**
|
| 259 |
+
|
| 260 |
+
```json
|
| 261 |
+
{
|
| 262 |
+
"base": [
|
| 263 |
+
{ "door": [ { "handle": [] } ] }
|
| 264 |
+
]
|
| 265 |
+
}
|
| 266 |
+
```
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
answer_4 = """
|
| 270 |
+
Thanks! Let's now analyze the **wardrobe** based on the closed and open-state images you've provided.
|
| 271 |
+
|
| 272 |
+
---
|
| 273 |
+
|
| 274 |
+
### ✅ Articulated Object Parsing – Wardrobe
|
| 275 |
+
|
| 276 |
+
---
|
| 277 |
+
|
| 278 |
+
### **1. Part Detection**
|
| 279 |
+
|
| 280 |
+
- **Base**: 1 (entire wardrobe structure)
|
| 281 |
+
- **Doors**: 4 vertical panels (two left, two right)
|
| 282 |
+
- **Drawers**: 4 horizontal drawers at the bottom
|
| 283 |
+
- **Handles**: 8 in total
|
| 284 |
+
- 4 for doors (1 each)
|
| 285 |
+
- 4 for drawers (1 each)
|
| 286 |
+
|
| 287 |
+
---
|
| 288 |
+
|
| 289 |
+
### **2. Step-by-Step Reasoning**
|
| 290 |
+
|
| 291 |
+
#### **2.1 Part Listing**
|
| 292 |
+
|
| 293 |
+
- `base`: 1
|
| 294 |
+
- `door`: 4
|
| 295 |
+
- `drawer`: 4
|
| 296 |
+
- `handle`: 8
|
| 297 |
+
|
| 298 |
+
#### **2.2 Validation**
|
| 299 |
+
|
| 300 |
+
- ✅ Exactly one base
|
| 301 |
+
- ✅ Each door has 1 handle → valid
|
| 302 |
+
- ✅ Each drawer has 1 handle → valid
|
| 303 |
+
- ✅ No trays → ✓ valid (not a microwave)
|
| 304 |
+
- ✅ Every handle is attached to a door or drawer (confirmed via open-state image)
|
| 305 |
+
|
| 306 |
+
✔ All structural rules are satisfied.
|
| 307 |
+
|
| 308 |
+
---
|
| 309 |
+
|
| 310 |
+
#### **2.3 Attachment Inference**
|
| 311 |
+
|
| 312 |
+
- Each **door** is hinged to the **base**
|
| 313 |
+
- Each **drawer** slides out of the **base**
|
| 314 |
+
- Each **handle** is mounted on one **door** or **drawer**
|
| 315 |
+
|
| 316 |
+
---
|
| 317 |
+
|
| 318 |
+
### **3. Final Connectivity Graph**
|
| 319 |
+
|
| 320 |
+
```json
|
| 321 |
+
{
|
| 322 |
+
"base": [
|
| 323 |
+
{ "door_1": [ { "handle_1": [] } ] },
|
| 324 |
+
{ "door_2": [ { "handle_2": [] } ] },
|
| 325 |
+
{ "door_3": [ { "handle_3": [] } ] },
|
| 326 |
+
{ "door_4": [ { "handle_4": [] } ] },
|
| 327 |
+
{ "drawer_1": [ { "handle_5": [] } ] },
|
| 328 |
+
{ "drawer_2": [ { "handle_6": [] } ] },
|
| 329 |
+
{ "drawer_3": [ { "handle_7": [] } ] },
|
| 330 |
+
{ "drawer_4": [ { "handle_8": [] } ] }
|
| 331 |
+
]
|
| 332 |
+
}
|
| 333 |
+
```
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
answer_all = [answer_sys, answer_1, answer_2, answer_3, answer_4]
|
| 337 |
+
messages = [
|
| 338 |
+
{"role": "user", "content": system_prompt},
|
| 339 |
+
{"role": "assistant", "content": answer_sys}
|
| 340 |
+
]
|
| 341 |
+
|
| 342 |
+
for i in range(4):
|
| 343 |
+
root_path = './scripts/imgs_reference/'
|
| 344 |
+
close_path = os.path.join(root_path, f'close{i + 1}.png')
|
| 345 |
+
open_path = os.path.join(root_path, f'open{i + 1}.png')
|
| 346 |
+
close_img = encode_image(close_path, center_crop=False)
|
| 347 |
+
open_img = encode_image(open_path, center_crop=False)
|
| 348 |
+
messages.append(
|
| 349 |
+
{
|
| 350 |
+
"role": "user",
|
| 351 |
+
"content": [
|
| 352 |
+
{
|
| 353 |
+
"type": "image_url",
|
| 354 |
+
"image_url": {"url": f"data:image/png;base64,{close_img}"},
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"type": "image_url",
|
| 358 |
+
"image_url": {"url": f"data:image/png;base64,{open_img}"},
|
| 359 |
+
}
|
| 360 |
+
],
|
| 361 |
+
}
|
| 362 |
+
)
|
| 363 |
+
messages.append({"role": "assistant", "content": answer_all[i + 1]})
|
scripts/json2urdf.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from xml.etree.ElementTree import Element, SubElement, tostring, ElementTree
|
| 3 |
+
from xml.dom.minidom import parseString
|
| 4 |
+
import pybullet as p
|
| 5 |
+
import pybullet_data
|
| 6 |
+
import os
|
| 7 |
+
from PIL import Image # 使用Pillow保存图像
|
| 8 |
+
import numpy as np
|
| 9 |
+
import trimesh # 用于处理3D网格
|
| 10 |
+
import imageio
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
def degrees_to_radians(degrees):
|
| 14 |
+
"""Convert an angle from degrees to radians."""
|
| 15 |
+
return degrees * math.pi / 180.0
|
| 16 |
+
|
| 17 |
+
def ply_to_obj(ply_filename, obj_filename, urdf_filename, part, json_data):
|
| 18 |
+
"""Convert a PLY file to OBJ format using trimesh."""
|
| 19 |
+
base_path = '/'.join(urdf_filename.split('/')[:-1])
|
| 20 |
+
mesh = trimesh.load(os.path.join(base_path, ply_filename), force='mesh')
|
| 21 |
+
print(ply_filename, mesh.bounding_box.centroid)
|
| 22 |
+
# mesh.vertices -= (mesh.bounding_box.centroid)
|
| 23 |
+
# find base (parent == -1)
|
| 24 |
+
base_part_id = next((p['id'] for p in json_data['diffuse_tree'] if p['parent'] == -1), None)
|
| 25 |
+
if 'joint' in part.keys():
|
| 26 |
+
mesh.vertices -= (mesh.bounding_box.centroid + part['joint']['axis']['origin'])
|
| 27 |
+
while part['parent'] != base_part_id:
|
| 28 |
+
parent_part = next((p for p in json_data['diffuse_tree'] if p['id'] == part['parent']), None)
|
| 29 |
+
if parent_part is None:
|
| 30 |
+
break
|
| 31 |
+
mesh.vertices -= (parent_part['joint']['axis']['origin'])
|
| 32 |
+
part = parent_part
|
| 33 |
+
else:
|
| 34 |
+
mesh.vertices -= (mesh.bounding_box.centroid)
|
| 35 |
+
mesh.export(os.path.join(base_path, obj_filename))
|
| 36 |
+
|
| 37 |
+
def create_urdf_from_json(json_data, urdf_filename, parent_part=None):
|
| 38 |
+
robot = Element('robot', name='articulate_object')
|
| 39 |
+
|
| 40 |
+
def add_link(parent, part, urdf_filename, json_data, base=False):
|
| 41 |
+
link = SubElement(parent, 'link', name=f"link_{part['id']}")
|
| 42 |
+
|
| 43 |
+
# 将PLY文件转换为OBJ文件
|
| 44 |
+
ply_path = part['objs'][0]
|
| 45 |
+
obj_path = os.path.splitext(ply_path)[0] + '.obj'
|
| 46 |
+
# if not os.path.exists(obj_path):
|
| 47 |
+
ply_to_obj(ply_path, obj_path, urdf_filename, part, json_data)
|
| 48 |
+
|
| 49 |
+
visual = SubElement(link, 'visual')
|
| 50 |
+
origin = SubElement(visual, 'origin', xyz=" ".join(map(str, part['aabb']['center'])), rpy="0 0 0")
|
| 51 |
+
geometry = SubElement(visual, 'geometry')
|
| 52 |
+
mesh = SubElement(geometry, 'mesh', filename=obj_path) # 使用转换后的OBJ路径
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def add_joint(parent, child_part, parent_part):
|
| 56 |
+
joint = SubElement(parent, 'joint', name=f"{parent_part['id']}_{child_part['id']}_joint", type=child_part['joint']['type'])
|
| 57 |
+
|
| 58 |
+
# for i in range(3):
|
| 59 |
+
# child_part['joint']['axis']['origin'][i] -= (child_part['aabb']['size'][i])
|
| 60 |
+
origin = SubElement(joint, 'origin', xyz=" ".join(map(str, child_part['joint']['axis']['origin'])), rpy="0 0 0")
|
| 61 |
+
# origin = SubElement(joint, 'origin', xyz=" 0 0 0 ", rpy="0 0 0")
|
| 62 |
+
axis = SubElement(joint, 'axis', xyz=" ".join(map(str, child_part['joint']['axis']['direction'])))
|
| 63 |
+
if child_part['joint']['type'] == 'revolute':
|
| 64 |
+
child_part['joint']['range'][0] = degrees_to_radians(child_part['joint']['range'][0])
|
| 65 |
+
child_part['joint']['range'][1] = degrees_to_radians(child_part['joint']['range'][1])
|
| 66 |
+
|
| 67 |
+
lower, upper = child_part['joint']['range']
|
| 68 |
+
if upper < lower:
|
| 69 |
+
lower, upper = upper, lower
|
| 70 |
+
|
| 71 |
+
limit = SubElement(
|
| 72 |
+
joint, 'limit',
|
| 73 |
+
lower=str(lower),
|
| 74 |
+
upper=str(upper),
|
| 75 |
+
effort="10",
|
| 76 |
+
velocity="1"
|
| 77 |
+
)
|
| 78 |
+
parent_element = SubElement(joint, 'parent', link=f"link_{parent_part['id']}")
|
| 79 |
+
child_element = SubElement(joint, 'child', link=f"link_{child_part['id']}")
|
| 80 |
+
|
| 81 |
+
base_part = json_data['diffuse_tree'][0]
|
| 82 |
+
add_link(robot, base_part, urdf_filename, json_data, base=True)
|
| 83 |
+
|
| 84 |
+
for part in json_data['diffuse_tree'][1:]:
|
| 85 |
+
base_part = next((p for p in json_data['diffuse_tree'] if p['parent'] == -1), None)
|
| 86 |
+
parent_part = next((p for p in json_data['diffuse_tree'] if p['id'] == part['parent']), None)
|
| 87 |
+
add_link(robot, part, urdf_filename, json_data)
|
| 88 |
+
if parent_part:
|
| 89 |
+
add_joint(robot, part, parent_part)
|
| 90 |
+
|
| 91 |
+
xmlstr = parseString(tostring(robot)).toprettyxml(indent=" ")
|
| 92 |
+
|
| 93 |
+
with open(urdf_filename, "w") as f:
|
| 94 |
+
f.write(xmlstr)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def pybullet_render(urdf_path, target_dir, num_frames, distance=3, fov=60):
|
| 98 |
+
physicsClient = p.connect(p.DIRECT)
|
| 99 |
+
p.setAdditionalSearchPath(pybullet_data.getDataPath())
|
| 100 |
+
try:
|
| 101 |
+
robot = p.loadURDF(urdf_path, [0, 0, 0])
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print(e)
|
| 104 |
+
return
|
| 105 |
+
for i in range(-1, p.getNumJoints(robot)):
|
| 106 |
+
rgba = [np.random.uniform(0.2, 1.0), np.random.uniform(0.2, 1.0), np.random.uniform(0.2, 1.0), 1]
|
| 107 |
+
p.changeVisualShape(robot, linkIndex=i, rgbaColor=rgba)
|
| 108 |
+
p.resetBasePositionAndOrientation(robot, [0, 0, 0], [0, 0.7071, 0.7071, 0])
|
| 109 |
+
joint_info = []
|
| 110 |
+
for i in range(p.getNumJoints(robot)):
|
| 111 |
+
info = p.getJointInfo(robot, i)
|
| 112 |
+
if info[2] != p.JOINT_FIXED:
|
| 113 |
+
joint_info.append({
|
| 114 |
+
'index': info[0],
|
| 115 |
+
'type': info[2],
|
| 116 |
+
'name': info[1].decode('utf-8'),
|
| 117 |
+
'lower_limit': info[8],
|
| 118 |
+
'upper_limit': info[9],
|
| 119 |
+
'initial_position': p.getJointState(robot, info[0])[0]
|
| 120 |
+
})
|
| 121 |
+
|
| 122 |
+
joint_positions = {}
|
| 123 |
+
for joint in joint_info:
|
| 124 |
+
start = joint['lower_limit']
|
| 125 |
+
end = joint['upper_limit']
|
| 126 |
+
joint_positions[joint['index']] = np.concatenate((np.linspace(start, end, num_frames), np.linspace(end, start, num_frames)))
|
| 127 |
+
|
| 128 |
+
gif_frames = []
|
| 129 |
+
for frame in range(num_frames*2):
|
| 130 |
+
# import pdb; pdb.set_trace()
|
| 131 |
+
for joint in joint_info:
|
| 132 |
+
p.resetJointState(robot, joint['index'], joint_positions[joint['index']][frame])
|
| 133 |
+
joint_state = p.getJointState(robot, joint['index'])
|
| 134 |
+
p.stepSimulation()
|
| 135 |
+
viewMatrix = p.computeViewMatrixFromYawPitchRoll(
|
| 136 |
+
cameraTargetPosition=[0, 0, 0],
|
| 137 |
+
distance=3.0,
|
| 138 |
+
yaw=-150,
|
| 139 |
+
pitch=-10,
|
| 140 |
+
roll=0,
|
| 141 |
+
upAxisIndex=2
|
| 142 |
+
)
|
| 143 |
+
projectionMatrix=p.computeProjectionMatrixFOV(
|
| 144 |
+
fov=fov, ##60
|
| 145 |
+
aspect=1.0,
|
| 146 |
+
nearVal=0.1, farVal=100)
|
| 147 |
+
|
| 148 |
+
width, height, rgbPixels, depthBuffer, segMask = p.getCameraImage(
|
| 149 |
+
width=1024, height=1024, viewMatrix=viewMatrix,
|
| 150 |
+
projectionMatrix=projectionMatrix,
|
| 151 |
+
renderer=p.ER_BULLET_HARDWARE_OPENGL)
|
| 152 |
+
|
| 153 |
+
#get rgba image
|
| 154 |
+
rgba_image = np.reshape(rgbPixels, (height, width, 4))
|
| 155 |
+
# rgba_image[np.all(rgba_image[:, :, :3] == 255, axis=-1)] = [0, 0, 0, 0]
|
| 156 |
+
gif_frames.append(rgba_image[:, :, :3])
|
| 157 |
+
|
| 158 |
+
p.disconnect()
|
| 159 |
+
imageio.mimsave(f'{target_dir}/animation.gif', gif_frames, fps=8, loop=0)
|
| 160 |
+
|
scripts/mesh_retrieval/retrieve.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 4 |
+
import json
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
from retrieval.obj_retrieval import find_obj_candidates, pick_and_rescale_parts
|
| 8 |
+
import trimesh
|
| 9 |
+
import shutil
|
| 10 |
+
|
| 11 |
+
def _retrieve_part_meshes(info_dict, save_dir, gt_data_root):
|
| 12 |
+
mesh_save_dir = os.path.join(save_dir, "plys")
|
| 13 |
+
obj_save_dir = os.path.join(save_dir, "objs")
|
| 14 |
+
os.makedirs(mesh_save_dir, exist_ok=True)
|
| 15 |
+
os.makedirs(obj_save_dir, exist_ok=True)
|
| 16 |
+
print(save_dir)
|
| 17 |
+
if os.path.exists(os.path.join(save_dir, "object.ply")):
|
| 18 |
+
return
|
| 19 |
+
HASHBOOK_PATH = "retrieval/retrieval_hash_no_handles.json"
|
| 20 |
+
|
| 21 |
+
obj_candidates = find_obj_candidates(
|
| 22 |
+
info_dict,
|
| 23 |
+
gt_data_root,
|
| 24 |
+
HASHBOOK_PATH,
|
| 25 |
+
gt_file_name="object.json",
|
| 26 |
+
num_states=5,
|
| 27 |
+
metric_num_samples=4096,
|
| 28 |
+
keep_top=3,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
retrieved_mesh_specs = pick_and_rescale_parts(
|
| 32 |
+
info_dict, obj_candidates, gt_data_root, gt_file_name="object.json"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
scene = trimesh.Scene()
|
| 36 |
+
for i, mesh_spec in enumerate(retrieved_mesh_specs):
|
| 37 |
+
part_spec = info_dict["diffuse_tree"][i]
|
| 38 |
+
current_part_meshes = []
|
| 39 |
+
file_paths = []
|
| 40 |
+
for file in mesh_spec["files"]:
|
| 41 |
+
file = os.path.join(mesh_spec["dir"], file).replace("ply", "obj")
|
| 42 |
+
file_paths.append(file)
|
| 43 |
+
m = trimesh.load(file, force="mesh")
|
| 44 |
+
current_part_meshes.append(m)
|
| 45 |
+
|
| 46 |
+
if not current_part_meshes:
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
bounds = np.array([m.bounds for m in current_part_meshes])
|
| 50 |
+
min_extents = bounds[:, 0, :].min(axis=0)
|
| 51 |
+
max_extents = bounds[:, 1, :].max(axis=0)
|
| 52 |
+
group_centroid = (min_extents + max_extents) / 2.0
|
| 53 |
+
|
| 54 |
+
transformation = trimesh.transformations.compose_matrix(
|
| 55 |
+
scale=mesh_spec["scale_factor"],
|
| 56 |
+
angles=[0, 0, np.radians(90) if mesh_spec["z_rotate_90"] else 0],
|
| 57 |
+
translate=part_spec["aabb"]["center"],
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
part_scene = trimesh.Scene()
|
| 61 |
+
for mesh in current_part_meshes:
|
| 62 |
+
mesh.vertices -= group_centroid
|
| 63 |
+
mesh.apply_transform(transformation)
|
| 64 |
+
part_scene.add_geometry(mesh)
|
| 65 |
+
scene.add_geometry(mesh)
|
| 66 |
+
|
| 67 |
+
obj_path = os.path.join(obj_save_dir, f"part_{i}/part_{i}.obj")
|
| 68 |
+
os.makedirs(os.path.dirname(obj_path), exist_ok=True)
|
| 69 |
+
part_scene.export(obj_path, include_texture=True)
|
| 70 |
+
info_dict["diffuse_tree"][i]["objs"] = [f"objs/part_{i}/part_{i}.obj"]
|
| 71 |
+
|
| 72 |
+
scene.export(os.path.join(save_dir, "object.ply"))
|
| 73 |
+
del mesh, scene
|
| 74 |
+
return info_dict
|
| 75 |
+
|
| 76 |
+
def main(args):
|
| 77 |
+
with open(os.path.join(args.src_dir, args.json_name), "r") as f:
|
| 78 |
+
info_dict = json.load(f)
|
| 79 |
+
|
| 80 |
+
if 'meta' not in info_dict.keys():
|
| 81 |
+
info_dict['meta'] = {
|
| 82 |
+
'obj_cat': 'StroageFurniture'
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
updated_json = _retrieve_part_meshes(info_dict, args.src_dir, args.gt_data_root)
|
| 86 |
+
|
| 87 |
+
if updated_json is not None:
|
| 88 |
+
with open(os.path.join(args.src_dir, args.json_name), "w") as f:
|
| 89 |
+
json.dump(updated_json, f)
|
| 90 |
+
|
| 91 |
+
if __name__ == '__main__':
|
| 92 |
+
parser = argparse.ArgumentParser()
|
| 93 |
+
parser.add_argument('--src_dir', type=str, required=True, help='path to the directory containing object.json')
|
| 94 |
+
parser.add_argument('--json_name', type=str, default='object.json', help='name of the json file')
|
| 95 |
+
parser.add_argument('--gt_data_root', type=str, default='./', help='path to the ground truth data')
|
| 96 |
+
args = parser.parse_args()
|
| 97 |
+
main(args)
|
scripts/mesh_retrieval/retrieve_gpt.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import argparse
|
| 4 |
+
from tqdm.contrib.concurrent import process_map
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
def run_retrieve(src_dir, json_name, data_root):
|
| 8 |
+
fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root]
|
| 9 |
+
try:
|
| 10 |
+
subprocess.run(fn_call, check=True, stderr=subprocess.STDOUT)
|
| 11 |
+
except subprocess.CalledProcessError as e:
|
| 12 |
+
print(f'Error from run_retrieve: {src_dir}')
|
| 13 |
+
print(f'Error: {e}')
|
| 14 |
+
return ' '.join(fn_call)
|
| 15 |
+
|
| 16 |
+
if __name__ == '__main__':
|
| 17 |
+
root_path = '/home/users/ruiqi.wu/manipulate_3d_generate/data/gpt_blender/'
|
| 18 |
+
for class_name in os.listdir(root_path):
|
| 19 |
+
if class_name == 'StroageFurniture':
|
| 20 |
+
for model_id in os.listdir(os.path.join(root_path, class_name)):
|
| 21 |
+
json_path = os.path.join(root_path, class_name, model_id, 'object.json')
|
| 22 |
+
object_path = os.path.join(root_path, class_name, model_id, 'object.ply')
|
| 23 |
+
if os.path.exists(json_path):
|
| 24 |
+
if not os.path.exists(object_path):
|
| 25 |
+
print(json_path)
|
| 26 |
+
src_dir = os.path.join(root_path, class_name, model_id)
|
| 27 |
+
json_name = 'object.json'
|
| 28 |
+
data_root = '../singapo'
|
| 29 |
+
run_retrieve(src_dir, json_name, data_root)
|
scripts/mesh_retrieval/run_retrieve.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import argparse
|
| 4 |
+
from tqdm.contrib.concurrent import process_map
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
def run_retrieve(src_dir, json_name, data_root):
|
| 8 |
+
if 'StorageFurniture' not in src_dir and 'Table' not in src_dir:
|
| 9 |
+
data_root = '../acd_data/merged-data'
|
| 10 |
+
fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root]
|
| 11 |
+
try:
|
| 12 |
+
subprocess.run(fn_call, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
|
| 13 |
+
except subprocess.CalledProcessError as e:
|
| 14 |
+
print(f'Error from run_retrieve: {src_dir}')
|
| 15 |
+
print(f'Error: {e}')
|
| 16 |
+
return ' '.join(fn_call)
|
| 17 |
+
|
| 18 |
+
if __name__ == '__main__':
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--src", type=str, required=True, help="path to the experiment folder")
|
| 21 |
+
parser.add_argument("--json_name", type=str, default="object.json", help="name of the json file")
|
| 22 |
+
parser.add_argument("--gt_data_root", type=str, default="../data", help="path to the ground truth data")
|
| 23 |
+
parser.add_argument("--max_workers", type=int, default=6, help="number of images to render for each object")
|
| 24 |
+
args = parser.parse_args()
|
| 25 |
+
|
| 26 |
+
assert os.path.exists(args.src), f"Src path does not exist: {args.src}"
|
| 27 |
+
assert os.path.exists(args.gt_data_root), f"GT data root does not exist: {args.gt_data_root}"
|
| 28 |
+
|
| 29 |
+
exp_path = args.src
|
| 30 |
+
# len_root = len(exp)
|
| 31 |
+
print('----------- Retrieve Part Meshes -----------')
|
| 32 |
+
src_dirs = []
|
| 33 |
+
# exps = os.listdir(root)
|
| 34 |
+
# for exp in exps:
|
| 35 |
+
# exp_path = os.path.join(root, exp)
|
| 36 |
+
for model_id in os.listdir(exp_path):
|
| 37 |
+
model_id_path = os.path.join(exp_path, model_id)
|
| 38 |
+
# print(model_id_path)
|
| 39 |
+
if '.' in model_id:
|
| 40 |
+
continue
|
| 41 |
+
for model_id_id in os.listdir(model_id_path):
|
| 42 |
+
if '.' not in model_id_id:
|
| 43 |
+
model_id_id_path = os.path.join(model_id_path, model_id_id)
|
| 44 |
+
for json_file in os.listdir(model_id_id_path):
|
| 45 |
+
if json_file.endswith(args.json_name):
|
| 46 |
+
if os.path.exists(os.path.join(model_id_id_path, 'object.ply')):
|
| 47 |
+
print(f"Found {model_id_id_path} with object.ply")
|
| 48 |
+
else:
|
| 49 |
+
# run_retrieve(model_id_id_path, json_name=args.json_name, data_root=args.gt_data_root)
|
| 50 |
+
src_dirs.append(model_id_id_path)
|
| 51 |
+
print(len(src_dirs), model_id_id_path)
|
| 52 |
+
# for dirpath, dirname, fnames in os.walk(root):
|
| 53 |
+
# for fname in fnames:
|
| 54 |
+
# if fname.endswith(args.json_name):
|
| 55 |
+
# src_dirs.append(dirpath) # save the relative directory path
|
| 56 |
+
# print(root)
|
| 57 |
+
print(f"Found {len(src_dirs)} jsons to retrieve part meshes")
|
| 58 |
+
# print(src_dirs)
|
| 59 |
+
# import ipdb
|
| 60 |
+
# ipdb.set_trace()
|
| 61 |
+
|
| 62 |
+
# for src_dir in src_dirs:
|
| 63 |
+
# print(src_dir)
|
| 64 |
+
# command = run_retrieve(src_dir, json_name=args.json_name, data_root=args.gt_data_root)
|
| 65 |
+
# command_file = open('retrieve_commands.sh', 'a')
|
| 66 |
+
# command_file.write(command + '\n')
|
| 67 |
+
# command_file.close()
|
| 68 |
+
process_map(partial(run_retrieve, json_name=args.json_name, data_root=args.gt_data_root), src_dirs, max_workers=6, chunksize=1)
|