xinjie.wang commited on
Commit
c28dddb
·
1 Parent(s): 93b8c4c

init commit

Browse files
Files changed (48) hide show
  1. app.py +133 -0
  2. configs/config.yaml +93 -0
  3. dataset/__init__.py +13 -0
  4. dataset/base_dataset.py +404 -0
  5. dataset/data_module.py +82 -0
  6. dataset/mydataset.py +282 -0
  7. dataset/utils.py +194 -0
  8. inference.py +450 -0
  9. lightning_logs/version_0/hparams.yaml +1 -0
  10. lightning_logs/version_1/hparams.yaml +1 -0
  11. lightning_logs/version_2/hparams.yaml +1 -0
  12. lightning_logs/version_3/hparams.yaml +1 -0
  13. lightning_logs/version_4/hparams.yaml +1 -0
  14. lightning_logs/version_5/hparams.yaml +1 -0
  15. lightning_logs/version_6/hparams.yaml +111 -0
  16. lightning_logs/version_6/metrics.csv +4 -0
  17. metrics/__init__.py +0 -0
  18. metrics/aor.py +44 -0
  19. metrics/cd.py +284 -0
  20. metrics/giou.py +142 -0
  21. metrics/iou.py +220 -0
  22. metrics/iou_cdist.py +227 -0
  23. models/__init__.py +19 -0
  24. models/denoiser.py +415 -0
  25. models/utils.py +199 -0
  26. my_utils/__init__.py +0 -0
  27. my_utils/callbacks.py +36 -0
  28. my_utils/lr_schedulers.py +104 -0
  29. my_utils/misc.py +35 -0
  30. my_utils/plot.py +122 -0
  31. my_utils/refs.py +122 -0
  32. my_utils/render.py +482 -0
  33. my_utils/savermixins.py +55 -0
  34. objects/__init__.py +0 -0
  35. objects/dict_utils.py +299 -0
  36. objects/motions.py +99 -0
  37. requirements.txt +21 -0
  38. retrieval/__init__.py +0 -0
  39. retrieval/obj_retrieval.py +509 -0
  40. retrieval/retrieval_hash_acd.json +329 -0
  41. retrieval/retrieval_hash_no_handles.json +722 -0
  42. scripts/graph_pred/api.py +210 -0
  43. scripts/graph_pred/eval.py +62 -0
  44. scripts/graph_pred/prompt_workflow_new.py +363 -0
  45. scripts/json2urdf.py +160 -0
  46. scripts/mesh_retrieval/retrieve.py +97 -0
  47. scripts/mesh_retrieval/retrieve_gpt.py +29 -0
  48. scripts/mesh_retrieval/run_retrieve.py +68 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shutil
4
+ import zipfile
5
+ from types import SimpleNamespace
6
+ from inference import run_demo, load_config
7
+ import random
8
+ import string
9
+ from gradio.themes import Soft
10
+ from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc
11
+
12
+ custom_theme = Soft(
13
+ primary_hue=stone,
14
+ secondary_hue=gray,
15
+ radius_size="md",
16
+ text_size="sm",
17
+ spacing_size="sm",
18
+ )
19
+
20
+
21
+ def inference_ui(img1, img2, omega, n_denoise_steps):
22
+ tmpdir = 'results'
23
+ random_str = ''.join(random.choices(string.ascii_letters, k=16))
24
+ tmpdir = tmpdir + "_" + random_str
25
+
26
+ # 删除所有包含 "results" 的目录
27
+ for dir in os.listdir('.'):
28
+ if dir.startswith('results') and os.path.isdir(dir):
29
+ shutil.rmtree(dir)
30
+ os.makedirs(os.path.join(tmpdir, "0"), exist_ok=True)
31
+
32
+ args = SimpleNamespace(
33
+ img_path_1=img1,
34
+ img_path_2=img2,
35
+ ckpt_path='ckpts/dipo.ckpt',
36
+ config_path='configs/config.yaml',
37
+ use_example_graph=False,
38
+ save_dir=tmpdir,
39
+ gt_data_root='./data/PartnetMobility',
40
+ n_samples=3,
41
+ omega=omega,
42
+ n_denoise_steps=n_denoise_steps,
43
+ )
44
+ args.config = load_config(args.config_path)
45
+ run_demo(args)
46
+
47
+ gif_path = os.path.join(tmpdir, "0", "animation.gif")
48
+ ply_path = os.path.join(tmpdir, "0", "object.ply")
49
+ glb_path = os.path.join(tmpdir, "0", "object.glb")
50
+
51
+ # 压缩结果为ZIP包
52
+ zip_path = os.path.join(tmpdir, "output.zip")
53
+ folder_to_zip = os.path.join(tmpdir, "0")
54
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
55
+ for root, dirs, files in os.walk(folder_to_zip):
56
+ for file in files:
57
+ abs_path = os.path.join(root, file)
58
+ rel_path = os.path.relpath(abs_path, folder_to_zip)
59
+ zipf.write(abs_path, arcname=rel_path)
60
+
61
+ return (
62
+ gif_path if os.path.exists(gif_path) else None,
63
+ zip_path if os.path.exists(zip_path) else None
64
+ )
65
+
66
+ def prepare_data():
67
+ if not os.path.exists("data") or not os.path.exists("saved_model"):
68
+ print("Downloading data.tar from Hugging Face Datasets...")
69
+ os.system("wget https://huggingface.co/datasets/wuruiqi0722/DIPO_data/resolve/main/data/data.tar -O data.tar")
70
+ os.system("tar -xf data.tar")
71
+
72
+ with gr.Blocks(theme=custom_theme) as demo:
73
+ gr.Markdown("## DIPO: Dual-State Images Controlled Articulated Object Generation Powered by Diverse Data")
74
+ gr.Markdown(
75
+ """
76
+ <p style="display: flex; gap: 10px; flex-wrap: nowrap;">
77
+ <a href="https://rq-wu.github.io/projects/DIPO">
78
+ <img alt="📖 Project Page" src="https://img.shields.io/badge/📖-Project_Page-blue">
79
+ </a>
80
+ <a href="https://arxiv.org/abs/2505.20460">
81
+ <img alt="📄 arXiv" src="https://img.shields.io/badge/📄-arXiv-b31b1b">
82
+ </a>
83
+ <a href="https://github.com/RQ-Wu/DIPO">
84
+ <img alt="💻 GitHub" src="https://img.shields.io/badge/GitHub-000000?logo=github">
85
+ </a>
86
+ </p>
87
+ """
88
+ )
89
+ gr.Markdown("Currently, only the articulated object in following categories are supported: Table, Dishwasher, StorageFurniture, Refrigerator, WashingMachine, Microwave, Oven.")
90
+
91
+ with gr.Row():
92
+ with gr.Column(scale=1):
93
+ img1_input = gr.Image(label="Image: Closed State", type="filepath", height=250)
94
+ img2_input = gr.Image(label="Image: Opened State", type="filepath", height=250)
95
+ omega = gr.Slider(0.0, 1.0, step=0.1, value=0.5, label="Omega (CFG Guidance)")
96
+ n_denoise = gr.Slider(10, 200, step=10, value=100, label="Denoising Steps")
97
+ run_button = gr.Button("🚀 Run Generation (~2mins)")
98
+
99
+ with gr.Column(scale=1):
100
+ output_gif = gr.Image(label="GIF Animation", type="filepath", height=678, width=10000)
101
+ zip_download_btn = gr.DownloadButton(label="📦 Download URDF folder", interactive=False)
102
+
103
+ gr.Examples(
104
+ examples=[
105
+ ["examples/1.png", "examples/1_open_1.png"],
106
+ ["examples/1.png", "examples/1_open_2.png"],
107
+ ["examples/close1.png", "examples/open1.png"],
108
+ # ["examples/close2.png", "examples/open2.png"],
109
+ ["examples/close3.png", "examples/open3.png"],
110
+ # ["examples/close4.png", "examples/open4.png"],
111
+ ["examples/close5.png", "examples/open5.png"],
112
+ ["examples/close6.png", "examples/open6.png"],
113
+ ["examples/close7.png", "examples/open7.png"],
114
+ ["examples/close8.png", "examples/open8.png"],
115
+ ["examples/close9.jpg", "examples/open9.jpg"],
116
+ ["examples/close10.png", "examples/open10.png"],
117
+ ],
118
+ inputs=[img1_input, img2_input],
119
+ label="📂 Example Inputs"
120
+ )
121
+
122
+ run_button.click(
123
+ fn=inference_ui,
124
+ inputs=[img1_input, img2_input, omega, n_denoise],
125
+ outputs=[output_gif, zip_download_btn]
126
+ ).success(
127
+ lambda: gr.DownloadButton(interactive=True),
128
+ outputs=[zip_download_btn]
129
+ )
130
+
131
+ if __name__ == "__main__":
132
+ prepare_data()
133
+ demo.launch()
configs/config.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dipo
2
+ version: denoiser
3
+
4
+ data:
5
+ name: dm_dipo
6
+ json_root: data_path
7
+ root: data_path # root directory of the dataset
8
+ batch_size: 20 # batch size for training
9
+ num_workers: 8 # number of workers for data loading
10
+ K: 32 # maximum number of nodes (parts) in the graph (object)
11
+ split_file: split_file_path
12
+ n_views_per_model: 20
13
+ frame_mode: last_frame
14
+ test_which: pm
15
+ mode_num: 5
16
+
17
+ system:
18
+ name: sys_origin
19
+ exp_dir: ./exps/${name}/${version}
20
+ data_root: ${data.root}
21
+ n_time_samples: 16
22
+ loss_fg_weight: 0.01
23
+ img_drop_prob: 0.1 # image dropout probability, for classifier free training
24
+ guidance_scaler: 0.5 # scaling factor for guidance on the image during inference
25
+ graph_drop_prob: 0.5 # graph dropout probability, for classifier free training
26
+
27
+ model:
28
+ name: denoiser
29
+ in_ch: 6
30
+ attn_dim: 128
31
+ n_head: 4
32
+ n_layers: 6
33
+ dropout: 0.1
34
+ K: ${data.K}
35
+ mode_num: 5
36
+ img_emb_dims: [768, 128]
37
+ cat_drop_prob: 0.5 # object category dropout probability, for classifier free training
38
+
39
+ scheduler: # scheduler for the diffusion model
40
+ name: ddpm
41
+ config:
42
+ num_train_timesteps: 1000
43
+ beta_schedule: linear
44
+ prediction_type: epsilon
45
+
46
+ lr_scheduler_adapter: # lr scheduler for the new modules on top of the base model
47
+ name: LinearWarmupCosineAnnealingLR
48
+ warmup_epochs: 3
49
+ max_epochs: ${trainer.max_epochs}
50
+ warmup_start_lr: 1e-6
51
+ eta_min: 1e-5
52
+
53
+ optimizer_adapter: # optimizer for the new modules on top of the base model
54
+ name: AdamW
55
+ args:
56
+ lr: 5e-4
57
+ betas: [0.9, 0.99]
58
+ eps: 1.e-15
59
+
60
+ lr_scheduler_cage: # lr scheduler for modules in the base model
61
+ name: LinearWarmupCosineAnnealingLR
62
+ warmup_epochs: 3
63
+ max_epochs: ${trainer.max_epochs}
64
+ warmup_start_lr: 1e-6
65
+ eta_min: 1e-5
66
+
67
+ optimizer_cage: # optimizer for modules in the base model
68
+ name: AdamW
69
+ args:
70
+ lr: 5e-5
71
+ betas: [0.9, 0.99]
72
+ eps: 1.e-15
73
+
74
+ checkpoint:
75
+ dirpath: ${system.exp_dir}/ckpts
76
+ save_top_k: -1
77
+ every_n_epochs: 50
78
+
79
+ logger: # wandb logger
80
+ save_dir: ${system.exp_dir}/logs # directory to save logs
81
+ name: ${name}_${version}
82
+ project: SINGAPO
83
+
84
+ trainer:
85
+ max_epochs: 200
86
+ log_every_n_steps: 100
87
+ limit_train_batches: 1.0
88
+ limit_val_batches: 1.0
89
+ check_val_every_n_epoch: 10
90
+ precision: 16-mixed
91
+ profiler: simple
92
+ num_sanity_val_steps: -1
93
+
dataset/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datamodules = {}
2
+
3
+ def register(name):
4
+ def decorator(cls):
5
+ datamodules[name] = cls
6
+ return cls
7
+ return decorator
8
+
9
+ def make(name, config):
10
+ dm = datamodules[name](config)
11
+ return dm
12
+
13
+ from . import data_module
dataset/base_dataset.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import json
3
+ import numpy as np
4
+ # import collections.abc
5
+ # sys.modules['collections'].Mapping = collections.abc.Mapping
6
+
7
+ import networkx as nx
8
+ from torch.utils.data import Dataset
9
+ from my_utils.refs import cat_ref, sem_ref, joint_ref, data_mode_ref
10
+ from collections import deque
11
+
12
+ def build_graph(tree, K=32):
13
+ '''
14
+ Function to build graph from the node list.
15
+
16
+ Args:
17
+ nodes: list of nodes
18
+ K: the maximum number of nodes in the graph
19
+ Returns:
20
+ adj: adjacency matrix, records the 1-ring relationship (parent+children) between nodes
21
+ edge_list: list of edges, for visualization
22
+ '''
23
+ adj = np.zeros((K, K), dtype=np.float32)
24
+ parents = []
25
+ tree_list = []
26
+ for node in tree:
27
+ tree_list.append(
28
+ {
29
+ 'id': node['id'],
30
+ 'parent_id': node['parent'],
31
+ }
32
+ )
33
+ # 1-ring relationship
34
+ if node['parent'] != -1:
35
+ adj[node['id'], node['parent']] = 1
36
+ parents.append(node['parent'])
37
+ else:
38
+ adj[node['id'], node['id']] = 1
39
+ parents.append(-1)
40
+ for child_id in node['children']:
41
+ adj[node['id'], child_id] = 1
42
+ return {
43
+ 'adj': adj,
44
+ 'parents': np.array(parents, dtype=np.int8),
45
+ 'tree_list': tree_list
46
+ }
47
+
48
+ from collections import defaultdict
49
+ from functools import cmp_to_key
50
+
51
+ def bfs_tree_simple(tree_list):
52
+ order = [0] * len(tree_list)
53
+ queue = []
54
+ current_node_idx = 0
55
+ for node_idx, node in enumerate(tree_list):
56
+ if node['parent_id'] == -1:
57
+ queue.append(node['id'])
58
+ order[node_idx] = current_node_idx
59
+ current_node_idx += 1
60
+ break
61
+ while len(queue) > 0:
62
+ current_node = queue.pop(0)
63
+ for node_idx, node in enumerate(tree_list):
64
+ if node['parent_id'] == current_node:
65
+ queue.append(node['id'])
66
+ order[node_idx] = current_node_idx
67
+ current_node_idx += 1
68
+
69
+ return order
70
+
71
+ def bfs_tree(tree_list, aabb_list, epsilon=1e-3):
72
+ # 初始化遍历顺序列表
73
+ order = [0] * len(tree_list)
74
+ current_order = 0
75
+
76
+ # 构建父节点到子节点的索引映射
77
+ parent_map = defaultdict(list)
78
+ for idx, node in enumerate(tree_list):
79
+ parent_map[node['parent_id']].append(idx)
80
+
81
+ # 查找根节点
82
+ root_indices = [idx for idx, node in enumerate(tree_list) if node['parent_id'] == -1]
83
+ if not root_indices:
84
+ return order
85
+
86
+ # 初始化队列(存储节点索引)
87
+ queue = [root_indices[0]]
88
+ order[root_indices[0]] = current_order
89
+ current_order += 1
90
+
91
+ # 比较函数:按中心坐标排序
92
+ def compare_centers(a, b):
93
+ # 获取两个节点的中心坐标
94
+ center_a = [(aabb_list[a][i] + aabb_list[a][i+3])/2 for i in range(3)]
95
+ center_b = [(aabb_list[b][i] + aabb_list[b][i+3])/2 for i in range(3)]
96
+
97
+ # 逐级比较坐标(考虑epsilon阈值)
98
+ for coord in range(3):
99
+ delta = abs(center_a[coord] - center_b[coord])
100
+ if delta > epsilon:
101
+ return -1 if center_a[coord] < center_b[coord] else 1
102
+ return 0 # 所有坐标差均小于阈值时保持原顺序
103
+
104
+ # BFS遍历
105
+ while queue:
106
+ current_idx = queue.pop(0)
107
+ current_id = tree_list[current_idx]['id']
108
+
109
+ # 获取子节点索引并排序
110
+ children = parent_map.get(current_id, [])
111
+ sorted_children = sorted(children, key=cmp_to_key(compare_centers))
112
+
113
+ # 处理子节点
114
+ for child_idx in sorted_children:
115
+ order[child_idx] = current_order
116
+ current_order += 1
117
+ queue.append(child_idx)
118
+
119
+ return order
120
+
121
+ class BaseDataset(Dataset):
122
+ def __init__(self, hparams):
123
+ super().__init__()
124
+ self.hparams = hparams
125
+
126
+ def _filter_models(self, models_ids):
127
+ '''
128
+ Filter out models that has more than K nodes.
129
+ '''
130
+ json_data_root = self.hparams.json_root
131
+ filtered = []
132
+ for i, model_id in enumerate(models_ids):
133
+ if i % 100 == 0:
134
+ print(f'Checking model {i}/{len(models_ids)}')
135
+ path = os.path.join(json_data_root, model_id, self.json_name)
136
+ with open(path, 'r') as f:
137
+ json_file = json.load(f)
138
+ if len(json_file['diffuse_tree']) <= self.hparams.K:
139
+ filtered.append(model_id)
140
+ return filtered
141
+
142
+ def get_acd_mapping(self):
143
+ self.category_mapping = {
144
+ 'armoire': 'StorageFurniture',
145
+ 'bookcase': 'StorageFurniture',
146
+ 'chest_of_drawers': 'StorageFurniture',
147
+ 'desk': 'Table',
148
+ 'dishwasher': 'Dishwasher',
149
+ 'hanging_cabinet': 'StorageFurniture',
150
+ 'kitchen_cabinet': 'StorageFurniture',
151
+ 'microwave': 'Microwave',
152
+ 'nightstand': 'StorageFurniture',
153
+ 'oven': 'Oven',
154
+ 'refrigerator': 'Refrigerator',
155
+ 'sink_cabinet': 'StorageFurniture',
156
+ 'tv_stand': 'StorageFurniture',
157
+ 'washer': 'WashingMachine',
158
+ 'table': 'Table',
159
+ 'cabinet': 'StorageFurniture',
160
+ 'hanging_cabinet': 'StorageFurniture',
161
+ }
162
+
163
+ def _random_permute(self, graph, nodes):
164
+ '''
165
+ Function to randomly permute the nodes and update the graph and node attribute info.
166
+
167
+ Args:
168
+ graph: a dictionary containing the adjacency matrix, edge list, and root node
169
+ nodes: a list of nodes
170
+ Returns:
171
+ graph_permuted: a dictionary containing the updated adjacency matrix, edge list, and root node
172
+ nodes_permuted: a list of permuted nodes
173
+ '''
174
+ N = len(nodes)
175
+ order = np.random.permutation(N)
176
+ graph_permuted = self._reorder_nodes(graph, order)
177
+ exchange = [0] * len(order)
178
+ for i in range(len(order)):
179
+ exchange[order[i]] = i
180
+ nodes_permuted = nodes[exchange, :]
181
+ return graph_permuted, nodes_permuted
182
+
183
+ def _permute_by_order(self, graph, nodes, order):
184
+ '''
185
+ Function to permute the nodes and update the graph and node attribute info by order.
186
+
187
+ Args:
188
+ graph: a dictionary containing the adjacency matrix, edge list, and root node
189
+ nodes: a list of nodes
190
+ order: a list of indices for reordering
191
+ Returns:
192
+ graph_permuted: a dictionary containing the updated adjacency matrix, edge list, and root node
193
+ nodes_permuted: a list of permuted nodes
194
+ '''
195
+ graph_permuted = self._reorder_nodes(graph, order)
196
+ if nodes is None:
197
+ return graph_permuted, None
198
+ else:
199
+ exchange = [0] * len(order)
200
+ for i in range(len(order)):
201
+ exchange[order[i]] = i
202
+ nodes_permuted = nodes[exchange, :]
203
+ return graph_permuted, nodes_permuted
204
+
205
+ def _prepare_node_data(self, node):
206
+ # semantic label
207
+ label = np.array([sem_ref['fwd'][node['name']]], dtype=np.float32) / 5. - 0.8 # (1,), range from -0.8 to 0.8
208
+ # joint type
209
+ joint_type = np.array([joint_ref['fwd'][node['joint']['type']] / 5.], dtype=np.float32) - 0.5 # (1,), range from -0.8 to 0.8
210
+ # aabb
211
+ aabb_center = np.array(node['aabb']['center'], dtype=np.float32) # (3,), range from -1 to 1
212
+ aabb_size = np.array(node['aabb']['size'], dtype=np.float32) # (3,), range from -1 to 1
213
+ aabb_max = aabb_center + aabb_size / 2
214
+ aabb_min = aabb_center - aabb_size / 2
215
+ # joint axis and range
216
+ if node['joint']['type'] == 'fixed':
217
+ axis_dir = np.zeros((3,), dtype=np.float32)
218
+ axis_ori = aabb_center
219
+ joint_range = np.zeros((2,), dtype=np.float32)
220
+ else:
221
+ if node['joint']['type'] == 'revolute' or node['joint']['type'] == 'continuous':
222
+ joint_range = np.array([node['joint']['range'][1]], dtype=np.float32) / 360.
223
+ joint_range = np.concatenate([joint_range, np.zeros((1,), dtype=np.float32)], axis=0) # (2,)
224
+ elif node['joint']['type'] == 'prismatic' or node['joint']['type'] == 'screw':
225
+ joint_range = np.array([node['joint']['range'][1]], dtype=np.float32)
226
+ joint_range = np.concatenate([np.zeros((1,), dtype=np.float32), joint_range], axis=0) # (2,)
227
+ axis_dir = np.array(node['joint']['axis']['direction'], dtype=np.float32) * 0.7 # (3,), range from -0.7 to 0.7
228
+ # make sure the axis is pointing to the positive direction
229
+ if np.sum(axis_dir > 0) < np.sum(-axis_dir > 0):
230
+ axis_dir = -axis_dir
231
+ joint_range = -joint_range
232
+ axis_ori = np.array(node['joint']['axis']['origin'], dtype=np.float32) # (3,), range from -1 to 1
233
+ if (node['joint']['type'] == 'prismatic' or node['joint']['type'] == 'screw') and node['name'] != 'door':
234
+ axis_ori = aabb_center
235
+ # prepare node data by given mod name
236
+ # aabb = np.concatenate([aabb_max, aabb_min], axis=0)
237
+ # axis = np.concatenate([axis_dir, axis_ori], axis=0)
238
+ # node_data_all = [aabb, joint_type.repeat(6), axis, joint_range.repeat(3), label.repeat(6)]
239
+ # node_data_list = [node_data_all[data_mode_ref[mod_name]] for mod_name in self.hparams.data_mode]
240
+ # node_data = np.concatenate(node_data_list, axis=0)
241
+ node_label = np.ones(6, dtype=np.float32)
242
+
243
+ node_data = np.concatenate([aabb_max, aabb_min, joint_type.repeat(6), axis_dir, axis_ori, joint_range.repeat(3), label.repeat(6), node_label], axis=0)
244
+ if self.hparams.mode_num == 5:
245
+ node_data = np.concatenate([aabb_max, aabb_min, joint_type.repeat(6), axis_dir, axis_ori, joint_range.repeat(3), label.repeat(6)], axis=0)
246
+ return node_data
247
+
248
+
249
+ def _reorder_nodes(self, graph, order):
250
+ '''
251
+ Function to reorder nodes in the graph and
252
+ update the adjacency matrix, edge list, and root node.
253
+
254
+ Args:
255
+ graph: a dictionary containing the adjacency matrix, edge list, and root node
256
+ order: a list of indices for reordering
257
+ Returns:
258
+ new_graph: a dictionary containing the updated adjacency matrix, edge list, and root node
259
+ '''
260
+ N = len(order)
261
+ mapping = {i: order[i] for i in range(N)}
262
+ mapping.update({i: i for i in range(N, self.hparams.K)})
263
+ G = nx.from_numpy_array(graph['adj'], create_using=nx.Graph)
264
+ G_ = nx.relabel_nodes(G, mapping)
265
+ new_adj = nx.adjacency_matrix(G_, G.nodes).todense()
266
+
267
+ exchange = [0] * len(order)
268
+ for i in range(len(order)):
269
+ exchange[order[i]] = i
270
+ return {
271
+ 'adj': new_adj.astype(np.float32),
272
+ 'parents': graph['parents'][exchange]
273
+ }
274
+
275
+
276
+ def _prepare_input_GT(self, file, model_id):
277
+ '''
278
+ Function to parse input item from a json file for the CAGE training.
279
+ '''
280
+ tree = file['diffuse_tree']
281
+ K = self.hparams.K # max number of nodes
282
+ cond = {} # conditional information and axillary data
283
+ cond['parents'] = np.zeros(K, dtype=np.int8)
284
+
285
+ # prepare node data
286
+ nodes = []
287
+ for node in tree:
288
+ node_data = self._prepare_node_data(node) # (36,)
289
+ nodes.append(node_data)
290
+ nodes = np.array(nodes, dtype=np.float32)
291
+ n_nodes = len(nodes)
292
+
293
+ # prepare graph
294
+ graph = build_graph(tree, self.hparams.K)
295
+ if self.mode == 'train': # perturb the node order for training
296
+ graph, nodes = self._random_permute(graph, nodes)
297
+
298
+ # pad the nodes to K with empty nodes
299
+ if n_nodes < K:
300
+ empty_node = np.zeros((nodes[0].shape[0],))
301
+ data = np.concatenate([nodes, [empty_node] * (K - n_nodes)], axis=0, dtype=np.float32) # (K, 36)
302
+ else:
303
+ data = nodes
304
+ mode_num = data.shape[1] // 6
305
+ data = data.reshape(K*mode_num, 6) # (K * n_attr, 6)
306
+
307
+ # attr mask (for Local Attention)
308
+ attr_mask = np.eye(K, K, dtype=bool)
309
+ attr_mask = attr_mask.repeat(mode_num, axis=0).repeat(mode_num, axis=1)
310
+ cond['attr_mask'] = attr_mask
311
+
312
+ # key padding mask (for Global Attention)
313
+ pad_mask = np.zeros((K*mode_num, K*mode_num), dtype=bool)
314
+ pad_mask[:, :n_nodes*mode_num] = 1
315
+ cond['key_pad_mask'] = pad_mask
316
+
317
+ # adj mask (for Graph Relation Attention)
318
+ adj_mask = graph['adj'][:].astype(bool)
319
+ adj_mask = adj_mask.repeat(mode_num, axis=0).repeat(mode_num, axis=1)
320
+ adj_mask[n_nodes*mode_num:, :] = 1
321
+ cond['adj_mask'] = adj_mask
322
+
323
+ # object category
324
+ if self.map_cat: # for ACD dataset
325
+ category = file['meta']['obj_cat']
326
+ category = self.category_mapping[category]
327
+ cond['cat'] = cat_ref[category]
328
+ else:
329
+ cond['cat'] = cat_ref.get(file['meta']['obj_cat'], None)
330
+ if cond['cat'] is None:
331
+ cond['cat'] = self.category_mapping.get(file['meta']['obj_cat'], None)
332
+ if cond['cat'] is None:
333
+ cond['cat'] = 2
334
+ else:
335
+ cond['cat'] = cat_ref.get(cond['cat'], None)
336
+ # cond['cat'] = cat_ref[file['meta']['obj_cat']]
337
+ if cond['cat'] is None:
338
+ cond['cat'] = 2
339
+ # axillary info
340
+ cond['name'] = model_id
341
+ cond['adj'] = graph['adj']
342
+ cond['parents'][:n_nodes] = graph['parents']
343
+ cond['n_nodes'] = n_nodes
344
+ cond['obj_cat'] = file['meta']['obj_cat']
345
+
346
+ return data, cond
347
+
348
+ def _prepare_input(self, model_id, pred_file, gt_file=None):
349
+ '''
350
+ Function to parse input item from pred_file, and parse GT from gt_file if available.
351
+ '''
352
+ K = self.hparams.K # max number of nodes
353
+ cond = {} # conditional information and axillary data
354
+ # prepare node data
355
+ n_nodes = len(pred_file['diffuse_tree'])
356
+ # prepare graph
357
+ pred_graph = build_graph(pred_file['diffuse_tree'], K)
358
+ # dummy GT data
359
+ data = np.zeros((K*5, 6), dtype=np.float32)
360
+
361
+ # attr mask (for Local Attention)
362
+ attr_mask = np.eye(K, K, dtype=bool)
363
+ attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
364
+ cond['attr_mask'] = attr_mask
365
+
366
+ # key padding mask (for Global Attention)
367
+ pad_mask = np.zeros((K*5, K*5), dtype=bool)
368
+ pad_mask[:, :n_nodes*5] = 1
369
+ cond['key_pad_mask'] = pad_mask
370
+
371
+ # adj mask (for Graph Relation Attention)
372
+ adj_mask = pred_graph['adj'][:].astype(bool)
373
+ adj_mask = adj_mask.repeat(5, axis=0).repeat(5, axis=1)
374
+ adj_mask[n_nodes*5:, :] = 1
375
+ cond['adj_mask'] = adj_mask
376
+
377
+ # placeholder category, won't be used if category is given (below)
378
+ cond['cat'] = cat_ref['StorageFurniture']
379
+ cond['obj_cat'] = 'StorageFurniture'
380
+ # if object category is given as input
381
+ if not self.hparams.get('test_label_free', False):
382
+ assert 'meta' in pred_file, 'meta not found in the json file.'
383
+ assert 'obj_cat' in pred_file['meta'], 'obj_cat not found in the metadata of the json file.'
384
+ category = pred_file['meta']['obj_cat']
385
+ if self.map_cat: # for ACD dataset
386
+ category = self.category_mapping[category]
387
+ cond['cat'] = cat_ref[category]
388
+ cond['obj_cat'] = category
389
+
390
+ # axillary info
391
+ cond['name'] = model_id
392
+ cond['adj'] = pred_graph['adj']
393
+ cond['parents'] = np.zeros(K, dtype=np.int8)
394
+ cond['parents'][:n_nodes] = pred_graph['parents']
395
+ cond['n_nodes'] = n_nodes
396
+
397
+ return data, cond
398
+
399
+ def __getitem__(self, index):
400
+ raise NotImplementedError
401
+
402
+ def __len__(self):
403
+ raise NotImplementedError
404
+
dataset/data_module.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
3
+ import json
4
+ import dataset
5
+ import lightning.pytorch as pl
6
+ from torch.utils.data import DataLoader
7
+ from dataset.mydataset import MyDataset
8
+
9
+ @dataset.register("dm_dipo")
10
+ class DIPODataModule(pl.LightningDataModule):
11
+
12
+ def __init__(self, hparams):
13
+ super().__init__()
14
+ self.hparams.update(hparams)
15
+
16
+ def _prepare_split(self):
17
+ with open(self.hparams.split_file , "r") as f:
18
+ splits = json.load(f)
19
+
20
+ train_ids = splits["train"]
21
+ val_ids = [i for i in train_ids if "data" not in i]
22
+ return train_ids, val_ids
23
+
24
+ def _prepare_test_ids(self):
25
+ if "acd" in self.hparams.get('test_which'):
26
+ with open("/home/users/ruiqi.wu/singapo/data/data_acd.json", "r") as f:
27
+ file = json.load(f)
28
+ elif 'pm' in self.hparams.get('test_which'):
29
+ with open(self.hparams.split_file, "r") as f:
30
+ file = json.load(f)
31
+ else:
32
+ raise NotImplementedError(f"Dataset {self.hparams.get('test_which')} not implemented for SingapoDataModule")
33
+ ids = file['test']
34
+ return ids
35
+
36
+ def setup(self, stage=None):
37
+
38
+ if stage == "fit" or stage is None:
39
+ train_ids, val_ids = self._prepare_split()
40
+ val_ids = val_ids
41
+ self.train_dataset = MyDataset(self.hparams, model_ids=train_ids[:10], mode="train")
42
+ self.val_dataset = MyDataset(self.hparams, model_ids=val_ids[:50], mode="val")
43
+ elif stage == "validate":
44
+ val_ids = self._prepare_test_ids()
45
+ val_ids = val_ids
46
+ self.val_dataset = MyDataset(self.hparams, model_ids=val_ids, mode="val")
47
+ elif stage == "test":
48
+ test_ids = self._prepare_test_ids()
49
+ self.test_dataset = MyDataset(self.hparams, model_ids=test_ids, mode="test")
50
+ else:
51
+ raise NotImplementedError(f"Stage {stage} not implemented for SingapoDataModule")
52
+
53
+
54
+ def train_dataloader(self):
55
+ return DataLoader(
56
+ self.train_dataset,
57
+ batch_size=self.hparams.batch_size,
58
+ num_workers=self.hparams.num_workers,
59
+ pin_memory=True,
60
+ shuffle=True,
61
+ persistent_workers=True
62
+ )
63
+
64
+ def val_dataloader(self):
65
+ return DataLoader(
66
+ self.val_dataset,
67
+ batch_size=128,
68
+ num_workers=self.hparams.num_workers,
69
+ pin_memory=True,
70
+ shuffle=False,
71
+ persistent_workers=True
72
+ )
73
+
74
+ def test_dataloader(self):
75
+ return DataLoader(
76
+ self.test_dataset,
77
+ batch_size=1,
78
+ num_workers=self.hparams.num_workers,
79
+ pin_memory=True,
80
+ shuffle=False,
81
+ persistent_workers=True
82
+ )
dataset/mydataset.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
3
+ import json
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torchvision.transforms as T
7
+ from dataset.base_dataset import BaseDataset
8
+ import random
9
+ from tqdm import tqdm
10
+ import imageio
11
+ import torch
12
+
13
+ def make_white_background(src_img):
14
+ '''Make the white background for the input RGBA image.'''
15
+ src_img.load()
16
+ background = Image.new("RGB", src_img.size, (255, 255, 255))
17
+ background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
18
+ return background
19
+
20
+ class MyDataset(BaseDataset):
21
+
22
+ """
23
+ Dataset for training and testing on the PartNet-Mobility and ACD datasets (with our preprocessing).
24
+ The GT graph is given.
25
+ """
26
+
27
+ def __init__(self, hparams, model_ids, mode="train", json_name="object.json"):
28
+ self.hparams = hparams
29
+ self.json_name = json_name
30
+ self.model_ids = self._filter_models(model_ids)
31
+ self.mode = mode
32
+ self.map_cat = False
33
+ self.get_acd_mapping()
34
+
35
+ self.no_GT = (
36
+ True if self.hparams.get("test_no_GT", False) and self.hparams.get("test_pred_G", False)
37
+ else False
38
+ )
39
+ self.pred_G = (
40
+ False
41
+ if mode in ["train", "val"]
42
+ else self.hparams.get("test_pred_G", False)
43
+ )
44
+
45
+ if mode == 'test':
46
+ if "acd" in hparams.test_which:
47
+ self.map_cat = True
48
+
49
+ self.files = self._cache_data()
50
+ print(f"[INFO] {mode} dataset: {len(self)} data samples loaded.")
51
+
52
+ def _cache_data_train(self):
53
+ json_data_root = self.hparams.json_root
54
+ data_root = self.hparams.root
55
+ # number of views per model and in total
56
+ n_views_per_model = self.hparams.n_views_per_model
57
+ n_views = n_views_per_model * len(self.model_ids)
58
+ # json files for each model
59
+ json_files = []
60
+ # mapping to the index of the corresponding model in json_files
61
+ model_mappings = []
62
+ # space for dinov2 patch features
63
+ feats = np.empty((n_views, 512, 768), dtype=np.float16)
64
+ # space for object masks on image patches
65
+ obj_masks = np.empty((n_views, 256), dtype=bool)
66
+ # input images (not required in training)
67
+ imgs = None
68
+ # load data for non-aug views
69
+ i = 0 # index for views
70
+ for j, model_id in enumerate(self.model_ids):
71
+ print(model_id)
72
+ # if j % 10 == 0 and torch.distributed.get_rank() == 0:
73
+ # print(f"\rLoading training data: {j}/{len(self.model_ids)}")
74
+ # 3D data
75
+ with open(os.path.join(json_data_root, model_id, self.json_name), "r") as f:
76
+ json_file = json.load(f)
77
+ json_files.append(json_file)
78
+ filenames = os.listdir(os.path.join(data_root, model_id, 'features'))
79
+ filenames = [f for f in filenames if 'high_res' not in f]
80
+ filenames = filenames[:self.hparams.n_views_per_model]
81
+ for filename in filenames:
82
+ view_feat = np.load(os.path.join(data_root, model_id, 'features', filename))
83
+ first_frame_feat = view_feat[0]
84
+ if self.hparams.frame_mode == 'last_frame':
85
+ second_frame_feat = view_feat[-2]
86
+ elif self.hparams.frame_mode == 'random_state_frame':
87
+ second_frame_feat = view_feat[-1]
88
+ else:
89
+ raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame")
90
+ feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16)
91
+ feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16)
92
+ i = i + 1
93
+ model_mappings += [j] * n_views_per_model
94
+ # object masks for all views
95
+ # all_obj_masks = np.load(
96
+ # os.path.join(json_data_root, model_id, "features/patch_obj_masks.npy")
97
+ # ) # (20, Np)
98
+ # obj_masks[i : i + n_views_per_model] = all_obj_masks[:n_views_per_model]
99
+ return {
100
+ "len": n_views,
101
+ "gt_files": json_files,
102
+ "features": feats,
103
+ "obj_masks": None,
104
+ "model_mappings": model_mappings,
105
+ "imgs": imgs,
106
+ }
107
+
108
+ def _cache_data_non_train(self):
109
+ # number of views per model and in total
110
+ n_views_per_model = 2
111
+ n_views = n_views_per_model * len(self.model_ids)
112
+ # json files for each model
113
+ gt_files = []
114
+ pred_files = [] # for predicted graphs
115
+ # mapping to the index of the corresponding model in json_files
116
+ model_mappings = []
117
+ # space for dinov2 patch features
118
+ feats = np.empty((n_views, 512, 768), dtype=np.float16)
119
+ # space for input images
120
+ first_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8)
121
+ second_imgs = np.empty((n_views, 128, 128, 3), dtype=np.uint8)
122
+ # transformation for input images
123
+ transform = T.Compose(
124
+ [
125
+ T.Resize(256, interpolation=T.InterpolationMode.BICUBIC),
126
+ T.CenterCrop(224),
127
+ T.Resize(128, interpolation=T.InterpolationMode.BICUBIC),
128
+ ]
129
+ )
130
+
131
+ i = 0 # index for views
132
+ desc = f'Loading {self.mode} data'
133
+ for j, model_id in tqdm(enumerate(self.model_ids), total=len(self.model_ids), desc=desc):
134
+ with open(os.path.join(self.hparams.json_root, model_id, self.json_name), "r") as f:
135
+ json_file = json.load(f)
136
+ gt_files.append(json_file)
137
+ # filename_dir = os.path.join(self.hparams.root, model_id, 'features')
138
+ for filename in ['18.npy', '19.npy']:
139
+ view_feat = np.load(os.path.join(self.hparams.root, model_id, 'features', filename))
140
+ first_frame_feat = view_feat[0]
141
+ if self.hparams.frame_mode == 'last_frame':
142
+ second_frame_feat = view_feat[-2]
143
+ elif self.hparams.frame_mode == 'random_state_frame':
144
+ second_frame_feat = view_feat[-1]
145
+ else:
146
+ raise NotImplementedError("Please provide correct frame mode: last_frame | random_state_frame")
147
+ feats[i : i + 1, :256, :] = first_frame_feat.astype(np.float16)
148
+ feats[i : i + 1, 256:, :] = second_frame_feat.astype(np.float16)
149
+
150
+ video_path = os.path.join(self.hparams.root, model_id, 'imgs', 'animation_' + filename.replace('.npy', '.mp4'))
151
+ reader = imageio.get_reader(video_path)
152
+ frames = []
153
+ for frame in reader:
154
+ frames.append(frame)
155
+ reader.close()
156
+
157
+ first_img = Image.fromarray(frames[0])
158
+ if first_img.mode == 'RGBA':
159
+ first_img = make_white_background(first_img)
160
+
161
+
162
+ first_img = np.asarray(transform(first_img), dtype=np.int8)
163
+ first_imgs[i] = first_img
164
+
165
+ if self.hparams.frame_mode == 'last_frame':
166
+ second_img = Image.fromarray(frames[-1])
167
+ elif self.hparams.frame_mode == 'random_state_frame':
168
+ second_img_path = video_path.replace('animation', 'random').replace('.mp4', '.png')
169
+ second_img = Image.open(second_img_path)
170
+ if second_img.mode == 'RGBA':
171
+ second_img = make_white_background(second_img)
172
+ second_img = np.asarray(transform(second_img), dtype=np.int8)
173
+ second_imgs[i] = second_img
174
+
175
+ i = i + 1
176
+ # mapping to json file
177
+ model_mappings += [j] * n_views_per_model
178
+
179
+ return {
180
+ "len": n_views,
181
+ "gt_files": gt_files,
182
+ "pred_files": pred_files,
183
+ "features": feats,
184
+ "model_mappings": model_mappings,
185
+ "imgs": [first_imgs, second_imgs],
186
+ }
187
+
188
+ def _cache_data(self):
189
+ """
190
+ Function to cache data from disk.
191
+ """
192
+ if self.mode == "train":
193
+ return self._cache_data_train()
194
+ else:
195
+ return self._cache_data_non_train()
196
+
197
+ def _get_item_train_val(self, index):
198
+ model_i = self.files["model_mappings"][index]
199
+ gt_file = self.files["gt_files"][model_i]
200
+ data, cond = self._prepare_input_GT(
201
+ file=gt_file, model_id=self.model_ids[model_i]
202
+ )
203
+ if self.mode == "val":
204
+ # input image for visualization
205
+ img_first = self.files["imgs"][0][index]
206
+ img_last = self.files["imgs"][1][index]
207
+ cond["img"] = np.concatenate([img_first, img_last], axis=1)
208
+ # else:
209
+ # # object masks on patches
210
+ # # obj_mask = self.files["obj_masks"][index][None, ...].repeat(self.hparams.K * 5, axis=0)
211
+ # cond["img_obj_mask"] = [None]
212
+ return data, cond
213
+
214
+ def _get_item_test(self, index):
215
+ model_i = self.files["model_mappings"][index]
216
+
217
+ gt_file = None if self.no_GT else self.files["gt_files"][model_i]
218
+
219
+ if self.hparams.get('G_dir', None) is None:
220
+ data, cond = self._prepare_input_GT(file=gt_file, model_id=self.model_ids[model_i])
221
+ else:
222
+ if index % 2 == 0:
223
+ filename = '18.json'
224
+ else:
225
+ filename = '19.json'
226
+ pred_file_path = os.path.join(self.hparams.G_dir, self.model_ids[model_i], filename)
227
+ with open(pred_file_path, "r") as f:
228
+ pred_file = json.load(f)
229
+ data, cond = self._prepare_input(model_id=self.model_ids[model_i], pred_file=pred_file, gt_file=gt_file)
230
+ # input image for visualization
231
+ img_first = self.files["imgs"][0][index]
232
+ img_last = self.files["imgs"][1][index]
233
+ cond["img"] = np.concatenate([img_first, img_last], axis=1)
234
+ return data, cond
235
+
236
+ def __getitem__(self, index):
237
+ # input image features
238
+ feat = self.files["features"][index]
239
+
240
+ # prepare input, GT data and other axillary info
241
+ if self.mode == "test":
242
+ data, cond = self._get_item_test(index)
243
+ else:
244
+ data, cond = self._get_item_train_val(index)
245
+
246
+ return data, cond, feat
247
+
248
+ def __len__(self):
249
+ return self.files["len"]
250
+
251
+ if __name__ == '__main__':
252
+ from types import SimpleNamespace
253
+
254
+ class EnhancedNamespace(SimpleNamespace):
255
+ def get(self, key, default=None):
256
+ return getattr(self, key, default)
257
+
258
+ hparams = {
259
+ "name": "dm_singapo",
260
+ "json_root": "/home/users/ruiqi.wu/singapo/", # root directory of the dataset
261
+ "batch_size": 20, # batch size for training
262
+ "num_workers": 8, # number of workers for data loading
263
+ "K": 32, # maximum number of nodes (parts) in the graph (object)
264
+ "split_file": "/home/users/ruiqi.wu/singapo/data/data_split.json",
265
+ "n_views_per_model": 5,
266
+ "root": "/home/users/ruiqi.wu/manipulate_3d_generate/data/blender_version",
267
+ "frame_mode": "last_frame"
268
+ }
269
+ hparams = EnhancedNamespace(**hparams)
270
+ with open(hparams.split_file , "r") as f:
271
+ splits = json.load(f)
272
+
273
+ train_ids = splits["train"]
274
+ val_ids = [i for i in train_ids if "augmented" not in i]
275
+
276
+ val_ids = [val_id for val_id in val_ids if os.path.exists(os.path.join(hparams.root, val_id, "features"))]
277
+
278
+ dataset = MyDataset(hparams, model_ids=val_ids[:20], mode="valid")
279
+ for i in range(20):
280
+ data, cond, feat = dataset.__getitem__(i)
281
+ import ipdb
282
+ ipdb.set_trace()
dataset/utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
3
+ import numpy as np
4
+ from PIL import Image
5
+ from my_utils.refs import joint_ref, sem_ref
6
+
7
+ def rescale_axis(jtype, axis_d, axis_o, box_center):
8
+ '''
9
+ Function to rescale the axis for rendering
10
+
11
+ Args:
12
+ - jtype (int): joint type
13
+ - axis_d (np.array): axis direction
14
+ - axis_o (np.array): axis origin
15
+ - box_center (np.array): bounding box center
16
+
17
+ Returns:
18
+ - center (np.array): rescaled axis origin
19
+ - axis_d (np.array): rescaled axis direction
20
+ '''
21
+ if jtype == 0 or jtype == 1:
22
+ return [0., 0., 0.], [0., 0., 0.]
23
+ if jtype == 3 or jtype == 4:
24
+ center = box_center
25
+ else:
26
+ center = axis_o + np.dot(axis_d, box_center-axis_o) * axis_d
27
+ return center.tolist(), axis_d.tolist()
28
+
29
+ def make_white_background(src_img):
30
+ '''Make the white background for the input RGBA image.'''
31
+ src_img.load()
32
+ background = Image.new("RGB", src_img.size, (255, 255, 255))
33
+ background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
34
+ return background
35
+
36
+ def build_graph(tree, K=32):
37
+ '''
38
+ Function to build graph from the node list.
39
+
40
+ Args:
41
+ nodes: list of nodes
42
+ K: the maximum number of nodes in the graph
43
+ Returns:
44
+ adj: adjacency matrix, records the 1-ring relationship (parent+children) between nodes
45
+ edge_list: list of edges, for visualization
46
+ '''
47
+ adj = np.zeros((K, K), dtype=np.float32)
48
+ parents = []
49
+ for node in tree:
50
+ # 1-ring relationship
51
+ if node['parent'] != -1:
52
+ adj[node['id'], node['parent']] = 1
53
+ parents.append(node['parent'])
54
+ else:
55
+ adj[node['id'], node['id']] = 1
56
+ parents.append(-1)
57
+ for child_id in node['children']:
58
+ adj[node['id'], child_id] = 1
59
+
60
+ return {
61
+ 'adj': adj,
62
+ 'parents': np.array(parents, dtype=np.int8)
63
+ }
64
+
65
+ def load_input_from(pred_file, K=32):
66
+ '''
67
+ Function to parse input item from a file containing the predicted graph
68
+ '''
69
+
70
+ cond = {} # conditional information and axillary data
71
+ # prepare node data
72
+ n_nodes = len(pred_file['diffuse_tree'])
73
+ # prepare graph
74
+ pred_graph = build_graph(pred_file['diffuse_tree'], K)
75
+
76
+ # attr mask (for Local Attention)
77
+ attr_mask = np.eye(K, K, dtype=bool)
78
+ attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
79
+ cond['attr_mask'] = attr_mask
80
+
81
+ # key padding mask (for Global Attention)
82
+ pad_mask = np.zeros((K*5, K*5), dtype=bool)
83
+ pad_mask[:, :n_nodes*5] = 1
84
+ cond['key_pad_mask'] = pad_mask
85
+
86
+ # adj mask (for Graph Relation Attention)
87
+ adj_mask = pred_graph['adj'][:].astype(bool)
88
+ adj_mask = adj_mask.repeat(5, axis=0).repeat(5, axis=1)
89
+ adj_mask[n_nodes*5:, :] = 1
90
+ cond['adj_mask'] = adj_mask
91
+
92
+ # placeholder
93
+ data = np.zeros((K*5, 6), dtype=bool)
94
+ cond['cat'] = 2
95
+
96
+ # axillary info
97
+ cond['adj'] = pred_graph['adj']
98
+ cond['parents'] = np.zeros(K, dtype=np.int8)
99
+ cond['parents'][:n_nodes] = pred_graph['parents']
100
+ cond['n_nodes'] = n_nodes
101
+
102
+ return data, cond
103
+
104
+ def convert_data_range(x):
105
+ '''postprocessing: convert the raw model output to the original range, following CAGE'''
106
+ x = x.reshape(-1, 30) # (K, 36)
107
+ aabb_max = x[:, 0:3]
108
+ aabb_min = x[:, 3:6]
109
+ center = (aabb_max + aabb_min) / 2.0
110
+ size = (aabb_max - aabb_min).clip(min=5e-3)
111
+
112
+ j_type = np.mean(x[:, 6:12], axis=1)
113
+ j_type = ((j_type + 0.5) * 5).clip(min=1.0, max=5.0).round()
114
+
115
+ axis_d = x[:, 12:15]
116
+ axis_d = axis_d / (
117
+ np.linalg.norm(axis_d, axis=1, keepdims=True) + np.finfo(float).eps
118
+ )
119
+ axis_o = x[:, 15:18]
120
+
121
+ j_range = (x[:, 18:20] + x[:, 20:22] + x[:, 22:24]) / 3
122
+ j_range = j_range.clip(min=-1.0, max=1.0)
123
+ j_range[:, 0] = j_range[:, 0] * 360
124
+ j_range[:, 1] = j_range[:, 1]
125
+
126
+ label = np.mean(x[:, 24:30], axis=1)
127
+ label = ((label + 0.8) * 5).clip(min=0.0, max=7.0).round()
128
+ return {
129
+ "center": center,
130
+ "size": size,
131
+ "type": j_type,
132
+ "axis_d": axis_d,
133
+ "axis_o": axis_o,
134
+ "range": j_range,
135
+ "label": label,
136
+ }
137
+
138
+ def parse_tree(data, n_nodes, par, adj):
139
+ tree = []
140
+ # convert to json format
141
+ for i in range(n_nodes):
142
+ node = {"id": i}
143
+ node["name"] = sem_ref["bwd"][int(data["label"][i].item())]
144
+ node["parent"] = int(par[i])
145
+ node["children"] = [
146
+ int(child) for child in np.where(adj[i] == 1)[0] if child != par[i]
147
+ ]
148
+ node["aabb"] = {}
149
+ node["aabb"]["center"] = data["center"][i].tolist()
150
+ node["aabb"]["size"] = data["size"][i].tolist()
151
+ node["joint"] = {}
152
+ if node['name'] == 'base':
153
+ node["joint"]["type"] = 'fixed'
154
+ else:
155
+ node["joint"]["type"] = joint_ref["bwd"][int(data["type"][i].item())]
156
+ if node["joint"]["type"] == "fixed":
157
+ node["joint"]["range"] = [0.0, 0.0]
158
+ elif node["joint"]["type"] == "revolute":
159
+ node["joint"]["range"] = [0.0, float(data["range"][i][0])]
160
+ elif node["joint"]["type"] == "continuous":
161
+ node["joint"]["range"] = [0.0, 360.0]
162
+ elif (
163
+ node["joint"]["type"] == "prismatic" or node["joint"]["type"] == "screw"
164
+ ):
165
+ node["joint"]["range"] = [0.0, float(data["range"][i][1])]
166
+ node["joint"]["axis"] = {}
167
+ # relocate the axis to visualize well
168
+ axis_o, axis_d = rescale_axis(
169
+ int(data["type"][i].item()),
170
+ data["axis_d"][i],
171
+ data["axis_o"][i],
172
+ data["center"][i],
173
+ )
174
+ node["joint"]["axis"]["direction"] = axis_d
175
+ node["joint"]["axis"]["origin"] = axis_o
176
+ # append node to the tree
177
+ tree.append(node)
178
+ return tree
179
+
180
+ def convert_json(x, c, prefix=''):
181
+ out = {"meta": {}, "diffuse_tree": []}
182
+ n_nodes = c[f"{prefix}n_nodes"][0].item()
183
+ par = c[f"{prefix}parents"][0].cpu().numpy().tolist()
184
+ adj = c[f"{prefix}adj"][0].cpu().numpy()
185
+ np.fill_diagonal(adj, 0) # remove self-loop for the root node
186
+ if f"{prefix}obj_cat" in c:
187
+ out["meta"]["obj_cat"] = c[f"{prefix}obj_cat"][0]
188
+
189
+ # convert the data to original range
190
+ data = convert_data_range(x)
191
+ # parse the tree
192
+ tree = parse_tree(data, n_nodes, par, adj)
193
+ out["diffuse_tree"] = tree
194
+ return out
inference.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
3
+ import json
4
+ import torch
5
+ import argparse
6
+ import numpy as np
7
+ from PIL import Image, ImageOps
8
+ import imageio
9
+ # from my_utils.plot import viz_graph
10
+ from my_utils.misc import load_config
11
+ import torchvision.transforms as T
12
+ from diffusers import DDPMScheduler
13
+ from models.denoiser import Denoiser
14
+ from scripts.json2urdf import create_urdf_from_json, pybullet_render
15
+ from dataset.utils import make_white_background, load_input_from, convert_data_range, parse_tree
16
+ import models
17
+ import torch.nn.functional as F
18
+ from io import BytesIO
19
+ import base64
20
+ from scripts.graph_pred.api import predict_graph_twomode, gpt_infer_image_category
21
+ import subprocess
22
+ import spaces
23
+ import time
24
+
25
+
26
+ cat_ref = {
27
+ "Table": 0,
28
+ "Dishwasher": 1,
29
+ "StorageFurniture": 2,
30
+ "Refrigerator": 3,
31
+ "WashingMachine": 4,
32
+ "Microwave": 5,
33
+ "Oven": 6,
34
+ }
35
+
36
+ def run_retrieve(src_dir, json_name, data_root):
37
+ fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root]
38
+ try:
39
+ subprocess.run(fn_call, check=True, stderr=subprocess.STDOUT)
40
+ except subprocess.CalledProcessError as e:
41
+ print(f'Error from run_retrieve: {src_dir}')
42
+ print(f'Error: {e}')
43
+
44
+ def make_white_background(src_img):
45
+ '''Make the white background for the input RGBA image.'''
46
+ src_img.load()
47
+ background = Image.new("RGB", src_img.size, (255, 255, 255))
48
+ background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
49
+ return background
50
+
51
+ def pad_to_square(img, fill=0):
52
+ """Pad image to square with given fill value (default: 0 = black)."""
53
+ width, height = img.size
54
+ if width == height:
55
+ return img
56
+ max_side = max(width, height)
57
+ delta_w = max_side - width
58
+ delta_h = max_side - height
59
+ padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2)
60
+ return ImageOps.expand(img, padding, fill=fill)
61
+
62
+ def load_img(img_path):
63
+ transform = T.Compose([
64
+ T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
65
+ T.ToTensor(),
66
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
67
+ ])
68
+ with Image.open(img_path) as img:
69
+ if img.mode == 'RGBA':
70
+ img = make_white_background(img)
71
+ img = img.convert('RGB') # Ensure it's 3-channel for normalization
72
+ img = pad_to_square(img, fill=0)
73
+ img = transform(img)
74
+ img_batch = img.unsqueeze(0).cuda()
75
+
76
+ return img_batch
77
+
78
+
79
+ def load_frame_with_imageio(frame):
80
+ """
81
+ 将单帧图像处理为符合 DINO 模型输入的格式。
82
+ """
83
+ transform = T.Compose([
84
+ T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
85
+ T.ToTensor(),
86
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
87
+ ])
88
+ img = Image.fromarray(frame) # 转为 PIL 图像
89
+ if img.mode == 'RGBA':
90
+ img = make_white_background(img)
91
+ img = transform(img) # 应用预处理
92
+ return img.unsqueeze(0).cuda() # 增加 batch 维度
93
+
94
+ def read_video_as_batch_with_imageio(video_path):
95
+ """
96
+ 使用 imageio 读取视频并将所有帧处理为 batch 格式 (B, C, H, W)。
97
+ """
98
+ reader = imageio.get_reader(video_path)
99
+ batch_frames = []
100
+
101
+ try:
102
+ for frame in reader:
103
+ # 加载帧并处理为 (1, C, H, W)
104
+ processed_frame = load_frame_with_imageio(frame)
105
+ batch_frames.append(processed_frame)
106
+
107
+ reader.close()
108
+ if batch_frames:
109
+ return torch.cat(batch_frames, dim=0).cuda() # 在 batch 维度堆叠,并转移到 GPU
110
+ else:
111
+ print("视频没有有效帧")
112
+ return None
113
+ except Exception as e:
114
+ print(f"处理视频时出错: {e}")
115
+ return None
116
+
117
+ def extract_dino_feature(img_path_1, img_path_2):
118
+ print('Extracting DINO feature...')
119
+ feat_1 = load_img(img_path_1)
120
+ feat_2 = load_img(img_path_2)
121
+ frames = torch.cat([feat_1, feat_2], dim=0)
122
+ dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg', pretrained=True).cuda()
123
+ print('step4')
124
+ with torch.no_grad():
125
+ feat = dinov2_vitb14_reg.forward_features(frames)["x_norm_patchtokens"]
126
+ # release the GPU memory of the model
127
+ feat_input = torch.cat([feat[0], feat[-1]], dim=0).unsqueeze(0)
128
+ print('Extracting DINO feature over')
129
+ torch.cuda.empty_cache()
130
+ return feat_input
131
+
132
+ def set_scheduler(n_steps=100):
133
+ scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='linear', prediction_type='epsilon')
134
+ scheduler.set_timesteps(n_steps)
135
+ return scheduler
136
+
137
+ def prepare_model_input(data, cond, feat, n_samples):
138
+ # attention masks
139
+ attr_mask = torch.from_numpy(cond['attr_mask']).unsqueeze(0).repeat(n_samples, 1, 1)
140
+ key_pad_mask = torch.from_numpy(cond['key_pad_mask'])
141
+ graph_mask = torch.from_numpy(cond['adj_mask'])
142
+ # input image feature
143
+ f = feat.repeat(n_samples, 1, 1)
144
+ # input noise
145
+ B, C = data.shape
146
+ noise = torch.randn([n_samples, B, C], dtype=torch.float32)
147
+ # dummy image feature (used for guided diffusion)
148
+ dummy_feat = torch.from_numpy(np.load('systems/dino_dummy.npy').astype(np.float32))
149
+ dummy_feat = dummy_feat.unsqueeze(0).repeat(n_samples, 1, 1)
150
+ # dummy object category
151
+ cat = torch.zeros(1, dtype=torch.long).repeat(n_samples)
152
+ return {
153
+ "noise": noise.cuda(),
154
+ "attr_mask": attr_mask.cuda(),
155
+ "key_pad_mask": key_pad_mask.cuda(),
156
+ "graph_mask": graph_mask.cuda(),
157
+ "dummy_f": dummy_feat.cuda(),
158
+ 'cat': cat.cuda(),
159
+ 'f': f.cuda(),
160
+ }
161
+
162
+ def prepare_model_input_nocond(feat, n_samples):
163
+ # attention masks
164
+ cond_example = np.zeros((32*5, 32*5), dtype=bool)
165
+ attr_mask = np.eye(32, 32, dtype=bool)
166
+ attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
167
+ attr_mask = torch.from_numpy(attr_mask).unsqueeze(0).repeat(n_samples, 1, 1)
168
+ key_pad_mask = torch.from_numpy(cond_example).unsqueeze(0).repeat(n_samples, 1, 1)
169
+ graph_mask = torch.from_numpy(cond_example).unsqueeze(0).repeat(n_samples, 1, 1)
170
+ # input image feature
171
+ f = feat.repeat(n_samples, 1, 1)
172
+ # input noise
173
+ data = np.zeros((32*5, 6), dtype=bool)
174
+ noise = torch.randn(data.shape, dtype=torch.float32).repeat(n_samples, 1, 1)
175
+ # dummy image feature (used for guided diffusion)
176
+ dummy_feat = torch.from_numpy(np.load('systems/dino_dummy.npy').astype(np.float32))
177
+ dummy_feat = dummy_feat.unsqueeze(0).repeat(n_samples, 1, 1)
178
+ # dummy object category
179
+ cat = torch.zeros(1, dtype=torch.long).repeat(n_samples)
180
+ return {
181
+ "noise": noise.cuda(),
182
+ "attr_mask": attr_mask.cuda(),
183
+ "key_pad_mask": key_pad_mask.cuda(),
184
+ "graph_mask": graph_mask.cuda(),
185
+ "dummy_f": dummy_feat.cuda(),
186
+ 'cat': cat.cuda(),
187
+ 'f': f.cuda(),
188
+ }
189
+
190
+ def save_graph(pred_graph, save_dir):
191
+ print(f'Saving the predicted graph to {save_dir}/pred_graph.json')
192
+ # save the response
193
+ with open(os.path.join(save_dir, "pred_graph.json"), "w") as f:
194
+ json.dump(pred_graph, f, indent=4)
195
+ # Visualize the graph
196
+ # img_graph = Image.fromarray(viz_graph(pred_graph))
197
+ # img_graph.save(os.path.join(save_dir, "pred_graph.png"))
198
+
199
+ def forward(model, scheduler, inputs, omega=0.5):
200
+ print('Running inference...')
201
+ noisy_x = inputs['noise']
202
+ for t in scheduler.timesteps:
203
+ timesteps = torch.tensor([t], device=inputs['noise'].device)
204
+ outputs_cond = model(
205
+ x=noisy_x,
206
+ cat=inputs['cat'],
207
+ timesteps=timesteps,
208
+ feat=inputs['f'],
209
+ key_pad_mask=inputs['key_pad_mask'],
210
+ graph_mask=inputs['graph_mask'],
211
+ attr_mask=inputs['attr_mask'],
212
+ label_free=True,
213
+ ) # take condtional image as input
214
+ if omega != 0:
215
+ outputs_free = model(
216
+ x=noisy_x,
217
+ cat=inputs['cat'],
218
+ timesteps=timesteps,
219
+ feat=inputs['dummy_f'],
220
+ key_pad_mask=inputs['key_pad_mask'],
221
+ graph_mask=inputs['graph_mask'],
222
+ attr_mask=inputs['attr_mask'],
223
+ label_free=True,
224
+ ) # take the dummy DINO features for the condition-free mode
225
+ noise_pred = (1 + omega) * outputs_cond['noise_pred'] - omega * outputs_free['noise_pred']
226
+ else:
227
+ noise_pred = outputs_cond['noise_pred']
228
+ noisy_x = scheduler.step(noise_pred, t, noisy_x).prev_sample
229
+ return noisy_x
230
+
231
+ def _convert_json(x, c):
232
+ out = {"meta": {}, "diffuse_tree": []}
233
+ n_nodes = c["n_nodes"]
234
+ par = c["parents"].tolist()
235
+ adj = c["adj"]
236
+ np.fill_diagonal(adj, 0) # remove self-loop for the root node
237
+ if "obj_cat" in c:
238
+ out["meta"]["obj_cat"] = c["obj_cat"]
239
+
240
+ # convert the data to original range
241
+ data = convert_data_range(x)
242
+ # parse the tree
243
+ out["diffuse_tree"] = parse_tree(data, n_nodes, par, adj)
244
+ return out
245
+
246
+ def post_process(output, cond, save_root, gt_data_root, visualize=False):
247
+ print('Post-processing...')
248
+ N = output.shape[0]
249
+ for i in range(N):
250
+ cond_n = {}
251
+ cond_n['n_nodes'] = cond['n_nodes'][i]
252
+ cond_n['parents'] = cond['parents'][i]
253
+ cond_n['adj'] = cond['adj'][i]
254
+ cond_n['obj_cat'] = cond['cat']
255
+ # convert the raw model output to the json format
256
+ out_json = _convert_json(output, cond_n)
257
+ save_dir = os.path.join(save_root, str(i))
258
+ os.makedirs(save_dir, exist_ok=True)
259
+ with open(os.path.join(save_dir, "object.json"), "w") as f:
260
+ json.dump(out_json, f, indent=4)
261
+
262
+
263
+ # retrieve part meshes (call python script)
264
+ # print(f"Retrieving part meshes for the object {i}...")
265
+ # os.system(f"python scripts/mesh_retrieval/retrieve.py --src_dir {save_dir} --json_name object.json --gt_data_root {gt_data_root}")
266
+
267
+
268
+
269
+
270
+ def load_model(ckpt_path, config):
271
+ print('Loading model from checkpoint...')
272
+ model = models.make(config.name, config)
273
+ state_dict = torch.load(ckpt_path)
274
+ state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
275
+ model.load_state_dict(state_dict)
276
+ model.eval()
277
+ return model.cuda()
278
+
279
+ def convert_pred_graph(pred_graph):
280
+ cond = {}
281
+ B, K = pred_graph.shape[:2]
282
+ adj = np.zeros((B, K, K), dtype=np.float32)
283
+ padding = np.zeros((B, 5 * K, 5* K), dtype=bool)
284
+ parents = np.zeros((B, K), dtype=np.int32)
285
+ n_nodes = np.zeros((B,), dtype=np.int32)
286
+ for b in range(B):
287
+ node_len = 0
288
+ for k in range(K):
289
+ if pred_graph[b, k] == k and k > 0:
290
+ node_len = k
291
+ break
292
+ node = pred_graph[b, k]
293
+ adj[b, k, node] = 1
294
+ adj[b, node, k] = 1
295
+ parents[b, k] = node
296
+ adj[b, node_len:] = 1
297
+ padding[b, :, :5 * node_len] = 1
298
+ parents[b, 0] = -1
299
+ n_nodes[b] = node_len
300
+ adj_mask = adj.astype(bool).repeat(5, axis=1).repeat(5, axis=2)
301
+ attr_mask = np.eye(32, 32, dtype=bool)
302
+ attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
303
+
304
+ cond['adj_mask'] = adj_mask
305
+ cond['attr_mask'] = attr_mask
306
+ cond['key_pad_mask'] = padding
307
+
308
+ cond['adj'] = adj
309
+ cond['parents'] = parents
310
+ cond['n_nodes'] = n_nodes
311
+ cond['cat'] = 'StorageFurniture'
312
+
313
+ data = np.zeros((32*5, 6), dtype=bool)
314
+
315
+ return data, cond
316
+
317
+ def bfs_tree_simple(tree_list):
318
+ order = [0] * len(tree_list)
319
+ queue = []
320
+ current_node_idx = 0
321
+ for node_idx, node in enumerate(tree_list):
322
+ if node['parent'] == -1:
323
+ queue.append(node['id'])
324
+ order[current_node_idx] = node_idx
325
+ current_node_idx += 1
326
+ break
327
+ while len(queue) > 0:
328
+ current_node = queue.pop(0)
329
+ for node_idx, node in enumerate(tree_list):
330
+ if node['parent'] == current_node:
331
+ queue.append(node['id'])
332
+ order[current_node_idx] = node_idx
333
+ current_node_idx += 1
334
+
335
+ return order
336
+
337
+ def get_graph_from_gpt(img_path_1, img_path_2):
338
+ first_img = Image.open(img_path_1)
339
+ first_img_data = first_img.resize((1024, 1024))
340
+ buffer = BytesIO()
341
+ first_img_data.save(buffer, format="PNG")
342
+ buffer.seek(0)
343
+ # encode the image as base64
344
+ first_encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
345
+
346
+
347
+ second_img = Image.open(img_path_2)
348
+ second_img_data = second_img.resize((1024, 1024))
349
+ buffer = BytesIO()
350
+ second_img_data.save(buffer, format="PNG")
351
+ buffer.seek(0)
352
+ # encode the image as base64
353
+ second_encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
354
+
355
+ pred_gpt = predict_graph_twomode('', first_img_data=first_encoded_image, second_img_data=second_encoded_image)
356
+ print(pred_gpt)
357
+ pred_graph = pred_gpt['diffuse_tree']
358
+ # order = bfs_tree_simple(pred_graph)
359
+ # pred_graph = [pred_graph[i] for i in order]
360
+
361
+
362
+ # generate array [0, 1, 2, ..., 31] for init
363
+ graph_array = np.array([i for i in range(32)])
364
+ for node_idx, node in enumerate(pred_graph):
365
+ if node['parent'] == -1:
366
+ graph_array[node_idx] = node_idx
367
+ else:
368
+ graph_array[node_idx] = node['parent']
369
+
370
+ # new axis for batch
371
+ graph_array = np.expand_dims(graph_array, axis=0)
372
+
373
+ cat_str = gpt_infer_image_category(first_encoded_image, second_encoded_image)
374
+
375
+ return torch.from_numpy(graph_array).cuda().repeat(3, 1), cat_str
376
+
377
+ @spaces.GPU
378
+ def run_demo(args):
379
+ # extract DINOV2 feature from the input image
380
+ t1 = time.time()
381
+ feat = extract_dino_feature(args.img_path_1, args.img_path_2)
382
+ t2 = time.time()
383
+ print(f'Extracted DINO feature in {t2 - t1:.2f} seconds')
384
+ scheduler = set_scheduler(args.n_denoise_steps)
385
+ # load the checkpoint of the model
386
+ model = load_model(args.ckpt_path, args.config.system.model)
387
+
388
+ # inference
389
+ with torch.no_grad():
390
+ t3 = time.time()
391
+ pred_graph, cat_str = get_graph_from_gpt(args.img_path_1, args.img_path_2)
392
+ t4 = time.time()
393
+ print(f'Got the predicted graph in {t4 - t3:.2f} seconds')
394
+ print(pred_graph)
395
+ data, cond = convert_pred_graph(pred_graph)
396
+ inputs = prepare_model_input(data, cond, feat, n_samples=args.n_samples)
397
+
398
+ # Update the object category
399
+ cond['cat'] = cat_str
400
+ inputs['cat'][:] = cat_ref[cat_str]
401
+ print(f'Object category predicted by GPT: {cat_str}, {cat_ref[cat_str]}')
402
+
403
+ output = forward(model, scheduler, inputs, omega=args.omega).cpu().numpy()
404
+ t5 = time.time()
405
+ print(f'Forwarded the model in {t5 - t4:.2f} seconds')
406
+
407
+ # post-process
408
+ post_process(output, cond, args.save_dir, args.gt_data_root, visualize=True)
409
+
410
+ # retrieve
411
+ for sample in os.listdir(args.save_dir):
412
+ sample_dir = os.path.join(args.save_dir, sample)
413
+ t6 = time.time()
414
+ run_retrieve(sample_dir, 'object.json', args.gt_data_root)
415
+ t7 = time.time()
416
+ print(f'Retrieved part meshes for in {t7 - t6:.2f} seconds')
417
+
418
+ save_json_path = os.path.join(args.save_dir, "0", "object.json")
419
+ with open(save_json_path, 'r') as file:
420
+ json_data = json.load(file)
421
+ create_urdf_from_json(json_data, save_json_path.replace('.json', '.urdf'))
422
+ pybullet_render(save_json_path.replace('.json', '.urdf'), os.path.join(args.save_dir, "0"), 8)
423
+
424
+
425
+ if __name__ == '__main__':
426
+ '''
427
+ Script for running the inference on an example image input.
428
+ '''
429
+ parser = argparse.ArgumentParser()
430
+ parser.add_argument("--img_path_1", type=str, default='examples/1.png', help="path to the input image")
431
+ parser.add_argument("--img_path_2", type=str, default='examples/1_open_1.png', help="path to the input image")
432
+ parser.add_argument("--ckpt_path", type=str, default='exps/singapo/final/ckpts/last.ckpt', help="path to the checkpoint of the model")
433
+ parser.add_argument("--config_path", type=str, default='exps/singapo/final/config/parsed.yaml', help="path to the config file")
434
+ parser.add_argument("--use_example_graph", action="store_true", default=False, help="if you don't have the openai key yet, turn on to use the example graph for inference")
435
+ parser.add_argument("--save_dir", type=str, default='results', help="path to save the output")
436
+ parser.add_argument("--gt_data_root", type=str, default='./', help="the root directory of the original data, used for part mesh retrieval")
437
+ parser.add_argument("--n_samples", type=int, default=3, help="number of samples to generate given the input")
438
+ parser.add_argument("--omega", type=float, default=0.5, help="the weight of the condition-free mode in the inference")
439
+ parser.add_argument("--n_denoise_steps", type=int, default=100, help="number of denoising steps")
440
+ args = parser.parse_args()
441
+
442
+ assert os.path.exists(args.img_path_1), "The input image does not exist"
443
+ # assert os.path.exists(args.ckpt_path), "The checkpoint does not exist"
444
+ assert os.path.exists(args.config_path), "The config file does not exist"
445
+ os.makedirs(args.save_dir, exist_ok=True)
446
+
447
+ config = load_config(args.config_path)
448
+ args.config = config
449
+
450
+ run_demo(args)
lightning_logs/version_0/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_1/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_2/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_3/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_4/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_5/hparams.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
lightning_logs/version_6/hparams.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sys_origin
2
+ exp_dir: ./exps/dipo/denoiser
3
+ data_root: /horizon-bucket/robot_lab/users/ruiqi.wu/robot/dataset/blender
4
+ n_time_samples: 16
5
+ loss_fg_weight: 0.01
6
+ img_drop_prob: 0.1
7
+ guidance_scaler: 0.5
8
+ graph_drop_prob: 0.5
9
+ model:
10
+ name: denoiser
11
+ in_ch: 6
12
+ attn_dim: 128
13
+ n_head: 4
14
+ n_layers: 6
15
+ dropout: 0.1
16
+ K: 32
17
+ mode_num: 5
18
+ img_emb_dims:
19
+ - 768
20
+ - 128
21
+ cat_drop_prob: 0.5
22
+ scheduler:
23
+ name: ddpm
24
+ config:
25
+ num_train_timesteps: 1000
26
+ beta_schedule: linear
27
+ prediction_type: epsilon
28
+ lr_scheduler_adapter:
29
+ name: LinearWarmupCosineAnnealingLR
30
+ warmup_epochs: 3
31
+ max_epochs: 200
32
+ warmup_start_lr: 1.0e-06
33
+ eta_min: 1.0e-05
34
+ optimizer_adapter:
35
+ name: AdamW
36
+ args:
37
+ lr: 0.0005
38
+ betas:
39
+ - 0.9
40
+ - 0.99
41
+ eps: 1.0e-15
42
+ lr_scheduler_cage:
43
+ name: LinearWarmupCosineAnnealingLR
44
+ warmup_epochs: 3
45
+ max_epochs: 200
46
+ warmup_start_lr: 1.0e-06
47
+ eta_min: 1.0e-05
48
+ optimizer_cage:
49
+ name: AdamW
50
+ args:
51
+ lr: 5.0e-05
52
+ betas:
53
+ - 0.9
54
+ - 0.99
55
+ eps: 1.0e-15
56
+ hparams:
57
+ name: sys_origin
58
+ exp_dir: ./exps/dipo/denoiser
59
+ data_root: /horizon-bucket/robot_lab/users/ruiqi.wu/robot/dataset/blender
60
+ n_time_samples: 16
61
+ loss_fg_weight: 0.01
62
+ img_drop_prob: 0.1
63
+ guidance_scaler: 0.5
64
+ graph_drop_prob: 0.5
65
+ model:
66
+ name: denoiser
67
+ in_ch: 6
68
+ attn_dim: 128
69
+ n_head: 4
70
+ n_layers: 6
71
+ dropout: 0.1
72
+ K: 32
73
+ mode_num: 5
74
+ img_emb_dims:
75
+ - 768
76
+ - 128
77
+ cat_drop_prob: 0.5
78
+ scheduler:
79
+ name: ddpm
80
+ config:
81
+ num_train_timesteps: 1000
82
+ beta_schedule: linear
83
+ prediction_type: epsilon
84
+ lr_scheduler_adapter:
85
+ name: LinearWarmupCosineAnnealingLR
86
+ warmup_epochs: 3
87
+ max_epochs: 200
88
+ warmup_start_lr: 1.0e-06
89
+ eta_min: 1.0e-05
90
+ optimizer_adapter:
91
+ name: AdamW
92
+ args:
93
+ lr: 0.0005
94
+ betas:
95
+ - 0.9
96
+ - 0.99
97
+ eps: 1.0e-15
98
+ lr_scheduler_cage:
99
+ name: LinearWarmupCosineAnnealingLR
100
+ warmup_epochs: 3
101
+ max_epochs: 200
102
+ warmup_start_lr: 1.0e-06
103
+ eta_min: 1.0e-05
104
+ optimizer_cage:
105
+ name: AdamW
106
+ args:
107
+ lr: 5.0e-05
108
+ betas:
109
+ - 0.9
110
+ - 0.99
111
+ eps: 1.0e-15
lightning_logs/version_6/metrics.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ lr-AdamW,lr-AdamW-1,step
2
+ 1e-06,1e-06,0
3
+ 0.0002505,2.5500000000000003e-05,10
4
+ 0.0005,5.000000000000001e-05,20
metrics/__init__.py ADDED
File without changes
metrics/aor.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
3
+ import numpy as np
4
+ from copy import deepcopy
5
+ from metrics.iou import sampling_iou
6
+ from objects.motions import transform_all_parts
7
+
8
+ from objects.dict_utils import get_bbox_vertices
9
+
10
+ '''
11
+ This file computes the Average Overlap Ratio (AOR) metric\n
12
+ '''
13
+
14
+ def AOR(tgt, num_states=10, transform_use_plucker=False):
15
+ tree = tgt["diffuse_tree"]
16
+ states = np.linspace(0, 1, num_states)
17
+ original_bbox_vertices = np.array([get_bbox_vertices(tgt, i) for i in range(len(tgt["diffuse_tree"]))], dtype=np.float32)
18
+
19
+ ious = []
20
+ for state_idx, state in enumerate(states):
21
+ ious_per_state = []
22
+ bbox_vertices = deepcopy(original_bbox_vertices)
23
+ part_trans = transform_all_parts(bbox_vertices, tgt, state, transform_use_plucker)
24
+ for node in tree:
25
+ children = node['children']
26
+ num_children = len(children)
27
+ if num_children < 2:
28
+ continue
29
+ for i in range(num_children-1):
30
+ for j in range(i+1, num_children):
31
+ child_id = children[i]
32
+ sibling_id = children[j]
33
+ bbox_v_0 = deepcopy(bbox_vertices[child_id])
34
+ bbox_v_1 = deepcopy(bbox_vertices[sibling_id])
35
+ iou = sampling_iou(bbox_v_0, bbox_v_1, part_trans[child_id], part_trans[sibling_id], num_samples=10000)
36
+ if np.isnan(iou):
37
+ continue
38
+ ious_per_state.append(iou)
39
+ if len(ious_per_state) > 0:
40
+ ious.append(np.mean(ious_per_state))
41
+ if len(ious) == 0:
42
+ return -1
43
+ return float(np.mean(ious))
44
+
metrics/cd.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script computes the Chamfer Distance (CD) between two objects\n
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+ import torch
9
+ import numpy as np
10
+ import trimesh
11
+ from copy import deepcopy
12
+ from pytorch3d.structures import Meshes
13
+ from pytorch3d.ops import sample_points_from_meshes
14
+ from pytorch3d.loss import chamfer_distance
15
+ from objects.motions import transform_all_parts
16
+ from objects.dict_utils import (
17
+ zero_center_object,
18
+ rescale_object,
19
+ compute_overall_bbox_size,
20
+ get_base_part_idx,
21
+ find_part_mapping
22
+ )
23
+
24
+ def _load_and_combine_plys(dir, ply_files, scale=None, z_rotate=None, translate=None):
25
+ """
26
+ Load and combine the ply files into one PyTorch3D mesh
27
+
28
+ - dir: the directory of the object in which the ply files are from\n
29
+ - ply_files: the list of ply files\n
30
+ - scale: the scale factor to apply to the vertices\n
31
+ - z_rotate: whether to rotate the object around the z-axis by 90 degrees\n
32
+ - translate: the translation to apply to the vertices\n
33
+
34
+ Return:\n
35
+ - mesh: one PyTorch3D mesh of the combined ply files
36
+ """
37
+
38
+ # Combine the ply files into one
39
+ meshes = []
40
+ for ply_file in ply_files:
41
+ meshes.append(trimesh.load(os.path.join(dir, ply_file), force="mesh"))
42
+ full_part_mesh = trimesh.util.concatenate(meshes)
43
+
44
+ # Apply the transformations
45
+ full_part_mesh.vertices -= full_part_mesh.bounding_box.centroid
46
+ transformation = trimesh.transformations.compose_matrix(
47
+ scale=scale,
48
+ angles=[0, 0, np.radians(90) if z_rotate else 0],
49
+ translate=translate,
50
+ )
51
+ full_part_mesh.apply_transform(transformation)
52
+
53
+ # Create the PyTorch3D mesh
54
+ mesh = Meshes(
55
+ verts=torch.as_tensor(full_part_mesh.vertices, dtype=torch.float32, device='cuda').unsqueeze(
56
+ 0
57
+ ),
58
+ faces=torch.as_tensor(full_part_mesh.faces, dtype=torch.int32, device='cuda').unsqueeze(0),
59
+ )
60
+
61
+ return mesh
62
+
63
+
64
+ def _compute_chamfer_distance(
65
+ obj1_part_points, obj2_part_points, part_mapping=None, exclude_id=-1
66
+ ):
67
+ """
68
+ Compute the chamfer distance between the two set of points representing the two objects
69
+
70
+ - obj1_part_points: the set of points representing the first object\n
71
+ - obj2_part_points: the set of points representing the second object\n
72
+ - part_mapping (optional): the part mapping from the first object to the second object, if provided, the chamfer distance will be computed between the corresponding parts\n
73
+ - exclude_id (optional): the part id to exclude from the chamfer distance computation, the default if provided is the base part id\n
74
+
75
+ Return:\n
76
+ - distance: the chamfer distance between the two objects
77
+ """
78
+ if part_mapping is not None:
79
+ n_parts = part_mapping.shape[0]
80
+ distance = 0
81
+ for i in range(n_parts):
82
+ if i == exclude_id:
83
+ continue
84
+ obj1_part_points_i = obj1_part_points[i]
85
+ obj2_part_points_i = obj2_part_points[int(part_mapping[i, 0])]
86
+ with torch.no_grad():
87
+ obj1_part_points_i = obj1_part_points_i.cuda()
88
+ obj2_part_points_i = obj2_part_points_i.cuda()
89
+ # symmetric chamfer distance
90
+ forward_distance, _ = chamfer_distance(
91
+ obj1_part_points_i[None, :],
92
+ obj2_part_points_i[None, :],
93
+ batch_reduction=None,
94
+ )
95
+ backward_distance, _ = chamfer_distance(
96
+ obj2_part_points_i[None, :],
97
+ obj1_part_points_i[None, :],
98
+ batch_reduction=None,
99
+ )
100
+ distance += (forward_distance.item() + backward_distance.item()) * 0.5
101
+ distance /= n_parts
102
+ else:
103
+ # Merge the points of all parts into one tensor
104
+ obj1_part_points = obj1_part_points.reshape(-1, 3)
105
+ obj2_part_points = obj2_part_points.reshape(-1, 3)
106
+
107
+ # Compute the chamfer distance between the two objects
108
+ with torch.no_grad():
109
+ obj1_part_points = obj1_part_points.cuda()
110
+ obj2_part_points = obj2_part_points.cuda()
111
+ forward_distance, _ = chamfer_distance(
112
+ obj1_part_points[None, :],
113
+ obj2_part_points[None, :],
114
+ batch_reduction=None,
115
+ )
116
+ backward_distance, _ = chamfer_distance(
117
+ obj2_part_points[None, :],
118
+ obj1_part_points[None, :],
119
+ batch_reduction=None,
120
+ )
121
+ distance = (forward_distance.item() + backward_distance.item()) * 0.5
122
+
123
+ return distance
124
+
125
+
126
+ def _get_scores(
127
+ src_dict,
128
+ tgt_dict,
129
+ original_src_part_points,
130
+ original_tgt_part_points,
131
+ part_mapping,
132
+ num_states,
133
+ include_base,
134
+ src_base_idx,
135
+ ):
136
+
137
+ chamfer_distances = np.zeros(num_states, dtype=np.float32)
138
+ joint_states = np.linspace(0, 1, num_states)
139
+ for state_idx, state in enumerate(joint_states):
140
+
141
+ # Reset the part point clouds
142
+ src_part_points = deepcopy(original_src_part_points)
143
+ tgt_part_points = deepcopy(original_tgt_part_points)
144
+
145
+ # Transform the part point clouds to the current state using the joints
146
+ transform_all_parts(src_part_points.numpy(), src_dict, state, dry_run=False)
147
+ transform_all_parts(tgt_part_points.numpy(), tgt_dict, state, dry_run=False)
148
+
149
+ # Compute the chamfer distance between the two objects
150
+ chamfer_distances[state_idx] = _compute_chamfer_distance(
151
+ src_part_points,
152
+ tgt_part_points,
153
+ part_mapping=part_mapping,
154
+ exclude_id=-1 if include_base else src_base_idx,
155
+ )
156
+
157
+ # Compute the ID
158
+ aid_cd = np.mean(chamfer_distances)
159
+ rid_cd = chamfer_distances[0]
160
+
161
+ return {
162
+ "AS-CD": float(aid_cd),
163
+ "RS-CD": float(rid_cd),
164
+ }
165
+
166
+
167
+ def CD(
168
+ gen_obj_dict,
169
+ gen_obj_path,
170
+ gt_obj_dict,
171
+ gt_obj_path,
172
+ num_states=5,
173
+ num_samples=2048,
174
+ include_base=False,
175
+ ):
176
+ """
177
+ Compute the Chamfer Distance\n
178
+ This metric is the average of per-part chamfer distance between the two objects over a number of articulation states\n
179
+
180
+ - gen_obj_dict: the generated object dictionary\n
181
+ - gen_obj_path: the directory to the predicted object\n
182
+ - gt_obj_dict: the ground truth object dictionary\n
183
+ - gt_obj_path: the directory to the ground truth object\n
184
+ - num_states (optional): the number of articulation states to compute the metric\n
185
+ - num_samples (optional): the number of samples to use\n
186
+ - include_base (optional): whether to include the base part in the chamfer distance computation\n
187
+
188
+ Return:\n
189
+ - aid_score: the score over the sampled articulated states\n
190
+ - rid_score: the score at the resting state\n
191
+ - The score is in the range of [0, inf), lower is better
192
+ """
193
+ # Make copies of the dictionaries to avoid modifying the original dictionaries
194
+ gen_dict = deepcopy(gen_obj_dict)
195
+ gt_dict = deepcopy(gt_obj_dict)
196
+
197
+ # Zero center the objects
198
+ zero_center_object(gen_dict)
199
+ zero_center_object(gt_dict)
200
+
201
+ # Compute the scale factor by comparing the overall bbox size and scale the candidate object as a whole
202
+ gen_bbox_size = compute_overall_bbox_size(gen_dict)
203
+ gt_bbox_size = compute_overall_bbox_size(gt_dict)
204
+ scale_factor = gen_bbox_size / gt_bbox_size
205
+ rescale_object(gen_obj_dict, scale_factor)
206
+
207
+ # Record the indices of the base parts of the two objects
208
+ gen_base_idx = get_base_part_idx(gen_dict)
209
+ gt_base_idx = get_base_part_idx(gt_dict)
210
+
211
+ # Find mapping between the parts of the two objects based on closest bbox centers
212
+ mapping_gen2gt = find_part_mapping(gen_dict, gt_dict, use_hungarian=True)
213
+ mapping_gt2gen = find_part_mapping(gt_dict, gen_dict, use_hungarian=True)
214
+
215
+ # Get the number of parts of the two objects
216
+ gen_tree = gen_dict["diffuse_tree"]
217
+ gt_tree = gt_dict["diffuse_tree"]
218
+ gen_num_parts = len(gen_tree)
219
+ gt_num_parts = len(gt_tree)
220
+
221
+ # Get the paths of the ply files of the two objects
222
+ gen_part_ply_paths = [
223
+ {"dir": gen_obj_path, "files": gen_tree[i]["plys"]}
224
+ for i in range(gen_num_parts)
225
+ ]
226
+ gt_part_ply_paths = [
227
+ {"dir": gt_obj_path, "files": gt_tree[i]["plys"]}
228
+ for i in range(gt_num_parts)
229
+ ]
230
+
231
+ # Load the ply files of the two objects and sample points from them
232
+ gen_part_points = torch.zeros(
233
+ (gen_num_parts, num_samples, 3), dtype=torch.float32
234
+ )
235
+ for i in range(gen_num_parts):
236
+ part_mesh = _load_and_combine_plys(
237
+ gen_part_ply_paths[i]["dir"],
238
+ gen_part_ply_paths[i]["files"],
239
+ scale=scale_factor,
240
+ translate=gen_tree[i]["aabb"]["center"],
241
+ )
242
+ gen_part_points[i] = sample_points_from_meshes(
243
+ part_mesh, num_samples=num_samples
244
+ ).squeeze(0).cpu()
245
+
246
+ gt_part_points = torch.zeros(
247
+ (gt_num_parts, num_samples, 3), dtype=torch.float32
248
+ )
249
+ for i in range(gt_num_parts):
250
+ part_mesh = _load_and_combine_plys(
251
+ gt_part_ply_paths[i]["dir"],
252
+ gt_part_ply_paths[i]["files"],
253
+ translate=gt_tree[i]["aabb"]["center"],
254
+ )
255
+ gt_part_points[i] = sample_points_from_meshes(
256
+ part_mesh, num_samples=num_samples
257
+ ).squeeze(0).cpu()
258
+
259
+ cd_gen2gt = _get_scores(
260
+ gen_dict,
261
+ gt_dict,
262
+ gen_part_points,
263
+ gt_part_points,
264
+ mapping_gen2gt,
265
+ num_states,
266
+ include_base,
267
+ gen_base_idx,
268
+ )
269
+
270
+ cd_gt2gen = _get_scores(
271
+ gt_dict,
272
+ gen_dict,
273
+ gt_part_points,
274
+ gen_part_points,
275
+ mapping_gt2gen,
276
+ num_states,
277
+ include_base,
278
+ gt_base_idx,
279
+ )
280
+
281
+ return {
282
+ "AS-CD": (cd_gen2gt["AS-CD"] + cd_gt2gen["AS-CD"]) / 2,
283
+ "RS-CD": (cd_gen2gt["RS-CD"] + cd_gt2gen["RS-CD"]) / 2,
284
+ }
metrics/giou.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
3
+ import numpy as np
4
+ from metrics.iou import (
5
+ _sample_points_in_box3d,
6
+ _apply_backward_transformations,
7
+ _apply_forward_transformations,
8
+ _count_points_in_box3d,
9
+ )
10
+
11
+
12
+ def giou_aabb(bbox1_vertices, bbox2_verices):
13
+ """
14
+ Compute the generalized IoU between two axis-aligned bounding boxes\n
15
+ - bbox1_vertices: the vertices of the first bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
16
+ - bbox2_vertices: the vertices of the second bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
17
+
18
+ Return:\n
19
+ - giou: the gIoU between the two bounding boxes
20
+ """
21
+ volume1 = np.prod(np.max(bbox1_vertices, axis=0) - np.min(bbox1_vertices, axis=0))
22
+ volume2 = np.prod(np.max(bbox2_verices, axis=0) - np.min(bbox2_verices, axis=0))
23
+
24
+ # Compute the intersection and union of the two bounding boxes
25
+ min_bbox = np.maximum(np.min(bbox1_vertices, axis=0), np.min(bbox2_verices, axis=0))
26
+ max_bbox = np.minimum(np.max(bbox1_vertices, axis=0), np.max(bbox2_verices, axis=0))
27
+ intersection = np.prod(np.clip(max_bbox - min_bbox, a_min=0, a_max=None))
28
+ union = volume1 + volume2 - intersection
29
+ # Compute IoU
30
+ iou = intersection / union if union > 0 else 0
31
+
32
+ # Compute the smallest enclosing box
33
+ min_enclosing_bbox = np.minimum(np.min(bbox1_vertices, axis=0), np.min(bbox2_verices, axis=0))
34
+ max_enclosing_bbox = np.maximum(np.max(bbox1_vertices, axis=0), np.max(bbox2_verices, axis=0))
35
+ volume3 = np.prod(max_enclosing_bbox - min_enclosing_bbox)
36
+
37
+ # Compute gIoU
38
+ giou = iou - (volume3 - union) / volume3 if volume3 > 0 else iou
39
+
40
+ return giou
41
+
42
+
43
+ def sampling_giou(
44
+ bbox1_vertices,
45
+ bbox2_vertices,
46
+ bbox1_transformations,
47
+ bbox2_transformations,
48
+ num_samples=10000,
49
+ ):
50
+ """
51
+ Compute the IoU between two bounding boxes\n
52
+ - bbox1_vertices: the vertices of the first bounding box\n
53
+ - bbox2_vertices: the vertices of the second bounding box\n
54
+ - bbox1_transformations: list of transformations applied to the first bounding box\n
55
+ - bbox2_transformations: list of transformations applied to the second bounding box\n
56
+ - num_samples (optional): the number of samples to use per bounding box\n
57
+
58
+ Return:\n
59
+ - iou: the IoU between the two bounding boxes after applying the transformations
60
+ """
61
+ # if no transformations are applied, use the axis-aligned bounding box IoU
62
+ if len(bbox1_transformations) == 0 and len(bbox2_transformations) == 0:
63
+ return giou_aabb(bbox1_vertices, bbox2_vertices)
64
+
65
+ # Volume of the two bounding boxes
66
+ bbox1_volume = np.prod(
67
+ np.max(bbox1_vertices, axis=0) - np.min(bbox1_vertices, axis=0)
68
+ )
69
+ bbox2_volume = np.prod(
70
+ np.max(bbox2_vertices, axis=0) - np.min(bbox2_vertices, axis=0)
71
+ )
72
+ # Volume of the smallest enclosing box
73
+ min_enclosing_bbox = np.minimum(np.min(bbox1_vertices, axis=0), np.min(bbox2_vertices, axis=0))
74
+ max_enclosing_bbox = np.maximum(np.max(bbox1_vertices, axis=0), np.max(bbox2_vertices, axis=0))
75
+ cbbox_volume = np.prod(max_enclosing_bbox - min_enclosing_bbox)
76
+
77
+ # Sample points in the two bounding boxes
78
+ bbox1_points = _sample_points_in_box3d(bbox1_vertices, num_samples)
79
+ bbox2_points = _sample_points_in_box3d(bbox2_vertices, num_samples)
80
+
81
+ # Transform the points
82
+ forward_bbox1_points = _apply_forward_transformations(
83
+ bbox1_points, bbox1_transformations
84
+ )
85
+ forward_bbox2_points = _apply_forward_transformations(
86
+ bbox2_points, bbox2_transformations
87
+ )
88
+
89
+ # Transform the forward points to the other box's rest pose frame
90
+ forward_bbox1_points_in_rest_bbox2_frame = _apply_backward_transformations(
91
+ forward_bbox1_points, bbox2_transformations
92
+ )
93
+ forward_bbox2_points_in_rest_bbox1_frame = _apply_backward_transformations(
94
+ forward_bbox2_points, bbox1_transformations
95
+ )
96
+
97
+ # Count the number of points in the other bounding box
98
+ num_bbox1_points_in_bbox2 = _count_points_in_box3d(
99
+ forward_bbox1_points_in_rest_bbox2_frame, bbox2_vertices
100
+ )
101
+ num_bbox2_points_in_bbox1 = _count_points_in_box3d(
102
+ forward_bbox2_points_in_rest_bbox1_frame, bbox1_vertices
103
+ )
104
+
105
+ # Compute the IoU
106
+ intersect = (
107
+ bbox1_volume * num_bbox1_points_in_bbox2
108
+ + bbox2_volume * num_bbox2_points_in_bbox1
109
+ ) / 2
110
+ union = bbox1_volume * num_samples + bbox2_volume * num_samples - intersect
111
+ iou = intersect / union
112
+
113
+ giou = iou - (cbbox_volume * num_samples - union) / (cbbox_volume * num_samples) if cbbox_volume > 0 else iou
114
+
115
+ return giou
116
+
117
+
118
+ def sampling_cDist(
119
+ part1,
120
+ part2,
121
+ bbox1_transformations,
122
+ bbox2_transformations,
123
+ ):
124
+ '''
125
+ Compute the centroid distance between two bounding boxes\n
126
+ - bbox1_vertices: the vertices of the first bounding box\n
127
+ - bbox2_vertices: the vertices of the second bounding box\n
128
+ - bbox1_transformations: list of transformations applied to the first bounding box\n
129
+ - bbox2_transformations: list of transformations applied to the second bounding box\n
130
+ '''
131
+
132
+ bbox1_centroid = np.array(part1['aabb']['center'], dtype=np.float32).reshape(1, 3)
133
+ bbox2_centroid = np.array(part2['aabb']['center'], dtype=np.float32).reshape(1, 3)
134
+
135
+ # Transform the centroids
136
+ bbox1_transformed_centroids = _apply_forward_transformations(bbox1_centroid, bbox1_transformations)
137
+ bbox2_transformed_centroids = _apply_forward_transformations(bbox2_centroid, bbox2_transformations)
138
+
139
+ # Compute the centroid distance
140
+ cDist = np.linalg.norm(bbox1_transformed_centroids - bbox2_transformed_centroids)
141
+
142
+ return cDist
metrics/iou.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def _sample_points_in_box3d(bbox_vertices, num_samples):
5
+ """
6
+ Sample points in a axis-aligned 3D bounding box\n
7
+ - bbox_vertices: the vertices of the bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
8
+ - num_samples: the number of samples to use\n
9
+
10
+ Return:\n
11
+ - points: the sampled points in the form: [[x0, y0, z0], [x1, y1, z1], ...]
12
+ """
13
+
14
+ # Compute the bounding box size
15
+ bbox_size = np.max(bbox_vertices, axis=0) - np.min(bbox_vertices, axis=0)
16
+
17
+ # Sample points in the bounding box
18
+ points = np.random.rand(num_samples, 3) * bbox_size + np.min(bbox_vertices, axis=0)
19
+
20
+ return points
21
+
22
+
23
+ def _apply_forward_transformations(points, transformations):
24
+ """
25
+ Apply forward transformations to the points\n
26
+ - points: the points in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
27
+ - transformations: list of transformations to apply\n
28
+
29
+ Return:\n
30
+ - points_transformed: the transformed points in the form: [[x0, y0, z0], [x1, y1, z1], ...]
31
+ """
32
+ if len(transformations) == 0:
33
+ return points
34
+
35
+ # To homogeneous coordinates
36
+ points_transformed = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1)
37
+
38
+ # Apply the transformations one by one in order
39
+ for transformation in transformations:
40
+ if transformation["type"] == "translation":
41
+ points_transformed = np.matmul(
42
+ transformation["matrix"], points_transformed.T
43
+ ).T
44
+
45
+ elif transformation["type"] == "rotation":
46
+ axis_origin = np.append(transformation["rotation_axis_origin"], 0)
47
+ points_recentered = points_transformed - axis_origin
48
+
49
+ points_rotated = np.matmul(transformation["matrix"], points_recentered.T).T
50
+ points_transformed = points_rotated + axis_origin
51
+
52
+ elif transformation["type"] == "plucker":
53
+ points_transformed = np.matmul(
54
+ transformation["matrix"], points_transformed.T
55
+ ).T
56
+
57
+ else:
58
+ raise ValueError(f"Unknown transformation type: {transformation['type']}")
59
+
60
+ return points_transformed[..., :3]
61
+
62
+
63
+ def _apply_backward_transformations(points, transformations):
64
+ """
65
+ Apply backward transformations to the points\n
66
+ - points: the points in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
67
+ - transformations: list of transformations to apply\n
68
+ - The inverse of the transformations are applied in reverse order\n
69
+
70
+ Return:\n
71
+ - points_transformed: the transformed points in the form: [[x0, y0, z0], [x1, y1, z1], ...]
72
+
73
+ Reference: https://mathematica.stackexchange.com/questions/106257/how-do-i-get-the-inverse-of-a-homogeneous-transformation-matrix
74
+ """
75
+ if len(transformations) == 0:
76
+ return points
77
+
78
+ # To homogeneous coordinates
79
+ points_transformed = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1)
80
+
81
+ # Apply the transformations one by one in reverse order
82
+ for transformation in transformations[::-1]:
83
+ inv_transformation = np.eye(4)
84
+ inv_transformation[:3, :3] = transformation["matrix"][:3, :3].T
85
+ inv_transformation[:3, 3] = -np.matmul(
86
+ transformation["matrix"][:3, :3].T, transformation["matrix"][:3, 3]
87
+ )
88
+
89
+ if transformation["type"] == "translation":
90
+ points_transformed = np.matmul(inv_transformation, points_transformed.T).T
91
+
92
+ elif transformation["type"] == "rotation":
93
+ axis_origin = np.append(transformation["rotation_axis_origin"], 0)
94
+ points_recentered = points_transformed - axis_origin
95
+
96
+ points_rotated = np.matmul(inv_transformation, points_recentered.T).T
97
+ points_transformed = points_rotated + axis_origin
98
+
99
+ elif transformation["type"] == "plucker":
100
+ points_transformed = np.matmul(inv_transformation, points_transformed.T).T
101
+
102
+ else:
103
+ raise ValueError(f"Unknown transformation type: {transformation['type']}")
104
+
105
+ return points_transformed[..., :3]
106
+
107
+
108
+ def _count_points_in_box3d(points, bbox_vertices):
109
+ """
110
+ Count the number of points in a 3D bounding box\n
111
+ - points: the points in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
112
+ - bbox_vertices: the vertices of the bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
113
+ - The bbox is assumed to be axis-aligned\n
114
+
115
+ Return:\n
116
+ - num_points_in_bbox: the number of points in the bounding box
117
+ """
118
+
119
+ # Count the number of points in the bounding box
120
+ num_points_in_bbox = np.sum(
121
+ np.all(points >= np.min(bbox_vertices, axis=0), axis=1)
122
+ & np.all(points <= np.max(bbox_vertices, axis=0), axis=1)
123
+ )
124
+
125
+ return num_points_in_bbox
126
+
127
+
128
+ def iou_aabb(bbox1_vertices, bbox2_verices):
129
+ """
130
+ Compute the IoU between two axis-aligned bounding boxes\n
131
+ - bbox1_vertices: the vertices of the first bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
132
+ - bbox2_vertices: the vertices of the second bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]\n
133
+
134
+ Return:\n
135
+ - iou: the IoU between the two bounding boxes
136
+ """
137
+
138
+ # Compute the intersection and union of the two bounding boxes
139
+ min_bbox = np.maximum(np.min(bbox1_vertices, axis=0), np.min(bbox2_verices, axis=0))
140
+ max_bbox = np.minimum(np.max(bbox1_vertices, axis=0), np.max(bbox2_verices, axis=0))
141
+ intersection = np.prod(np.clip(max_bbox - min_bbox, a_min=0, a_max=None))
142
+ union = (
143
+ np.prod(np.max(bbox1_vertices, axis=0) - np.min(bbox1_vertices, axis=0))
144
+ + np.prod(np.max(bbox2_verices, axis=0) - np.min(bbox2_verices, axis=0))
145
+ - intersection
146
+ )
147
+
148
+ # Compute the IoU
149
+ iou = intersection / union if union > 0 else 0
150
+
151
+ return iou
152
+
153
+
154
+ def sampling_iou(
155
+ bbox1_vertices,
156
+ bbox2_vertices,
157
+ bbox1_transformations,
158
+ bbox2_transformations,
159
+ num_samples=10000,
160
+ ):
161
+ """
162
+ Compute the IoU between two bounding boxes\n
163
+ - bbox1_vertices: the vertices of the first bounding box\n
164
+ - bbox2_vertices: the vertices of the second bounding box\n
165
+ - bbox1_transformations: list of transformations applied to the first bounding box\n
166
+ - bbox2_transformations: list of transformations applied to the second bounding box\n
167
+ - num_samples (optional): the number of samples to use per bounding box\n
168
+
169
+ Return:\n
170
+ - iou: the IoU between the two bounding boxes after applying the transformations
171
+ """
172
+ # if no transformations are applied, use the axis-aligned bounding box IoU
173
+ if len(bbox1_transformations) == 0 and len(bbox2_transformations) == 0:
174
+ return iou_aabb(bbox1_vertices, bbox2_vertices)
175
+
176
+ # Volume of the two bounding boxes
177
+ bbox1_volume = np.prod(
178
+ np.max(bbox1_vertices, axis=0) - np.min(bbox1_vertices, axis=0)
179
+ )
180
+ bbox2_volume = np.prod(
181
+ np.max(bbox2_vertices, axis=0) - np.min(bbox2_vertices, axis=0)
182
+ )
183
+
184
+ # Sample points in the two bounding boxes
185
+ bbox1_points = _sample_points_in_box3d(bbox1_vertices, num_samples)
186
+ bbox2_points = _sample_points_in_box3d(bbox2_vertices, num_samples)
187
+
188
+ # Transform the points
189
+ forward_bbox1_points = _apply_forward_transformations(
190
+ bbox1_points, bbox1_transformations
191
+ )
192
+ forward_bbox2_points = _apply_forward_transformations(
193
+ bbox2_points, bbox2_transformations
194
+ )
195
+
196
+ # Transform the forward points to the other box's rest pose frame
197
+ forward_bbox1_points_in_rest_bbox2_frame = _apply_backward_transformations(
198
+ forward_bbox1_points, bbox2_transformations
199
+ )
200
+ forward_bbox2_points_in_rest_bbox1_frame = _apply_backward_transformations(
201
+ forward_bbox2_points, bbox1_transformations
202
+ )
203
+
204
+ # Count the number of points in the other bounding box
205
+ num_bbox1_points_in_bbox2 = _count_points_in_box3d(
206
+ forward_bbox1_points_in_rest_bbox2_frame, bbox2_vertices
207
+ )
208
+ num_bbox2_points_in_bbox1 = _count_points_in_box3d(
209
+ forward_bbox2_points_in_rest_bbox1_frame, bbox1_vertices
210
+ )
211
+
212
+ # Compute the IoU
213
+ intersect = (
214
+ bbox1_volume * num_bbox1_points_in_bbox2
215
+ + bbox2_volume * num_bbox2_points_in_bbox1
216
+ ) / 2
217
+ union = bbox1_volume * num_samples + bbox2_volume * num_samples - intersect
218
+ iou = intersect / union
219
+
220
+ return iou
metrics/iou_cdist.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file computes the IoU-based and centroid-distance-based metrics in a symmetric manner\n
3
+ """
4
+
5
+ import sys, os
6
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
7
+ import numpy as np
8
+ from copy import deepcopy
9
+ from objects.dict_utils import (
10
+ get_base_part_idx,
11
+ get_bbox_vertices,
12
+ remove_handles,
13
+ compute_overall_bbox_size,
14
+ rescale_object,
15
+ find_part_mapping,
16
+ zero_center_object,
17
+ )
18
+ from objects.motions import transform_all_parts
19
+ from metrics.giou import sampling_giou, sampling_cDist
20
+
21
+
22
+ def _get_scores(
23
+ src_dict,
24
+ tgt_dict,
25
+ original_src_bbox_vertices,
26
+ original_tgt_bbox_vertices,
27
+ mapping,
28
+ num_states,
29
+ rotation_fix_range,
30
+ num_samples,
31
+ iou_include_base,
32
+ ):
33
+ # Record the indices of the base parts of the src objects
34
+ src_base_idx = get_base_part_idx(src_dict)
35
+
36
+ # Compute the sum of IoU between the generated object and the candidate object over a number of articulation states
37
+ num_parts_in_src = len(src_dict["diffuse_tree"])
38
+ iou_per_part_and_state = np.zeros((num_parts_in_src, num_states), dtype=np.float32)
39
+ cDist_per_part_and_state = np.zeros(
40
+ (num_parts_in_src, num_states), dtype=np.float32
41
+ )
42
+
43
+ states = np.linspace(0, 1, num_states)
44
+ for state_idx, state in enumerate(states):
45
+
46
+ # Get a fresh copy of the bounding box vertices in rest pose
47
+ src_bbox_vertices = deepcopy(original_src_bbox_vertices)
48
+ tgt_bbox_vertices = deepcopy(original_tgt_bbox_vertices)
49
+
50
+ # Transform the objects to the current state using the joints
51
+ src_part_transfomrations = transform_all_parts(
52
+ src_bbox_vertices,
53
+ src_dict,
54
+ state,
55
+ rotation_fix_range=rotation_fix_range,
56
+ )
57
+
58
+
59
+ tgt_part_transfomrations = transform_all_parts(
60
+ tgt_bbox_vertices,
61
+ tgt_dict,
62
+ state,
63
+ rotation_fix_range=rotation_fix_range,
64
+ )
65
+
66
+ # Compute the IoU between the two objects using the transformed bounding boxes and the part mapping
67
+ for src_part_idx in range(num_parts_in_src):
68
+
69
+ # Get the index of the corresponding part in the candidate object
70
+ tgt_part_idx = int(mapping[src_part_idx, 0])
71
+
72
+ # Always use a fresh copy of the bounding box vertices in rest pose in case dry_run=False is incorrectly set
73
+ src_part_bbox_vertices = deepcopy(original_src_bbox_vertices)[src_part_idx]
74
+ tgt_part_bbox_vertices = deepcopy(original_tgt_bbox_vertices)[tgt_part_idx]
75
+
76
+ # Compute the sampling-based IoU between the two parts
77
+
78
+ iou_per_part_and_state[src_part_idx, state_idx] = sampling_giou(
79
+ src_part_bbox_vertices,
80
+ tgt_part_bbox_vertices,
81
+ src_part_transfomrations[src_part_idx],
82
+ tgt_part_transfomrations[tgt_part_idx],
83
+ num_samples=num_samples,
84
+ )
85
+ # Compute the centriod distance between the two matched parts
86
+ cDist_per_part_and_state[src_part_idx, state_idx] = sampling_cDist(
87
+ src_dict["diffuse_tree"][src_part_idx],
88
+ tgt_dict["diffuse_tree"][tgt_part_idx],
89
+ src_part_transfomrations[src_part_idx],
90
+ tgt_part_transfomrations[tgt_part_idx],
91
+ )
92
+
93
+ # IoU and cDist at the resting state
94
+ per_part_iou_avg_at_rest = iou_per_part_and_state[:, 0]
95
+ per_part_cDist_avg_at_rest = cDist_per_part_and_state[:, 0]
96
+
97
+ # Average the IoU over the states
98
+ per_part_iou_avg_over_states = np.sum(iou_per_part_and_state, axis=1) / num_states
99
+ # Average the cDist over the states
100
+ per_part_cDist_avg_over_states = (
101
+ np.sum(cDist_per_part_and_state, axis=1) / num_states
102
+ )
103
+
104
+ # Remove the base part if specified
105
+ if not iou_include_base:
106
+ per_part_iou_avg_over_states = np.delete(
107
+ per_part_iou_avg_over_states, src_base_idx
108
+ )
109
+ per_part_iou_avg_at_rest = np.delete(per_part_iou_avg_at_rest, src_base_idx)
110
+ per_part_cDist_avg_over_states = np.delete(
111
+ per_part_cDist_avg_over_states, src_base_idx
112
+ )
113
+ per_part_cDist_avg_at_rest = np.delete(per_part_cDist_avg_at_rest, src_base_idx)
114
+
115
+ aid_iou = float(np.mean(per_part_iou_avg_over_states)) if len(per_part_iou_avg_over_states) > 0 else 0
116
+ aid_cdist = float(np.mean(per_part_cDist_avg_over_states)) if len(per_part_cDist_avg_over_states) > 0 else 1
117
+ rid_iou = float(np.mean(per_part_iou_avg_at_rest)) if len(per_part_iou_avg_at_rest) > 0 else 0
118
+ rid_cdist = float(np.mean(per_part_cDist_avg_at_rest)) if len(per_part_cDist_avg_at_rest) > 0 else 1
119
+
120
+ return {
121
+ "AS-IoU": 1. - aid_iou,
122
+ "AS-cDist": aid_cdist,
123
+ "RS-IoU": 1. - rid_iou,
124
+ "RS-cDist": rid_cdist
125
+ }
126
+
127
+
128
+ def IoU_cDist(
129
+ gen_obj_dict,
130
+ gt_obj_dict,
131
+ num_states=2,
132
+ compare_handles=False,
133
+ iou_include_base=False,
134
+ rotation_fix_range=True,
135
+ num_samples=10000,
136
+ ):
137
+ """
138
+ Compute the IoU-based and centroid-distance-based metrics\n
139
+ This metric is the average sum of IoU between parts in the two objects over the sampled articulation states and at the resting state\n
140
+
141
+ - gen_obj_dict: the dictionary of the generated object\n
142
+ - gt_obj_dict: the dictionary of the gt object\n
143
+ - num_states: the number of articulation states to compute the metric\n
144
+ - compare_handles (optional): whether to compare the handles\n
145
+ - iou_include_base (optional): whether to include the base part in the IoU computation\n
146
+ - rotation_fix_range (optional): whether to fix the rotation range to 90 degrees for revolute joints\n
147
+ - num_samples (optional): the number of samples to use\n
148
+
149
+ Return:\n
150
+ - scores: a dictionary of the computed scores\n
151
+ - "AS-IoU": the average IoU over the articulation states\n
152
+ - "AS-cDist": the average centroid distance over the articulation states\n
153
+ - "RS-IoU": the average IoU at the resting state\n
154
+ - "RS-cDist": the average centroid distance at the resting state\n
155
+ """
156
+ # Make copies of the dictionaries to avoid modifying the original dictionaries
157
+ gen_dict = deepcopy(gen_obj_dict)
158
+ gt_dict = deepcopy(gt_obj_dict)
159
+
160
+ # Strip the handles from the object if not comparing them
161
+ if not compare_handles:
162
+ gen_dict = remove_handles(gen_dict)
163
+ gt_dict = remove_handles(gt_dict)
164
+
165
+ # Zero center the objects
166
+ zero_center_object(gen_dict)
167
+ zero_center_object(gt_dict)
168
+
169
+ # scale the generated object as a whole to match the size of the gt object
170
+ gen_bbox_size = compute_overall_bbox_size(gen_dict)
171
+ gt_bbox_size = compute_overall_bbox_size(gt_dict)
172
+ scale_factor = gt_bbox_size / gen_bbox_size
173
+ rescale_object(gen_dict, scale_factor)
174
+
175
+ mapping_gen2gt = find_part_mapping(gen_dict, gt_dict, use_hungarian=True)
176
+ # for i in range(mapping_gen2gt.shape[0]):
177
+ # if mapping_gen2gt[i][0] < 100:
178
+ # gen_dict['diffuse_tree'][i]["parent"] = gt_dict['diffuse_tree'][int(mapping_gen2gt[i][0])]["parent"]
179
+ # gen_dict['diffuse_tree'][i]["children"] = gt_dict['diffuse_tree'][int(mapping_gen2gt[i][0])]["children"]
180
+ # gen_dict['diffuse_tree'][i]["id"] = gt_dict['diffuse_tree'][int(mapping_gen2gt[i][0])]["id"]
181
+ # mapping_gen2gt = find_part_mapping(gen_dict, gt_dict, use_hungarian=True)
182
+ mapping_gt2gen = find_part_mapping(gt_dict, gen_dict, use_hungarian=True)
183
+
184
+ # Save the original bounding box vertices in rest pose
185
+ original_gen_bbox_vertices = np.array(
186
+ [get_bbox_vertices(gen_dict, i) for i in range(len(gen_dict["diffuse_tree"]))],
187
+ dtype=np.float32,
188
+ )
189
+ original_gt_bbox_vertices = np.array(
190
+ [get_bbox_vertices(gt_dict, i) for i in range(len(gt_dict["diffuse_tree"]))],
191
+ dtype=np.float32,
192
+ )
193
+ # import ipdb
194
+ # ipdb.set_trace()
195
+ scores_gen2gt = _get_scores(
196
+ gen_dict,
197
+ gt_dict,
198
+ original_gen_bbox_vertices,
199
+ original_gt_bbox_vertices,
200
+ mapping_gen2gt,
201
+ num_states,
202
+ rotation_fix_range,
203
+ num_samples,
204
+ iou_include_base,
205
+ )
206
+
207
+ scores_gt2gen = _get_scores(
208
+ gt_dict,
209
+ gen_dict,
210
+ original_gt_bbox_vertices,
211
+ original_gen_bbox_vertices,
212
+ mapping_gt2gen,
213
+ num_states,
214
+ rotation_fix_range,
215
+ num_samples,
216
+ iou_include_base,
217
+ )
218
+
219
+
220
+ scores = {
221
+ "AS-IoU": (scores_gen2gt["AS-IoU"] + scores_gt2gen["AS-IoU"]) / 2,
222
+ "AS-cDist": (scores_gen2gt["AS-cDist"] + scores_gt2gen["AS-cDist"]) / 2,
223
+ "RS-IoU": (scores_gen2gt["RS-IoU"] + scores_gt2gen["RS-IoU"]) / 2,
224
+ "RS-cDist": (scores_gen2gt["RS-cDist"] + scores_gt2gen["RS-cDist"]) / 2,
225
+ }
226
+
227
+ return scores
models/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models = {}
2
+
3
+
4
+ def register(name):
5
+ def decorator(cls):
6
+ models[name] = cls
7
+ return cls
8
+
9
+ return decorator
10
+
11
+
12
+ def make(name, config):
13
+ if name == 'model_B9':
14
+ name = 'denoiser_singapo'
15
+ model = models[name](config)
16
+ return model
17
+
18
+
19
+ from . import denoiser
models/denoiser.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
3
+ import torch
4
+ import models
5
+ from torch import nn
6
+ from diffusers.models.attention import Attention, FeedForward
7
+ from models.utils import (
8
+ PEmbeder,
9
+ FinalLayer,
10
+ VisAttnProcessor,
11
+ MyAdaLayerNormZero
12
+ )
13
+
14
+ class RAPCrossAttnBlock(nn.Module):
15
+ def __init__(self, dim, num_layers, num_heads, head_dim, dropout=0.0, img_emb_dims=None):
16
+ super().__init__()
17
+ self.layers = nn.ModuleList([
18
+ Attention(
19
+ query_dim=dim,
20
+ cross_attention_dim=dim,
21
+ heads=num_heads,
22
+ dim_head=head_dim,
23
+ dropout=dropout,
24
+ bias=True,
25
+ cross_attention_norm="layer_norm",
26
+ processor=VisAttnProcessor(),
27
+ )
28
+ for _ in range(num_layers)
29
+ ])
30
+ self.norms = nn.ModuleList([
31
+ nn.LayerNorm(dim) for _ in range(num_layers)
32
+ ])
33
+
34
+ img_emb_layers = []
35
+ for i in range(len(img_emb_dims) - 1):
36
+ img_emb_layers.append(nn.Linear(img_emb_dims[i], img_emb_dims[i + 1]))
37
+ img_emb_layers.append(nn.LeakyReLU(inplace=True))
38
+ img_emb_layers.pop(-1)
39
+ self.img_emb = nn.Sequential(*img_emb_layers)
40
+ self.init_img_emb_weights()
41
+
42
+ def init_img_emb_weights(self):
43
+ for m in self.img_emb.modules():
44
+ if isinstance(m, nn.Linear):
45
+ nn.init.kaiming_normal_(m.weight, mode="fan_in")
46
+ if m.bias is not None:
47
+ nn.init.constant_(m.bias, 0)
48
+
49
+ def forward(self, img_first, img_second):
50
+ """
51
+ Inputs:
52
+ img_first: (B, Np, D)
53
+ img_second: (B, Np, D)
54
+ Output:
55
+ fused_feat: (B, Np, D)
56
+ """
57
+ img_first = self.img_emb(img_first)
58
+ img_second = self.img_emb(img_second)
59
+ fused = img_second
60
+ for norm, attn in zip(self.norms, self.layers):
61
+ normed = norm(fused)
62
+ delta, _ = attn(normed, encoder_hidden_states=img_first, attention_mask=None)
63
+ fused = fused + delta # residual connection
64
+ return fused
65
+
66
+ class Attn_Block(nn.Module):
67
+ def __init__(
68
+ self,
69
+ dim: int,
70
+ num_attention_heads: int,
71
+ attention_head_dim: int,
72
+ dropout=0.0,
73
+ activation_fn: str = "geglu",
74
+ num_embeds_ada_norm: int = None,
75
+ attention_bias: bool = False,
76
+ norm_elementwise_affine: bool = True,
77
+ final_dropout: bool = False,
78
+ class_dropout_prob: float = 0.0, # for classifier-free
79
+ img_emb_dims=None,
80
+
81
+ ):
82
+ super().__init__()
83
+
84
+ self.norm1 = MyAdaLayerNormZero(dim, num_embeds_ada_norm, class_dropout_prob)
85
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
86
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
87
+ self.norm4 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
88
+ self.norm5 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
89
+ self.norm6 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
90
+
91
+ self.local_attn = Attention(
92
+ query_dim=dim,
93
+ heads=num_attention_heads,
94
+ dim_head=attention_head_dim,
95
+ dropout=dropout,
96
+ bias=attention_bias,
97
+ )
98
+
99
+ self.global_attn = Attention(
100
+ query_dim=dim,
101
+ heads=num_attention_heads,
102
+ dim_head=attention_head_dim,
103
+ dropout=dropout,
104
+ bias=attention_bias,
105
+ )
106
+
107
+ self.graph_attn = Attention(
108
+ query_dim=dim,
109
+ heads=num_attention_heads,
110
+ dim_head=attention_head_dim,
111
+ dropout=dropout,
112
+ bias=attention_bias,
113
+ )
114
+
115
+ self.img_attn = Attention(
116
+ query_dim=dim,
117
+ cross_attention_dim=dim,
118
+ heads=num_attention_heads,
119
+ dim_head=attention_head_dim,
120
+ dropout=dropout,
121
+ bias=attention_bias,
122
+ cross_attention_norm="layer_norm",
123
+ processor=VisAttnProcessor(), # to be removed for release model
124
+ )
125
+
126
+ self.img_attn_second = Attention(
127
+ query_dim=dim,
128
+ cross_attention_dim=dim,
129
+ heads=num_attention_heads,
130
+ dim_head=attention_head_dim,
131
+ dropout=dropout,
132
+ bias=attention_bias,
133
+ cross_attention_norm="layer_norm",
134
+ processor=VisAttnProcessor(), # to be removed for release model
135
+ )
136
+
137
+ self.ff = FeedForward(
138
+ dim,
139
+ dropout=dropout,
140
+ activation_fn=activation_fn,
141
+ final_dropout=final_dropout,
142
+ )
143
+
144
+ # image embedding layers
145
+ layers = []
146
+ for i in range(len(img_emb_dims) - 1):
147
+ layers.append(nn.Linear(img_emb_dims[i], img_emb_dims[i + 1]))
148
+ layers.append(nn.LeakyReLU(inplace=True))
149
+ layers.pop(-1)
150
+ self.img_emb = nn.Sequential(*layers)
151
+ self.init_img_emb_weights()
152
+
153
+ def init_img_emb_weights(self):
154
+ for m in self.img_emb.modules():
155
+ if isinstance(m, nn.Linear):
156
+ nn.init.kaiming_normal_(m.weight, mode="fan_in")
157
+ if m.bias is not None:
158
+ nn.init.constant_(m.bias, 0)
159
+
160
+ def forward(
161
+ self,
162
+ hidden_states,
163
+ img_patches,
164
+ fuse_feat,
165
+ pad_mask,
166
+ attr_mask,
167
+ graph_mask,
168
+ timestep,
169
+ class_labels,
170
+ label_free=False,
171
+ ):
172
+ # image patches embedding
173
+ img_emb = self.img_emb(img_patches)
174
+
175
+ # adaptive normalization, taken timestep and class_labels as input condition
176
+ norm_hidden_states, gate_1, shift_mlp, scale_mlp, gate_mlp, gate_2, gate_3 = (
177
+ self.norm1(
178
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype,
179
+ label_free=label_free
180
+ )
181
+ )
182
+
183
+ # local attribute self-attention
184
+ attr_out = self.local_attn(norm_hidden_states, attention_mask=attr_mask)
185
+ attr_out = gate_1.unsqueeze(1) * attr_out
186
+ hidden_states = hidden_states + attr_out
187
+
188
+ # global attribute self-attention
189
+ norm_hidden_states = self.norm2(hidden_states)
190
+ global_out = self.global_attn(norm_hidden_states, attention_mask=pad_mask)
191
+ global_out = gate_2.unsqueeze(1) * global_out
192
+ hidden_states = hidden_states + global_out
193
+
194
+ # graph relation self-attention
195
+ norm_hidden_states = self.norm3(hidden_states)
196
+ graph_out = self.graph_attn(norm_hidden_states, attention_mask=graph_mask)
197
+ graph_out = gate_3.unsqueeze(1) * graph_out
198
+ hidden_states = hidden_states + graph_out
199
+
200
+ img_first, img_second = img_emb.chunk(2, dim=1)
201
+
202
+ # cross attention with image patches
203
+ norm_hidden_states = self.norm4(hidden_states)
204
+ B, Na, D = norm_hidden_states.shape
205
+ Np = img_first.shape[1] # number of image patches
206
+ mode_num = Na // 32
207
+ reshaped = norm_hidden_states.reshape(B, Na // mode_num, mode_num, D)
208
+ bboxes = reshaped[:, :, 0, :] # (B, K, D)
209
+ # cross attention between bbox attributes and image patches
210
+ bbox_img_out, bbox_cross_attn_map = self.img_attn(
211
+ bboxes,
212
+ encoder_hidden_states=img_first,
213
+ attention_mask=None,
214
+ ) # cross_attn_map: (B, n_head, K, Np)
215
+
216
+ # to reshape the cross_attn_map back to (B, n_head, Na*5, Np), reduntant for other attributes, fix later
217
+ # cross_attn_map_reshape = torch.zeros(size=(B, bbox_cross_attn_map.shape[1], Na // mode_num, mode_num, Np), device=bbox_cross_attn_map.device)
218
+ # cross_attn_map_reshape[:, :, :, 0, :] = bbox_cross_attn_map
219
+ # cross_attn_map = cross_attn_map_reshape.reshape(B, bbox_cross_attn_map.shape[1], Na, Np)
220
+
221
+ # assemble the output of cross attention with bbox attributes and other attributes
222
+ img_out = torch.empty(size=(B, Na // mode_num, mode_num, D), device=hidden_states.device, dtype=hidden_states.dtype)
223
+ img_out[:, :, 0, :] = bbox_img_out
224
+ img_out[:, :, 1:, :] = reshaped[:, :, 1:, :]
225
+ img_out = img_out.reshape(B, Na, D)
226
+ hidden_states = hidden_states + img_out
227
+
228
+ norm_hidden_states = self.norm6(hidden_states)
229
+ B, Na, D = norm_hidden_states.shape
230
+ Np = img_second.shape[1] # number of image patches
231
+ mode_num = Na // 32
232
+ reshaped = norm_hidden_states.reshape(B, Na // mode_num, mode_num, D)
233
+ joints = reshaped # (B, K, 4, D)
234
+ joints = joints.reshape(B, Na // mode_num * 5, D)
235
+ # cross attention between bbox attributes and image patches
236
+ joint_img_out, bbox_cross_attn_map = self.img_attn_second(
237
+ joints,
238
+ encoder_hidden_states=fuse_feat,
239
+ attention_mask=None,
240
+ ) # cross_attn_map: (B, n_head, K*4, Np)
241
+
242
+ # to reshape the cross_attn_map back to (B, n_head, Na*5, Np), reduntant for other attributes, fix later
243
+ # cross_attn_map_reshape = torch.zeros(size=(B, bbox_cross_attn_map.shape[1], Na // mode_num, mode_num, Np), device=bbox_cross_attn_map.device)
244
+ # cross_attn_map_reshape[:, :, :, 1:5, :] = bbox_cross_attn_map.reshape(
245
+ # B, bbox_cross_attn_map.shape[1], Na // mode_num, 4, Np
246
+ # )
247
+ # cross_attn_map = cross_attn_map_reshape.reshape(B, bbox_cross_attn_map.shape[1], Na, Np)
248
+
249
+ # assemble the output of cross attention with bbox attributes and other attributes
250
+ img_out = torch.empty(size=(B, Na // mode_num, mode_num, D), device=hidden_states.device, dtype=hidden_states.dtype)
251
+ img_out = joint_img_out.reshape(B, Na // mode_num, 5, D)
252
+ img_out = img_out.reshape(B, Na, D)
253
+ hidden_states = hidden_states + img_out
254
+
255
+ # feed-forward
256
+ norm_hidden_states = self.norm5(hidden_states)
257
+ norm_hidden_states = (
258
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
259
+ )
260
+ ff_output = self.ff(norm_hidden_states)
261
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
262
+ hidden_states = ff_output + hidden_states
263
+
264
+ return hidden_states
265
+
266
+ @models.register("denoiser")
267
+ class Denoiser(nn.Module):
268
+ """
269
+ Denoiser based on CAGE's attribute attention block + our ICA module, with 4 sequential attentions: LA -> GA -> GRA -> ICA
270
+ Different image adapters for each layer.
271
+ The image cross attention is with key-padding masks (object mask, part mask)
272
+ *** The ICA only applies to the bbox attributes, not other attributes such as motion params.***
273
+ """
274
+
275
+ def __init__(self, hparams):
276
+ super().__init__()
277
+ self.hparams = hparams
278
+ self.K = self.hparams.get("K", 32)
279
+
280
+ in_ch = hparams.in_ch
281
+ attn_dim = hparams.attn_dim
282
+ mid_dim = attn_dim // 2
283
+ n_head = hparams.n_head
284
+ head_dim = attn_dim // n_head
285
+ num_embeds_ada_norm = 6 * attn_dim
286
+
287
+ # embedding layers for different node attributes
288
+ self.aabb_emb = nn.Sequential(
289
+ nn.Linear(in_ch, mid_dim),
290
+ nn.ReLU(inplace=True),
291
+ nn.Linear(mid_dim, attn_dim),
292
+ )
293
+ self.jaxis_emb = nn.Sequential(
294
+ nn.Linear(in_ch, mid_dim),
295
+ nn.ReLU(inplace=True),
296
+ nn.Linear(mid_dim, attn_dim),
297
+ )
298
+ self.range_emb = nn.Sequential(
299
+ nn.Linear(in_ch, mid_dim),
300
+ nn.ReLU(inplace=True),
301
+ nn.Linear(mid_dim, attn_dim),
302
+ )
303
+ self.label_emb = nn.Sequential(
304
+ nn.Linear(in_ch, mid_dim),
305
+ nn.ReLU(inplace=True),
306
+ nn.Linear(mid_dim, attn_dim),
307
+ )
308
+ self.jtype_emb = nn.Sequential(
309
+ nn.Linear(in_ch, mid_dim),
310
+ nn.ReLU(inplace=True),
311
+ nn.Linear(mid_dim, attn_dim),
312
+ )
313
+ # self.node_type_emb = nn.Sequential(
314
+ # nn.Linear(in_ch, mid_dim),
315
+ # nn.ReLU(inplace=True),
316
+ # nn.Linear(mid_dim, attn_dim),
317
+ # )
318
+ # positional encoding for nodes and attributes
319
+ self.pe_node = PEmbeder(self.K, attn_dim)
320
+ self.pe_attr = PEmbeder(self.hparams.mode_num, attn_dim)
321
+
322
+ # attention layers
323
+ self.attn_layers = nn.ModuleList(
324
+ [
325
+ Attn_Block(
326
+ dim=attn_dim,
327
+ num_attention_heads=n_head,
328
+ attention_head_dim=head_dim,
329
+ class_dropout_prob=hparams.get("cat_drop_prob", 0.0),
330
+ dropout=hparams.dropout,
331
+ activation_fn="geglu",
332
+ num_embeds_ada_norm=num_embeds_ada_norm,
333
+ attention_bias=False,
334
+ norm_elementwise_affine=True,
335
+ final_dropout=False,
336
+ img_emb_dims=hparams.get("img_emb_dims", None),
337
+ )
338
+ for d in range(hparams.n_layers)
339
+ ]
340
+ )
341
+
342
+ self.image_interaction = RAPCrossAttnBlock(
343
+ dim=attn_dim,
344
+ num_layers=6,
345
+ num_heads=n_head,
346
+ head_dim=head_dim,
347
+ dropout=hparams.dropout,
348
+ img_emb_dims=hparams.get("img_emb_dims", None),
349
+ )
350
+
351
+ self.final_layer = FinalLayer(attn_dim, in_ch)
352
+
353
+ def forward(
354
+ self,
355
+ x,
356
+ cat,
357
+ timesteps,
358
+ feat,
359
+ key_pad_mask=None,
360
+ graph_mask=None,
361
+ attr_mask=None,
362
+ label_free=False,
363
+ ):
364
+ B = x.shape[0]
365
+ x = x.view(B, self.K, 5 * 6)
366
+
367
+ # embedding layers for different attributes
368
+ x_aabb = self.aabb_emb(x[..., :6])
369
+ x_jtype = self.jtype_emb(x[..., 6:12])
370
+ x_jaxis = self.jaxis_emb(x[..., 12:18])
371
+ x_range = self.range_emb(x[..., 18:24])
372
+ x_label = self.label_emb(x[..., 24:30])
373
+ # x_node_type = self.node_type_emb(x[..., 30:36])
374
+
375
+ # concatenate all attribute embeddings
376
+ x_ = torch.cat(
377
+ [x_aabb, x_jtype, x_jaxis, x_range, x_label], dim=2
378
+ ) # (B, K, 6*attn_dim)
379
+ x_ = x_.view(B, self.K * self.hparams.mode_num, self.hparams.attn_dim)
380
+
381
+ # positional encoding for nodes and attributes
382
+ idx_attr = torch.tensor(
383
+ [0, 1, 2, 3, 4], device=x.device, dtype=torch.long
384
+ ).repeat(self.K)
385
+ idx_node = torch.arange(
386
+ self.K, device=x.device, dtype=torch.long
387
+ ).repeat_interleave(self.hparams.mode_num)
388
+ x_ = self.pe_attr(self.pe_node(x_, idx=idx_node), idx=idx_attr)
389
+
390
+
391
+ # init tensor to store attention maps
392
+ Np = feat.shape[1]
393
+
394
+ img_first, img_second = feat.chunk(2, dim=1)
395
+ fused_img_feat = self.image_interaction(img_first, img_second) # (B, Np, D)
396
+
397
+ # attention layers
398
+ for i, attn_layer in enumerate(self.attn_layers):
399
+ x_ = attn_layer(
400
+ hidden_states=x_,
401
+ img_patches=feat,
402
+ fuse_feat=fused_img_feat,
403
+ timestep=timesteps,
404
+ class_labels=cat,
405
+ pad_mask=key_pad_mask,
406
+ graph_mask=graph_mask,
407
+ attr_mask=attr_mask,
408
+ label_free=label_free,
409
+ )
410
+
411
+ y = self.final_layer(x_, timesteps, cat)
412
+ return {
413
+ 'noise_pred': y,
414
+ 'attn_maps': None,
415
+ }
models/utils.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Optional
4
+ from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding
5
+
6
+ class FinalLayer(nn.Module):
7
+ """
8
+ Final layer of the diffusion model that outputs the final logits.
9
+ """
10
+ def __init__(self, in_ch, out_ch=None, dropout=0.0):
11
+ super().__init__()
12
+ out_ch = in_ch if out_ch is None else out_ch
13
+ self.linear = nn.Linear(in_ch, out_ch)
14
+ self.norm = AdaLayerNormTC(in_ch, 2 * in_ch, dropout)
15
+
16
+ def forward(self, x, t, cond=None):
17
+ assert cond is not None
18
+ x = self.norm(x, t, cond)
19
+ x = self.linear(x)
20
+ return x
21
+
22
+
23
+ class AdaLayerNormTC(nn.Module):
24
+ """
25
+ Norm layer modified to incorporate timestep and condition embeddings.
26
+ """
27
+
28
+ def __init__(self, embedding_dim, num_embeddings, dropout):
29
+ super().__init__()
30
+ self.emb = CombinedTimestepLabelEmbeddings(
31
+ num_embeddings, embedding_dim, dropout
32
+ )
33
+ self.silu = nn.SiLU()
34
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
35
+ self.norm = nn.LayerNorm(
36
+ embedding_dim, elementwise_affine=False, eps=torch.finfo(torch.float16).eps
37
+ )
38
+
39
+ def forward(self, x, timestep, cond):
40
+ emb = self.linear(self.silu(self.emb(timestep, cond, hidden_dtype=None)))
41
+ scale, shift = torch.chunk(emb, 2, dim=1)
42
+ x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
43
+ return x
44
+
45
+
46
+ class PEmbeder(nn.Module):
47
+ """
48
+ Positional embedding layer.
49
+ """
50
+ def __init__(self, vocab_size, d_model):
51
+ super().__init__()
52
+ self.embed = nn.Embedding(vocab_size, d_model)
53
+ self._init_embeddings()
54
+
55
+ def _init_embeddings(self):
56
+ nn.init.kaiming_normal_(self.embed.weight, mode="fan_in")
57
+
58
+ def forward(self, x, idx=None):
59
+ if idx is None:
60
+ idx = torch.arange(x.shape[1], device=x.device, dtype=torch.long)
61
+ return x + self.embed(idx)
62
+
63
+ class CombinedTimestepLabelEmbeddings(nn.Module):
64
+ '''Modified from diffusers.models.embeddings.CombinedTimestepLabelEmbeddings'''
65
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
66
+ super().__init__()
67
+
68
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
69
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
70
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
71
+
72
+ def forward(self, timestep, class_labels, hidden_dtype=None, label_free=False):
73
+ timesteps_proj = self.time_proj(timestep)
74
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
75
+ force_drop_ids = None # training mode
76
+ if label_free: # inference mode, force_drop_ids is set to all ones to be dropped in class_embedder
77
+ force_drop_ids = torch.ones_like(class_labels, dtype=torch.bool, device=class_labels.device)
78
+ class_labels = self.class_embedder(class_labels, force_drop_ids) # (N, D)
79
+ conditioning = timesteps_emb + class_labels # (N, D)
80
+ return conditioning
81
+
82
+
83
+ class MyAdaLayerNormZero(nn.Module):
84
+ """
85
+ Adaptive layer norm zero (adaLN-Zero), borrowed from diffusers.models.attention.AdaLayerNormZero.
86
+ Extended to incorporate scale parameters (gate_2, gate_3) for intermidate attention layers.
87
+ """
88
+
89
+ def __init__(self, embedding_dim, num_embeddings, class_dropout_prob):
90
+ super().__init__()
91
+
92
+ self.emb = CombinedTimestepLabelEmbeddings(
93
+ num_embeddings, embedding_dim, class_dropout_prob
94
+ )
95
+ self.silu = nn.SiLU()
96
+ self.linear = nn.Linear(embedding_dim, 8 * embedding_dim, bias=True)
97
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
98
+
99
+ def forward(self, x, timestep, class_labels, hidden_dtype=None, label_free=False):
100
+ emb_t_cls = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype, label_free=label_free)
101
+ emb = self.linear(self.silu(emb_t_cls))
102
+ (
103
+ shift_msa,
104
+ scale_msa,
105
+ gate_msa,
106
+ shift_mlp,
107
+ scale_mlp,
108
+ gate_mlp,
109
+ gate_2,
110
+ gate_3,
111
+ ) = emb.chunk(8, dim=1)
112
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
113
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_2, gate_3
114
+
115
+
116
+ class VisAttnProcessor:
117
+ r"""
118
+ This code is adapted from diffusers.models.attention_processor.AttnProcessor.
119
+ Used for visualizing the attention maps when testing, NOT for training.
120
+ """
121
+
122
+ def __call__(
123
+ self,
124
+ attn,
125
+ hidden_states,
126
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
127
+ attention_mask: Optional[torch.FloatTensor] = None,
128
+ temb: Optional[torch.FloatTensor] = None,
129
+ *args,
130
+ **kwargs,
131
+ ) -> torch.Tensor:
132
+ # Removed
133
+ # if len(args) > 0 or kwargs.get("scale", None) is not None:
134
+ # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
135
+ # deprecate("scale", "1.0.0", deprecation_message)
136
+
137
+ residual = hidden_states
138
+
139
+ if attn.spatial_norm is not None:
140
+ hidden_states = attn.spatial_norm(hidden_states, temb)
141
+
142
+ input_ndim = hidden_states.ndim
143
+
144
+ if input_ndim == 4:
145
+ batch_size, channel, height, width = hidden_states.shape
146
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
147
+
148
+ batch_size, sequence_length, _ = (
149
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
150
+ )
151
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
152
+
153
+ if attn.group_norm is not None:
154
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
155
+
156
+ query = attn.to_q(hidden_states)
157
+
158
+ if encoder_hidden_states is None:
159
+ encoder_hidden_states = hidden_states
160
+ elif attn.norm_cross:
161
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
162
+
163
+ key = attn.to_k(encoder_hidden_states)
164
+ value = attn.to_v(encoder_hidden_states)
165
+
166
+ query = attn.head_to_batch_dim(query) # (40, 160, 16)
167
+ key = attn.head_to_batch_dim(key) # (40, 256, 16)
168
+ value = attn.head_to_batch_dim(value) # (40, 256, 16)
169
+
170
+ if attention_mask is not None:
171
+ if attention_mask.dtype == torch.bool:
172
+ attn_mask = torch.zeros_like(attention_mask, dtype=query.dtype, device=query.device)
173
+ attn_mask = attn_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
174
+ else:
175
+ attn_mask = attention_mask
176
+ assert attn_mask.dtype == query.dtype, f"query and attention_mask must have the same dtype, but got {query.dtype} and {attention_mask.dtype}."
177
+ else:
178
+ attn_mask = None
179
+ attention_probs = attn.get_attention_scores(query, key, attn_mask) # (40, 160, 256)
180
+ hidden_states = torch.bmm(attention_probs, value) # (40, 160, 16)
181
+ hidden_states = attn.batch_to_head_dim(hidden_states)
182
+
183
+ # linear proj
184
+ hidden_states = attn.to_out[0](hidden_states)
185
+ # dropout
186
+ hidden_states = attn.to_out[1](hidden_states)
187
+
188
+ if input_ndim == 4:
189
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
190
+
191
+ if attn.residual_connection:
192
+ hidden_states = hidden_states + residual
193
+
194
+ hidden_states = hidden_states / attn.rescale_output_factor
195
+
196
+ attention_probs = attention_probs.reshape(batch_size, attn.heads, query.shape[1], sequence_length)
197
+
198
+ return hidden_states, attention_probs
199
+
my_utils/__init__.py ADDED
File without changes
my_utils/callbacks.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
3
+ import torch
4
+ from my_utils.misc import dump_config
5
+ from lightning.pytorch.callbacks.callback import Callback
6
+ from lightning.pytorch.utilities.rank_zero import rank_zero_only
7
+
8
+ class ConfigSnapshotCallback(Callback):
9
+ def __init__(self, config):
10
+ super().__init__()
11
+ self.config = config
12
+
13
+ def setup(self, trainer, pl_module, stage) -> None:
14
+ self.savedir = os.path.join(pl_module.hparams.exp_dir, 'config')
15
+
16
+ @rank_zero_only
17
+ def save_config_snapshot(self):
18
+ os.makedirs(self.savedir, exist_ok=True)
19
+ dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config)
20
+
21
+ def on_fit_start(self, trainer, pl_module):
22
+ self.save_config_snapshot()
23
+
24
+
25
+ class GPUCacheCleanCallback(Callback):
26
+ def on_train_batch_start(self, *args, **kwargs):
27
+ torch.cuda.empty_cache()
28
+
29
+ def on_validation_batch_start(self, *args, **kwargs):
30
+ torch.cuda.empty_cache()
31
+
32
+ def on_test_batch_start(self, *args, **kwargs):
33
+ torch.cuda.empty_cache()
34
+
35
+ def on_predict_batch_start(self, *args, **kwargs):
36
+ torch.cuda.empty_cache()
my_utils/lr_schedulers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code copied from lightning-bolts
3
+ """
4
+ import math
5
+ import warnings
6
+ from typing import List
7
+ from torch.optim import Optimizer
8
+ from torch.optim.lr_scheduler import _LRScheduler
9
+
10
+
11
+ class LinearWarmupCosineAnnealingLR(_LRScheduler):
12
+ """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and
13
+ base_lr followed by a cosine annealing schedule between base_lr and eta_min.
14
+
15
+ .. warning::
16
+ It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
17
+ after each iteration as calling it after each epoch will keep the starting lr at
18
+ warmup_start_lr for the first epoch which is 0 in most cases.
19
+
20
+ .. warning::
21
+ passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
22
+ It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
23
+ :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
24
+ epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
25
+ train and validation methods.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ optimizer: Optimizer,
31
+ warmup_epochs: int,
32
+ max_epochs: int,
33
+ warmup_start_lr: float = 0.0,
34
+ eta_min: float = 0.0,
35
+ last_epoch: int = -1,
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ optimizer (Optimizer): Wrapped optimizer.
40
+ warmup_epochs (int): Maximum number of iterations for linear warmup
41
+ max_epochs (int): Maximum number of iterations
42
+ warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
43
+ eta_min (float): Minimum learning rate. Default: 0.
44
+ last_epoch (int): The index of last epoch. Default: -1.
45
+ """
46
+ self.warmup_epochs = warmup_epochs
47
+ self.max_epochs = max_epochs
48
+ self.warmup_start_lr = warmup_start_lr
49
+ self.eta_min = eta_min
50
+
51
+ super().__init__(optimizer, last_epoch)
52
+
53
+ def get_lr(self) -> List[float]:
54
+ """Compute learning rate using chainable form of the scheduler."""
55
+ if not self._get_lr_called_within_step:
56
+ warnings.warn(
57
+ "To get the last learning rate computed by the scheduler; please use `get_last_lr()`.",
58
+ UserWarning,
59
+ )
60
+
61
+ if self.last_epoch == 0:
62
+ return [self.warmup_start_lr] * len(self.base_lrs)
63
+ if self.last_epoch < self.warmup_epochs:
64
+ return [
65
+ group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
66
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
67
+ ]
68
+ if self.last_epoch == self.warmup_epochs:
69
+ return self.base_lrs
70
+ if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
71
+ return [
72
+ group["lr"]
73
+ + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
74
+ for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
75
+ ]
76
+
77
+ return [
78
+ (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
79
+ / (
80
+ 1
81
+ + math.cos(
82
+ math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
83
+ )
84
+ )
85
+ * (group["lr"] - self.eta_min)
86
+ + self.eta_min
87
+ for group in self.optimizer.param_groups
88
+ ]
89
+
90
+ def _get_closed_form_lr(self) -> List[float]:
91
+ """Called when epoch is passed as a param to the `step` function of the scheduler."""
92
+ if self.last_epoch < self.warmup_epochs:
93
+ return [
94
+ self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
95
+ for base_lr in self.base_lrs
96
+ ]
97
+
98
+ return [
99
+ self.eta_min
100
+ + 0.5
101
+ * (base_lr - self.eta_min)
102
+ * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
103
+ for base_lr in self.base_lrs
104
+ ]
my_utils/misc.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+
3
+
4
+ # ============ Register OmegaConf Recolvers ============= #
5
+ OmegaConf.register_new_resolver('add', lambda a, b: a + b)
6
+ OmegaConf.register_new_resolver('sub', lambda a, b: a - b)
7
+ OmegaConf.register_new_resolver('mul', lambda a, b: a * b)
8
+ OmegaConf.register_new_resolver('div', lambda a, b: a / b)
9
+ # ======================================================= #
10
+
11
+
12
+ def prompt(question):
13
+ inp = input(f"{question} (y/n)").lower().strip()
14
+ if inp and inp == 'y':
15
+ return True
16
+ if inp and inp == 'n':
17
+ return False
18
+ return prompt(question)
19
+
20
+
21
+ def load_config(*yaml_files, cli_args=[]):
22
+ yaml_confs = [OmegaConf.load(f) for f in yaml_files]
23
+ cli_conf = OmegaConf.from_cli(cli_args)
24
+ conf = OmegaConf.merge(*yaml_confs, cli_conf)
25
+ OmegaConf.resolve(conf)
26
+ return conf
27
+
28
+
29
+ def config_to_primitive(config, resolve=True):
30
+ return OmegaConf.to_container(config, resolve=resolve)
31
+
32
+
33
+ def dump_config(path, config):
34
+ with open(path, 'w') as fp:
35
+ OmegaConf.save(config=config, f=fp)
my_utils/plot.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os, sys
2
+ # sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
3
+ # import matplotlib
4
+ # matplotlib.use('Agg')
5
+ # import numpy as np
6
+ # import networkx as nx
7
+ # from io import BytesIO
8
+ # from PIL import Image, ImageDraw
9
+ # from matplotlib import pyplot as plt
10
+ # from sklearn.decomposition import PCA
11
+ # from my_utils.refs import graph_color_ref
12
+
13
+ # def add_text(text, imgarr):
14
+ # '''
15
+ # Function to add text to image
16
+
17
+ # Args:
18
+ # - text (str): text to add
19
+ # - imgarr (np.array): image array
20
+
21
+ # Returns:
22
+ # - img (np.array): image array with text
23
+ # '''
24
+ # img = Image.fromarray(imgarr)
25
+ # I = ImageDraw.Draw(img)
26
+ # I.text((10, 10), text, fill='black')
27
+ # return np.asarray(img)
28
+
29
+ # def get_color(ref, n_nodes):
30
+ # '''
31
+ # Function to color the nodes
32
+
33
+ # Args:
34
+ # - ref (list): list of color reference
35
+ # - n_nodes (int): number of nodes
36
+
37
+ # Returns:
38
+ # - colors (list): list of colors
39
+ # '''
40
+ # N = len(ref)
41
+ # colors = []
42
+ # for i in range(n_nodes):
43
+ # colors.append(np.array([[int(i) for i in ref[i%N][4:-1].split(',')]]) / 255.)
44
+ # return colors
45
+
46
+
47
+ # def make_grid(images, cols=5):
48
+ # """
49
+ # Arrange list of images into a N x cols grid.
50
+
51
+ # Args:
52
+ # - images (list): List of Numpy arrays representing the images.
53
+ # - cols (int): Number of columns for the grid.
54
+
55
+ # Returns:
56
+ # - grid (numpy array): Numpy array representing the image grid.
57
+ # """
58
+ # # Determine the dimensions of each image
59
+ # img_h, img_w, _ = images[0].shape
60
+ # rows = len(images) // cols
61
+
62
+ # # Initialize a blank canvas
63
+ # grid = np.zeros((rows * img_h, cols * img_w, 3), dtype=images[0].dtype)
64
+
65
+ # # Place each image onto the grid
66
+ # for idx, img in enumerate(images):
67
+ # y = (idx // cols) * img_h
68
+ # x = (idx % cols) * img_w
69
+ # grid[y: y + img_h, x: x + img_w] = img
70
+
71
+ # return grid
72
+
73
+ # def viz_graph(info_dict, res=256):
74
+ # '''
75
+ # Function to plot the directed graph
76
+
77
+ # Args:
78
+ # - info_dict (dict): output json containing the graph information
79
+ # - res (int): resolution of the image
80
+
81
+ # Returns:
82
+ # - img_arr (np.array): image array
83
+ # '''
84
+ # # build tree
85
+ # tree = info_dict['diffuse_tree']
86
+ # edges = []
87
+ # for node in tree:
88
+ # edges += [(node['id'], child) for child in node['children']]
89
+ # G = nx.DiGraph()
90
+ # G.add_edges_from(edges)
91
+
92
+ # # plot tree
93
+ # plt.figure(figsize=(res/100, res/100))
94
+
95
+ # colors = get_color(graph_color_ref, len(tree))
96
+ # pos = nx.nx_agraph.graphviz_layout(G, prog="twopi", args="")
97
+ # node_order = sorted(G.nodes())
98
+ # nx.draw(G, pos, node_color=colors, nodelist=node_order, edge_color='k', with_labels=False)
99
+
100
+ # buf = BytesIO()
101
+ # plt.savefig(buf, format="png", dpi=100)
102
+ # buf.seek(0)
103
+ # img = Image.open(buf)
104
+ # img_arr = np.asarray(img)
105
+ # buf.close()
106
+ # plt.clf()
107
+ # plt.close()
108
+ # return img_arr[:, :, :3]
109
+
110
+ # def viz_patch_feat_pca(feat):
111
+ # pca = PCA(n_components=3)
112
+ # pca.fit(feat)
113
+ # feat_pca = pca.transform(feat)
114
+
115
+ # t = np.array(feat_pca)
116
+ # t_min = t.min(axis=0, keepdims=True)
117
+ # t_max = t.max(axis=0, keepdims=True)
118
+ # normalized_t = (t - t_min) / (t_max - t_min)
119
+
120
+ # array = (normalized_t * 255).astype(np.uint8)
121
+ # img_array = array.reshape(16, 16, 3)
122
+ # return img_array
my_utils/refs.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # reference of object categories
2
+ cat_ref = {
3
+ "Table": 0,
4
+ "Dishwasher": 1,
5
+ "StorageFurniture": 2,
6
+ "Refrigerator": 3,
7
+ "WashingMachine": 4,
8
+ "Microwave": 5,
9
+ "Oven": 6,
10
+ }
11
+
12
+ data_mode_ref = {
13
+ "aabb_max": 0,
14
+ "aabb_min": 1,
15
+ "joint_type": 2,
16
+ "axis_dir": 3,
17
+ "axis_ori": 4,
18
+ "joint_range": 5,
19
+ "label": 6
20
+ }
21
+
22
+ # reference of semantic labels for each part
23
+ sem_ref = {
24
+ "fwd": {
25
+ "door": 0,
26
+ "drawer": 1,
27
+ "base": 2,
28
+ "handle": 3,
29
+ "wheel": 4,
30
+ "knob": 5,
31
+ "shelf": 6,
32
+ "tray": 7,
33
+ },
34
+ "bwd": {
35
+ 0: "door",
36
+ 1: "drawer",
37
+ 2: "base",
38
+ 3: "handle",
39
+ 4: "wheel",
40
+ 5: "knob",
41
+ 6: "shelf",
42
+ 7: "tray",
43
+ },
44
+ }
45
+
46
+ # reference of joint types for each part
47
+ joint_ref = {
48
+ "fwd": {"fixed": 1, "revolute": 2, "prismatic": 3, "screw": 4, "continuous": 5},
49
+ "bwd": {1: "fixed", 2: "revolute", 3: "prismatic", 4: "screw", 5: "continuous"},
50
+ }
51
+
52
+
53
+ import plotly.express as px
54
+
55
+ # pallette for joint type color
56
+ joint_color_ref = px.colors.qualitative.Set1
57
+ # pallette for graph node color
58
+ # graph_color_ref = px.colors.qualitative.Bold + px.colors.qualitative.Prism
59
+ # graph_color_ref = [
60
+ # "rgb(200, 200, 200)", # 奶橙黄
61
+ # "rgb(255, 196, 200)", # 莓奶粉
62
+ # "rgb(154, 228, 186)", # 牛油果绿
63
+ # "rgb(252, 208, 140)", # 奶橙黄
64
+ # "rgb(217, 189, 250)", # 薄紫
65
+ # "rgb(203, 237, 164)", # 抹茶绿
66
+ # "rgb(188, 229, 235)", # 青蓝灰
67
+ # "rgb(179, 199, 243)", # 雾蓝
68
+ # "rgb(255, 224, 130)", # 淡柠黄
69
+ # "rgb(222, 179, 212)", # 粉紫
70
+ # "rgb(148, 212, 224)", # 冰蓝
71
+ # ]
72
+ graph_color_ref = [
73
+ "rgb(160, 160, 160)", # 奶橙灰 → 深灰白,对比提升
74
+ "rgb(255, 130, 145)", # 莓奶粉 → 更亮更红
75
+ "rgb(80, 200, 150)", # 牛油果绿 → 更深更绿
76
+ "rgb(255, 180, 60)", # 奶橙黄 → 更橙更亮
77
+ "rgb(180, 140, 255)", # 薄紫 → 更强饱和度紫
78
+ "rgb(130, 210, 50)", # 抹茶绿 → 偏亮偏黄的绿
79
+ "rgb(90, 190, 220)", # 青蓝灰 → 加蓝提升对比
80
+ "rgb(100, 150, 255)", # 雾蓝 → 饱和冷蓝
81
+ "rgb(255, 200, 0)", # 淡柠黄 → 纯柠黄
82
+ "rgb(200, 100, 190)", # 粉紫 → 更紫
83
+ "rgb(80, 180, 255)", # 冰蓝 → 更冷更亮的蓝
84
+ "rgb(255, 130, 145)", # 莓奶粉 → 更亮更红
85
+ "rgb(80, 200, 150)", # 牛油果绿 → 更深更绿
86
+ "rgb(255, 180, 60)", # 奶橙黄 → 更橙更亮
87
+ "rgb(180, 140, 255)", # 薄紫 → 更强饱和度紫
88
+ "rgb(130, 210, 50)", # 抹茶绿 → 偏亮偏黄的绿
89
+ "rgb(90, 190, 220)", # 青蓝灰 → 加蓝提升对比
90
+ "rgb(100, 150, 255)", # 雾蓝 → 饱和冷蓝
91
+ "rgb(255, 200, 0)", # 淡柠黄 → 纯柠黄
92
+ "rgb(200, 100, 190)", # 粉紫 → 更紫
93
+ "rgb(80, 180, 255)", # 冰蓝 → 更冷更亮的蓝
94
+ "rgb(255, 130, 145)", # 莓奶粉 → 更亮更红
95
+ "rgb(80, 200, 150)", # 牛油果绿 → 更深更绿
96
+ "rgb(255, 180, 60)", # 奶橙黄 → 更橙更亮
97
+ "rgb(180, 140, 255)", # 薄紫 → 更强饱和度紫
98
+ "rgb(130, 210, 50)", # 抹茶绿 → 偏亮偏黄的绿
99
+ "rgb(90, 190, 220)", # 青蓝灰 → 加蓝提升对比
100
+ "rgb(100, 150, 255)", # 雾蓝 → 饱和冷蓝
101
+ "rgb(255, 200, 0)", # 淡柠黄 → 纯柠黄
102
+ "rgb(200, 100, 190)", # 粉紫 → 更紫
103
+ "rgb(80, 180, 255)", # 冰蓝 → 更冷更亮的蓝
104
+ "rgb(255, 130, 145)", # 莓奶粉 → 更亮更红
105
+ "rgb(80, 200, 150)", # 牛油果绿 → 更深更绿
106
+ "rgb(255, 180, 60)", # 奶橙黄 → 更橙更亮
107
+ "rgb(180, 140, 255)", # 薄紫 → 更强饱和度紫
108
+ "rgb(130, 210, 50)", # 抹茶绿 → 偏亮偏黄的绿
109
+ "rgb(90, 190, 220)", # 青蓝灰 → 加蓝提升对比
110
+ "rgb(100, 150, 255)", # 雾蓝 → 饱和冷蓝
111
+ "rgb(255, 200, 0)", # 淡柠黄 → 纯柠黄
112
+ "rgb(200, 100, 190)", # 粉紫 → 更紫
113
+ "rgb(80, 180, 255)", # 冰蓝 → 更冷更亮的蓝
114
+ ]
115
+ # pallette for semantic label color
116
+ semantic_color_ref = px.colors.qualitative.Vivid_r
117
+ # attention map visulaization color
118
+ attn_color_ref = px.colors.sequential.Viridis
119
+
120
+ from matplotlib.colors import LinearSegmentedColormap
121
+
122
+ cmap_attn = LinearSegmentedColormap.from_list("mycmap", attn_color_ref, N=256)
my_utils/render.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os, sys
2
+ # sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
3
+ # import trimesh
4
+ # import pyrender
5
+ # import numpy as np
6
+ # # import open3d as o3d
7
+ # from copy import deepcopy
8
+ # os.environ['PYOPENGL_PLATFORM'] = 'egl'
9
+ # from my_utils.refs import semantic_color_ref, graph_color_ref, joint_color_ref
10
+
11
+ # def get_rotation_axis_angle(k, theta):
12
+ # '''
13
+ # Rotation matrix converter from axis-angle using Rodrigues' rotation formula
14
+
15
+ # Args:
16
+ # k (np.ndarray): 3D unit vector representing the axis to rotate about.
17
+ # theta (float): Angle to rotate with in radians.
18
+
19
+ # Returns:
20
+ # R (np.ndarray): 3x3 rotation matrix.
21
+ # '''
22
+ # if np.linalg.norm(k) == 0.:
23
+ # return np.eye(3)
24
+ # k = k / np.linalg.norm(k)
25
+ # kx, ky, kz = k[0], k[1], k[2]
26
+ # cos, sin = np.cos(theta), np.sin(theta)
27
+ # R = np.zeros((3, 3), dtype=np.float32)
28
+ # R[0, 0] = cos + (kx**2) * (1 - cos)
29
+ # R[0, 1] = kx * ky * (1 - cos) - kz * sin
30
+ # R[0, 2] = kx * kz * (1 - cos) + ky * sin
31
+ # R[1, 0] = kx * ky * (1 - cos) + kz * sin
32
+ # R[1, 1] = cos + (ky**2) * (1 - cos)
33
+ # R[1, 2] = ky * kz * (1 - cos) - kx * sin
34
+ # R[2, 0] = kx * kz * (1 - cos) - ky * sin
35
+ # R[2, 1] = ky * kz * (1 - cos) + kx * sin
36
+ # R[2, 2] = cos + (kz**2) * (1 - cos)
37
+ # return R
38
+
39
+ # def rescale_axis(jtype, axis_d, axis_o, box_center):
40
+ # '''
41
+ # Function to rescale the axis for rendering
42
+
43
+ # Args:
44
+ # - jtype (int): joint type
45
+ # - axis_d (np.array): axis direction
46
+ # - axis_o (np.array): axis origin
47
+ # - box_center (np.array): bounding box center
48
+
49
+ # Returns:
50
+ # - center (np.array): rescaled axis origin
51
+ # - axis_d (np.array): rescaled axis direction
52
+ # '''
53
+ # if jtype == 0 or jtype == 1:
54
+ # return [0., 0., 0.], [0., 0., 0.]
55
+ # if jtype == 3 or jtype == 4:
56
+ # center = box_center
57
+ # else:
58
+ # center = axis_o + np.dot(axis_d, box_center-axis_o) * axis_d
59
+ # return center.tolist(), axis_d.tolist()
60
+
61
+ # # def get_axis_mesh(k, axis_o, bbox_center, joint_type):
62
+ # # '''
63
+ # # Function to get the axis mesh
64
+
65
+ # # Args:
66
+ # # - k (np.array): axis direction
67
+ # # - center (np.array): axis origin
68
+ # # - bbox_center (np.array): bounding box center
69
+ # # - joint_type (int): joint type
70
+ # # '''
71
+ # # if joint_type == 0 or joint_type == 1 or np.linalg.norm(k) == 0. :
72
+ # # return None
73
+
74
+ # # k = k / np.linalg.norm(k)
75
+
76
+ # # if joint_type == 3 or joint_type == 4: # prismatic or screw
77
+ # # axis_o = bbox_center
78
+ # # else: # revolute or continuous
79
+ # # axis_o = axis_o + np.dot(k, bbox_center-axis_o) * k
80
+ # # axis = o3d.geometry.TriangleMesh.create_arrow(cylinder_radius=0.015, cone_radius=0.03, cylinder_height=1.0, cone_height=0.08)
81
+ # # arrow = np.array([0., 0., 1.], dtype=np.float32)
82
+ # # n = np.cross(arrow, k)
83
+ # # rad = np.arccos(np.dot(arrow, k))
84
+ # # R_arrow = get_rotation_axis_angle(n, rad)
85
+ # # axis.rotate(R_arrow, center=(0, 0, 0))
86
+ # # axis.translate(axis_o[:3])
87
+ # # axis.compute_vertex_normals()
88
+ # # vertices = np.asarray(axis.vertices)
89
+ # # faces = np.asarray(axis.triangles)
90
+ # # trimesh_axis = trimesh.Trimesh(vertices=vertices, faces=faces)
91
+ # # # trimesh_axis.visual.vertex_colors = np.array([0, 0, 0, 1.0], dtype=np.float32)
92
+ # # trimesh_axis.visual.vertex_colors = np.repeat(np.array([0, 0, 0, 1.0]), vertices.shape[0], axis=0)
93
+ # # return trimesh_axis
94
+
95
+ # def get_camera_pose(eye, look_at, up):
96
+ # """
97
+ # Compute the 4x4 transformation matrix for a camera pose.
98
+
99
+ # Parameters:
100
+ # eye (np.ndarray): 3D position of the camera.
101
+ # look_at (np.ndarray): 3D point the camera is looking at.
102
+ # up (np.ndarray): Up vector.
103
+
104
+ # Returns:
105
+ # pose (np.ndarray): 4x4 transformation matrix representing the camera pose.
106
+ # """
107
+ # # Compute the forward, right, and new up vectors
108
+ # forward = (look_at - eye)
109
+ # forward = forward / np.linalg.norm(forward)
110
+
111
+ # right = np.cross(forward, up)
112
+ # right = right / np.linalg.norm(right)
113
+
114
+ # new_up = np.cross(right, forward)
115
+ # new_up = new_up / np.linalg.norm(new_up)
116
+
117
+ # # Create rotation matrix
118
+ # pose = np.eye(4)
119
+ # pose[0:3, 0] = right
120
+ # pose[0:3, 1] = new_up
121
+ # pose[0:3, 2] = -forward # Negative because the camera looks along the negative Z axis in its local coordinate
122
+ # pose[0:3, 3] = eye
123
+
124
+ # return pose
125
+
126
+ # def get_rotation_axis_angle_box(axis, angle):
127
+ # axis = axis / np.linalg.norm(axis)
128
+ # return trimesh.transformations.rotation_matrix(angle, axis)
129
+
130
+ # def get_colored_box(center, size, jtype=None, jrange=None, axis_d=None, axis_o=None):
131
+ # '''
132
+ # Create a solid color box and its animated state if joint info is provided
133
+
134
+ # Args:
135
+ # center (np.array): box center (3,)
136
+ # size (np.array): box size (3,)
137
+ # color (list or array): RGBA color, e.g. [255, 0, 0, 255]
138
+ # jtype (int): joint type (2=rot, 3=slide, 4=screw, 5=continuous)
139
+ # jrange (list): joint motion range
140
+ # axis_d (np.array): axis direction (3,)
141
+ # axis_o (np.array): axis origin (3,)
142
+
143
+ # Returns:
144
+ # box: trimesh.Trimesh at rest
145
+ # box_anim: trimesh.Trimesh after transformation
146
+ # '''
147
+ # size = np.clip(size, a_min=0.005, a_max=3.0)
148
+ # center = np.clip(center, a_min=-3.0, a_max=3.0)
149
+
150
+ # # Rest state box
151
+ # box = trimesh.creation.box(extents=size)
152
+ # box.apply_translation(center)
153
+
154
+ # # Animated state (deepcopy + transform)
155
+ # box_anim = deepcopy(box)
156
+
157
+ # if jtype is not None:
158
+ # if jtype == 2: # revolute
159
+ # theta = np.deg2rad(jrange[1])
160
+ # T = trimesh.transformations.translation_matrix(axis_o)
161
+ # R_3 = get_rotation_axis_angle(axis_d, theta)
162
+ # R = np.eye(4, dtype=np.float32)
163
+ # R[:3, :3] = R_3
164
+ # T_inv = trimesh.transformations.translation_matrix(-axis_o)
165
+ # box_anim.apply_transform(T @ R @ T_inv)
166
+
167
+ # elif jtype == 3: # prismatic
168
+ # dist = float(jrange[1])
169
+ # T = trimesh.transformations.translation_matrix(axis_d * dist)
170
+ # box_anim.apply_transform(T)
171
+
172
+ # elif jtype == 4: # screw
173
+ # theta = np.pi / 4
174
+ # dist = float(jrange[1])
175
+ # T1 = trimesh.transformations.translation_matrix(-axis_o)
176
+ # R = get_rotation_axis_angle(axis_d, theta)
177
+ # T2 = trimesh.transformations.translation_matrix(axis_o + axis_d * dist)
178
+ # box_anim.apply_transform(T1 @ R @ T2)
179
+
180
+ # elif jtype == 5: # continuous
181
+ # theta = np.pi / 4
182
+ # T = trimesh.transformations.translation_matrix(-axis_o)
183
+ # R_3 = get_rotation_axis_angle(axis_d, theta)
184
+ # R = np.eye(4, dtype=np.float32)
185
+ # R[:3, :3] = R_3
186
+ # T_inv = trimesh.transformations.translation_matrix(axis_o)
187
+ # box_anim.apply_transform(T @ R @ T_inv)
188
+
189
+ # return box, box_anim
190
+
191
+ # # def get_bbox_mesh_pair(center, size, radius=0.01, jtype=None, jrange=None, axis_d=None, axis_o=None):
192
+ # # '''
193
+ # # Function to get the bounding box mesh pair
194
+
195
+ # # Args:
196
+ # # - center (np.array): bounding box center
197
+ # # - size (np.array): bounding box size
198
+ # # - radius (float): radius of the cylinder
199
+ # # - jtype (int): joint type
200
+ # # - jrange (list): joint range
201
+ # # - axis_d (np.array): axis direction
202
+ # # - axis_o (np.array): axis origin
203
+
204
+ # # Returns:
205
+ # # - trimesh_box (trimesh object): trimesh object for the bbox at resting state
206
+ # # - trimesh_box_anim (trimesh object): trimesh object for the bbox at opening state
207
+ # # '''
208
+
209
+ # # size = np.clip(size, a_max=3, a_min=0.005)
210
+ # # center = np.clip(center, a_max=3, a_min=-3)
211
+
212
+ # # line_box = o3d.geometry.TriangleMesh()
213
+ # # z_cylinder = o3d.geometry.TriangleMesh.create_cylinder(radius=radius, height=size[2])
214
+ # # y_cylinder = o3d.geometry.TriangleMesh.create_cylinder(radius=radius, height=size[1])
215
+ # # R_y = get_rotation_axis_angle(np.array([1., 0., 0.], dtype=np.float32), np.pi / 2)
216
+ # # y_cylinder.rotate(R_y, center=(0, 0, 0))
217
+ # # x_cylinder = o3d.geometry.TriangleMesh.create_cylinder(radius=radius, height=size[0])
218
+ # # R_x = get_rotation_axis_angle(np.array([0., 1., 0.], dtype=np.float32), np.pi / 2)
219
+ # # x_cylinder.rotate(R_x, center=(0, 0, 0))
220
+
221
+
222
+ # # z1 = deepcopy(z_cylinder)
223
+ # # z1.translate(np.array([-size[0] / 2, size[1] / 2, 0.], dtype=np.float32))
224
+ # # line_box += z1.translate(center[:3])
225
+ # # z2 = deepcopy(z_cylinder)
226
+ # # z2.translate(np.array([size[0] / 2, size[1] / 2, 0.], dtype=np.float32))
227
+ # # line_box += z2.translate(center[:3])
228
+ # # z3 = deepcopy(z_cylinder)
229
+ # # z3.translate(np.array([-size[0] / 2, -size[1] / 2, 0.], dtype=np.float32))
230
+ # # line_box += z3.translate(center[:3])
231
+ # # z4 = deepcopy(z_cylinder)
232
+ # # z4.translate(np.array([size[0] / 2, -size[1] / 2, 0.], dtype=np.float32))
233
+ # # line_box += z4.translate(center[:3])
234
+
235
+ # # y1 = deepcopy(y_cylinder)
236
+ # # y1.translate(np.array([-size[0] / 2, 0., size[2] / 2], dtype=np.float32))
237
+ # # line_box += y1.translate(center[:3])
238
+ # # y2 = deepcopy(y_cylinder)
239
+ # # y2.translate(np.array([size[0] / 2, 0., size[2] / 2], dtype=np.float32))
240
+ # # line_box += y2.translate(center[:3])
241
+ # # y3 = deepcopy(y_cylinder)
242
+ # # y3.translate(np.array([-size[0] / 2, 0., -size[2] / 2], dtype=np.float32))
243
+ # # line_box += y3.translate(center[:3])
244
+ # # y4 = deepcopy(y_cylinder)
245
+ # # y4.translate(np.array([size[0] / 2, 0., -size[2] / 2], dtype=np.float32))
246
+ # # line_box += y4.translate(center[:3])
247
+
248
+ # # x1 = deepcopy(x_cylinder)
249
+ # # x1.translate(np.array([0., -size[1] / 2, size[2] / 2], dtype=np.float32))
250
+ # # line_box += x1.translate(center[:3])
251
+ # # x2 = deepcopy(x_cylinder)
252
+ # # x2.translate(np.array([0., size[1] / 2, size[2] / 2], dtype=np.float32))
253
+ # # line_box += x2.translate(center[:3])
254
+ # # x3 = deepcopy(x_cylinder)
255
+ # # x3.translate(np.array([0., -size[1] / 2, -size[2] / 2], dtype=np.float32))
256
+ # # line_box += x3.translate(center[:3])
257
+ # # x4 = deepcopy(x_cylinder)
258
+ # # x4.translate(np.array([0., size[1] / 2, -size[2] / 2]))
259
+ # # line_box += x4.translate(center[:3])
260
+
261
+ # # # transform
262
+ # # line_box_anim = deepcopy(line_box)
263
+ # # if jtype == 2: # revolute
264
+ # # theta = np.deg2rad(jrange[1])
265
+ # # line_box_anim.translate(-axis_o)
266
+ # # R = get_rotation_axis_angle(axis_d, theta)
267
+ # # line_box_anim.rotate(R, center=(0, 0, 0))
268
+ # # line_box_anim.translate(axis_o)
269
+ # # elif jtype == 3: # prismatic
270
+ # # dist = np.array(jrange[1], dtype=np.float32)
271
+ # # line_box_anim.translate(axis_d * dist)
272
+ # # elif jtype == 4: # screw
273
+ # # dist = np.array(jrange[1], dtype=np.float32)
274
+ # # theta = 0.25 * np.pi
275
+ # # R = get_rotation_axis_angle(axis_d, theta)
276
+ # # line_box_anim.translate(-axis_o)
277
+ # # line_box_anim.rotate(R, center=(0, 0, 0))
278
+ # # line_box_anim.translate(axis_o)
279
+ # # line_box_anim.translate(axis_d * dist)
280
+ # # elif jtype == 5: # continuous
281
+ # # theta = 0.25 * np.pi
282
+ # # R = get_rotation_axis_angle(axis_d, theta)
283
+ # # line_box_anim.translate(-axis_o)
284
+ # # line_box_anim.rotate(R, center=(0, 0, 0))
285
+ # # line_box_anim.translate(axis_o)
286
+
287
+ # # vertices = np.asarray(line_box.vertices)
288
+ # # faces = np.asarray(line_box.triangles)
289
+ # # trimesh_box = trimesh.Trimesh(vertices=vertices, faces=faces)
290
+ # # trimesh_box.visual.vertex_colors = np.array([0.0, 1.0, 1.0, 1.0], dtype=np.float32)
291
+
292
+ # # vertices_anim = np.asarray(line_box_anim.vertices)
293
+ # # faces_anim = np.asarray(line_box_anim.triangles)
294
+ # # trimesh_box_anim = trimesh.Trimesh(vertices=vertices_anim, faces=faces_anim)
295
+ # # trimesh_box_anim.visual.vertex_colors = np.array([0.0, 1.0, 1.0, 1.0], dtype=np.float32)
296
+
297
+ # # return trimesh_box, trimesh_box_anim
298
+
299
+
300
+ # def get_color_from_palette(palette, idx):
301
+ # '''
302
+ # Function to get the color from the palette
303
+
304
+ # Args:
305
+ # - palette (list): list of color reference
306
+ # - idx (int): index of the color
307
+
308
+ # Returns:
309
+ # - color (np.array): color in the index of idx
310
+ # '''
311
+ # ref = palette[idx % len(palette)]
312
+ # ref_list = [int(i) for i in ref[4:-1].split(',')]
313
+ # if idx % len(palette) == 0:
314
+ # ref_list.append(120)
315
+ # else:
316
+ # ref_list.append(255)
317
+ # color = np.array([ref_list], dtype=np.float32) / 255.
318
+ # return color
319
+
320
+
321
+
322
+ # def render_anim_parts(aabbs, axiss, resolution=256):
323
+ # '''
324
+ # Function to render the 3D bounding boxes and axes in the scene
325
+
326
+ # Args:
327
+ # aabbs: list of trimesh objects for the bounding box of each part
328
+ # axiss: list of trimesh objects for the axis of each part
329
+ # resolution: resolution of the rendered image
330
+
331
+ # Returns:
332
+ # color_img: rendered image
333
+ # '''
334
+ # n_parts = len(aabbs)
335
+ # # build mesh for each 3D bounding box
336
+ # scene = pyrender.Scene()
337
+ # for i in range(n_parts):
338
+ # scene.add(aabbs[i])
339
+ # if axiss[i] is not None:
340
+ # scene.add(axiss[i])
341
+
342
+ # # Add light to the scene
343
+ # scene.ambient_light = np.full(shape=3, fill_value=1.5, dtype=np.float32)
344
+ # light = pyrender.DirectionalLight(color=np.ones(2), intensity=5.0)
345
+
346
+ # # Add camera to the scene
347
+ # pose = get_camera_pose(eye=np.array([1.5, 1.2, 4.5]), look_at=np.array([0, 0, 0]), up=np.array([0, 1, 0]))
348
+ # camera = pyrender.PerspectiveCamera(yfov=np.pi / 5.0, aspectRatio=1.0)
349
+ # scene.add(light, pose=pose)
350
+ # scene.add(camera, pose=pose)
351
+
352
+ # # Offscreen Rendering
353
+ # offscreen_renderer = pyrender.OffscreenRenderer(resolution, resolution)
354
+
355
+ # # Render the scene
356
+ # color_img, _ = offscreen_renderer.render(scene)
357
+
358
+ # # Cleanup
359
+ # offscreen_renderer.delete()
360
+ # scene.clear()
361
+ # return color_img
362
+
363
+
364
+ # def draw_boxes_axiss_anim(aabbs_0, aabbs_1, axiss, mode='graph', resolution=256, types=None):
365
+ # '''
366
+ # Function to draw the 3D bounding boxes and axes of the two frames
367
+
368
+ # Args:
369
+ # aabbs_0: list of trimesh objects for the bounding box of each part in the resting state
370
+ # aabbs_1: list of trimesh objects for the bounding box of each part in the open state
371
+ # axiss: list of trimesh objects for the axis of each part
372
+ # mode:
373
+ # 'graph' using palette corresponding to graph node,
374
+ # 'jtype' using palette corresponding to joint type,
375
+ # 'semantic' using palette corresponding to semantic label
376
+ # resolution: resolution of the rendered image
377
+ # types: ids corresponding to each joint type or semantic label, if mode is 'jtype' or 'semantic'
378
+ # '''
379
+ # n_parts = len(aabbs_0)
380
+ # ren_aabbs_0 = []
381
+ # ren_aabbs_1 = []
382
+ # ren_axiss = []
383
+ # if mode == 'graph':
384
+ # palette = graph_color_ref
385
+ # # Add meshes to the scene
386
+ # for i in range(n_parts):
387
+ # color = get_color_from_palette(palette, i)
388
+ # aabb_0 = pyrender.Mesh.from_trimesh(aabbs_0[i], smooth=False)
389
+ # aabb_0.primitives[0].color_0 = color.repeat(aabb_0.primitives[0].positions.shape[0], axis=0)
390
+ # ren_aabbs_0.append(aabb_0)
391
+ # aabb_1 = pyrender.Mesh.from_trimesh(aabbs_1[i], smooth=False)
392
+ # aabb_1.primitives[0].color_0 = color.repeat(aabb_1.primitives[0].positions.shape[0], axis=0)
393
+ # ren_aabbs_1.append(aabb_1)
394
+ # if axiss[i] is not None:
395
+ # axis = pyrender.Mesh.from_trimesh(axiss[i], smooth=False)
396
+ # axis.primitives[0].color_0 = color.repeat(axis.primitives[0].positions.shape[0], axis=0)
397
+ # ren_axiss.append(axis)
398
+ # else:
399
+ # ren_axiss.append(None)
400
+ # elif mode == 'jtype' or mode == 'semantic':
401
+ # assert types is not None
402
+ # palette = joint_color_ref if mode == 'jtype' else semantic_color_ref
403
+ # # Add meshes to the scene
404
+ # for i in range(n_parts):
405
+ # color = get_color_from_palette(palette, types[i])
406
+ # aabb_0 = pyrender.Mesh.from_trimesh(aabbs_0[i], smooth=False)
407
+ # aabb_0.primitives[0].color_0 = color.repeat(aabb_0.primitives[0].positions.shape[0], axis=0)
408
+ # ren_aabbs_0.append(aabb_0)
409
+ # aabb_1 = pyrender.Mesh.from_trimesh(aabbs_1[i], smooth=False)
410
+ # aabb_1.primitives[0].color_0 = color.repeat(aabb_1.primitives[0].positions.shape[0], axis=0)
411
+ # ren_aabbs_1.append(aabb_1)
412
+
413
+ # if axiss[i] is not None:
414
+ # axis = pyrender.Mesh.from_trimesh(axiss[i], smooth=False)
415
+ # ren_axiss.append(axis)
416
+ # else:
417
+ # ren_axiss.append(None)
418
+ # else:
419
+ # raise ValueError('mode must be either graph or type')
420
+
421
+ # img0 = render_anim_parts(ren_aabbs_0, ren_axiss, resolution=resolution)
422
+ # img1 = render_anim_parts(ren_aabbs_1, ren_axiss, resolution=resolution)
423
+ # return np.concatenate([img0, img1], axis=1)
424
+
425
+ # def prepare_meshes(info_dict):
426
+ # """
427
+ # Function to prepare the bbox and axis meshes for visualization
428
+
429
+ # Args:
430
+ # - info_dict (dict): output json containing the graph information
431
+ # """
432
+ # from my_utils.refs import joint_ref, sem_ref
433
+ # tree = info_dict["diffuse_tree"]
434
+ # bbox_0, bbox_1, axiss, labels, jtypes = [], [], [], [], []
435
+ # root_id = 0
436
+ # # get root id
437
+ # for node in tree:
438
+ # if node["parent"] == -1:
439
+ # root_id = node["id"]
440
+ # for node in tree:
441
+ # # retrieve info
442
+ # box_cen = np.array(node["aabb"]["center"], dtype=np.float32)
443
+ # box_size = np.array(node["aabb"]["size"], dtype=np.float32)
444
+ # axis_d = np.array(node["joint"]["axis"]["direction"], dtype=np.float32)
445
+ # axis_o = np.array(node["joint"]["axis"]["origin"], dtype=np.float32)
446
+ # jtype = joint_ref["fwd"][node["joint"]["type"]]
447
+ # # construct meshes for bbox in two states (closed and fully open)
448
+ # if node["id"] == root_id or node["parent"] == root_id: # use the joint info directly
449
+ # bb_0, bb_1 = get_colored_box(
450
+ # box_cen,
451
+ # box_size,
452
+ # jtype=jtype,
453
+ # jrange= node["joint"]["range"],
454
+ # axis_d=axis_d,
455
+ # axis_o=axis_o,
456
+ # )
457
+ # else: # use the parent joint info
458
+ # parent_id = node["parent"]
459
+ # bb_0, bb_1 = get_colored_box(
460
+ # box_cen,
461
+ # box_size,
462
+ # jtype=joint_ref["fwd"][tree[parent_id]["joint"]["type"]],
463
+ # jrange=tree[parent_id]["joint"]["range"],
464
+ # axis_d=np.array(tree[parent_id]["joint"]["axis"]["direction"], dtype=np.float32),
465
+ # axis_o=np.array(tree[parent_id]["joint"]["axis"]["origin"], dtype=np.float32),
466
+ # )
467
+ # # construct mesh for joint axis
468
+ # axis_mesh = get_axis_mesh(axis_d, axis_o, box_cen, node["joint"]["type"])
469
+ # # append
470
+ # bbox_0.append(bb_0)
471
+ # bbox_1.append(bb_1)
472
+ # axiss.append(axis_mesh)
473
+ # labels.append(sem_ref["fwd"][node["name"]])
474
+ # jtypes.append(jtype)
475
+
476
+ # return {
477
+ # "bbox_0": bbox_0,
478
+ # "bbox_1": bbox_1,
479
+ # "axiss": axiss,
480
+ # "labels": labels,
481
+ # "jtypes": jtypes,
482
+ # }
my_utils/savermixins.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import imageio
5
+ import numpy as np
6
+
7
+ class SaverMixin():
8
+
9
+ @property
10
+ def save_dir(self):
11
+ return self.hparams.save_dir
12
+
13
+ def convert_format(self, data):
14
+ if isinstance(data, np.ndarray):
15
+ return data
16
+ elif isinstance(data, torch.Tensor):
17
+ return data.cpu().numpy()
18
+ elif isinstance(data, list):
19
+ return [self.convert_format(d) for d in data]
20
+ elif isinstance(data, dict):
21
+ return {k: self.convert_format(v) for k, v in data.items()}
22
+ else:
23
+ raise TypeError('Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting', type(data))
24
+
25
+ def get_save_path(self, filename):
26
+ save_path = os.path.join(self.save_dir, filename)
27
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
28
+ return save_path
29
+
30
+ def save_rgb_image(self, filename, img):
31
+ imageio.imwrite(self.get_save_path(filename), img)
32
+
33
+ def save_rgb_video(self, filename, stage='fit', filter=None):
34
+ img_dir = os.path.join(self.logger.log_dir, 'images', stage)
35
+
36
+ writer_graph = imageio.get_writer(os.path.join(img_dir, filename), fps=1)
37
+
38
+ for file in sorted(os.listdir(img_dir)):
39
+ if file.endswith('.png') and 'gt' not in file:
40
+ if filter is not None:
41
+ if filter in file:
42
+ writer_graph.append_data(imageio.imread(os.path.join(img_dir, file)))
43
+ else:
44
+ writer_graph.append_data(imageio.imread(os.path.join(img_dir, file)))
45
+
46
+ writer_graph.close()
47
+
48
+
49
+
50
+ def save_json(self, filename, data):
51
+ save_path = self.get_save_path(filename)
52
+ with open(save_path, 'w') as f:
53
+ json.dump(data, f)
54
+
55
+
objects/__init__.py ADDED
File without changes
objects/dict_utils.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.optimize import linear_sum_assignment
3
+
4
+
5
+ def get_base_part_idx(obj_dict):
6
+ """
7
+ Get the index of the base part in the object dictionary\n
8
+
9
+ - obj_dict: the object dictionary\n
10
+
11
+ Return:\n
12
+ - base_part_idx: the index of the base part
13
+ """
14
+
15
+ # Adjust for NAP's corner case
16
+ base_part_ids = np.where(
17
+ [part["parent"] == -1 for part in obj_dict["diffuse_tree"]]
18
+ )[0]
19
+ if len(base_part_ids) > 0:
20
+ return base_part_ids[0].item()
21
+ else:
22
+ raise ValueError("No base part found")
23
+
24
+
25
+ def get_bbox_vertices(obj_dict, part_idx):
26
+ """
27
+ Get the 8 vertices of the bounding box\n
28
+ The order of the vertices is the same as the order that pytorch3d.ops.box3d_overlap expects\n
29
+ (This order is not necessary since we are not using pytorch3d.ops.box3d_overlap anymore)\n
30
+
31
+ - bbox_center: the center of the bounding box in the form: [cx, cy, cz]\n
32
+ - bbox_size: the size of the bounding box in the form: [lx, ly, lz]\n
33
+
34
+ Return:\n
35
+ - bbox_vertices: the 8 vertices of the bounding box in the form: [[x0, y0, z0], [x1, y1, z1], ...]
36
+ """
37
+
38
+ part = obj_dict["diffuse_tree"][part_idx]
39
+ bbox_center = np.array(part["aabb"]["center"], dtype=np.float32)
40
+ bbox_size_half = np.array(part["aabb"]["size"], dtype=np.float32) / 2
41
+
42
+ bbox_vertices = np.zeros((8, 3), dtype=np.float32)
43
+
44
+ # Get the 8 vertices of the bounding box in the order that pytorch3d.ops.box3d_overlap expects:
45
+ # 0: (x0, y0, z0) # 1: (x1, y0, z0) # 2: (x1, y1, z0) # 3: (x0, y1, z0)
46
+ # 4: (x0, y0, z1) # 5: (x1, y0, z1) # 6: (x1, y1, z1) # 7: (x0, y1, z1)
47
+ bbox_vertices[0, :] = bbox_center - bbox_size_half
48
+ bbox_vertices[1, :] = bbox_center + np.array(
49
+ [bbox_size_half[0], -bbox_size_half[1], -bbox_size_half[2]], dtype=np.float32
50
+ )
51
+ bbox_vertices[2, :] = bbox_center + np.array(
52
+ [bbox_size_half[0], bbox_size_half[1], -bbox_size_half[2]], dtype=np.float32
53
+ )
54
+ bbox_vertices[3, :] = bbox_center + np.array(
55
+ [-bbox_size_half[0], bbox_size_half[1], -bbox_size_half[2]], dtype=np.float32
56
+ )
57
+ bbox_vertices[4, :] = bbox_center + np.array(
58
+ [-bbox_size_half[0], -bbox_size_half[1], bbox_size_half[2]], dtype=np.float32
59
+ )
60
+ bbox_vertices[5, :] = bbox_center + np.array(
61
+ [bbox_size_half[0], -bbox_size_half[1], bbox_size_half[2]], dtype=np.float32
62
+ )
63
+ bbox_vertices[6, :] = bbox_center + bbox_size_half
64
+ bbox_vertices[7, :] = bbox_center + np.array(
65
+ [-bbox_size_half[0], bbox_size_half[1], bbox_size_half[2]], dtype=np.float32
66
+ )
67
+
68
+ return bbox_vertices
69
+
70
+
71
+ def compute_overall_bbox_size(obj_dict):
72
+ """
73
+ Compute the overall bounding box size of the object\n
74
+
75
+ - obj_dict: the object dictionary\n
76
+
77
+ Return:\n
78
+ - bbox_size: the overall bounding box size in the form: [lx, ly, lz]
79
+ """
80
+
81
+ bbox_min = np.zeros((len(obj_dict["diffuse_tree"]), 3), dtype=np.float32)
82
+ bbox_max = np.zeros((len(obj_dict["diffuse_tree"]), 3), dtype=np.float32)
83
+
84
+ # For each part, compute the bounding box and store the min and max vertices
85
+ for part_idx, part in enumerate(obj_dict["diffuse_tree"]):
86
+ bbox_center = np.array(part["aabb"]["center"], dtype=np.float32)
87
+ bbox_size_half = np.array(part["aabb"]["size"], dtype=np.float32) / 2
88
+ bbox_min[part_idx] = bbox_center - bbox_size_half
89
+ bbox_max[part_idx] = bbox_center + bbox_size_half
90
+
91
+ # Compute the overall bounding box size
92
+ bbox_min = np.min(bbox_min, axis=0)
93
+ bbox_max = np.max(bbox_max, axis=0)
94
+ bbox_size = bbox_max - bbox_min
95
+ return bbox_size
96
+
97
+
98
+ def remove_handles(obj_dict):
99
+ """
100
+ Remove the handles from the object dictionary and adjust the id, parent, and children of the parts\n
101
+
102
+ - obj_dict: the object dictionary\n
103
+
104
+ Return:\n
105
+ - obj_dict: the object dictionary without the handles
106
+ """
107
+
108
+ # Find the indices of the handles
109
+ handle_idxs = np.array(
110
+ [
111
+ i
112
+ for i in range(len(obj_dict["diffuse_tree"]))
113
+ if obj_dict["diffuse_tree"][i]["name"] == "handle"
114
+ and obj_dict["diffuse_tree"][i]["parent"] != -1
115
+ ]
116
+ ) # Added to avoid corner case of NAP where the handle is the base part
117
+
118
+ # Remove the handles from the object dictionary and adjust the id, parent, and children of the parts
119
+ for handle_idx in handle_idxs:
120
+ handle = obj_dict["diffuse_tree"][handle_idx]
121
+ parent_idx = handle["parent"]
122
+ if handle_idx in obj_dict["diffuse_tree"][parent_idx]["children"]:
123
+ obj_dict["diffuse_tree"][parent_idx]["children"].remove(handle_idx)
124
+ obj_dict["diffuse_tree"].pop(handle_idx)
125
+
126
+ # Adjust the id, parent, and children of the parts
127
+ for part in obj_dict["diffuse_tree"]:
128
+ if part["id"] > handle_idx:
129
+ part["id"] -= 1
130
+ if part["parent"] > handle_idx:
131
+ part["parent"] -= 1
132
+ for i in range(len(part["children"])):
133
+ if part["children"][i] > handle_idx:
134
+ part["children"][i] -= 1
135
+
136
+ handle_idxs -= 1
137
+
138
+ return obj_dict
139
+
140
+
141
+ # def normalize_object(obj_dict):
142
+ # """
143
+ # Normalize the object as a whole\n
144
+ # Make the base part to be centered at the origin and have a size of 2\n
145
+
146
+ # obj_dict: the object dictionary
147
+ # """
148
+ # # Find the base part and compute the translation and scaling factors
149
+ # tree = obj_dict["diffuse_tree"]
150
+ # for part in tree:
151
+ # if part["parent"] == -1:
152
+ # translate = -np.array(part["aabb"]["center"], dtype=np.float32)
153
+ # scale = 2.0 / np.array(part["aabb"]["size"], dtype=np.float32)
154
+ # break
155
+
156
+ # for part in tree:
157
+ # part["aabb"]["center"] = (
158
+ # np.array(part["aabb"]["center"], dtype=np.float32) + translate
159
+ # ) * scale
160
+ # part["aabb"]["size"] = np.array(part["aabb"]["size"], dtype=np.float32) * scale
161
+ # if part["joint"]["type"] != "fixed":
162
+ # part["joint"]["axis"]["origin"] = (
163
+ # np.array(part["joint"]["axis"]["origin"], dtype=np.float32) + translate
164
+ # ) * scale
165
+
166
+ def zero_center_object(obj_dict):
167
+ """
168
+ Zero center the object as a whole\n
169
+
170
+ - obj_dict: the object dictionary
171
+ """
172
+
173
+ bbox_min = np.zeros((len(obj_dict["diffuse_tree"]), 3))
174
+ bbox_max = np.zeros((len(obj_dict["diffuse_tree"]), 3))
175
+
176
+ # For each part, compute the bounding box and store the min and max vertices
177
+ for part_idx, part in enumerate(obj_dict["diffuse_tree"]):
178
+ bbox_center = np.array(part["aabb"]["center"])
179
+ bbox_size_half = np.array(part["aabb"]["size"]) / 2
180
+ bbox_min[part_idx] = bbox_center - bbox_size_half
181
+ bbox_max[part_idx] = bbox_center + bbox_size_half
182
+
183
+ # Compute the overall bounding box size
184
+ bbox_min = np.min(bbox_min, axis=0)
185
+ bbox_max = np.max(bbox_max, axis=0)
186
+ bbox_center = (bbox_min + bbox_max) / 2
187
+
188
+ translate = -bbox_center
189
+
190
+ for part in obj_dict["diffuse_tree"]:
191
+ part["aabb"]["center"] = np.array(part["aabb"]["center"]) + translate
192
+ if part["joint"]["type"] != "fixed":
193
+ part["joint"]["axis"]["origin"] = np.array(part["joint"]["axis"]["origin"]) + translate
194
+
195
+
196
+ def rescale_object(obj_dict, scale_factor):
197
+ """
198
+ Rescale the object as a whole\n
199
+
200
+ - obj_dict: the object dictionary\n
201
+ - scale_factor: the scale factor to rescale the object
202
+ """
203
+
204
+ for part in obj_dict["diffuse_tree"]:
205
+ part["aabb"]["center"] = (
206
+ np.array(part["aabb"]["center"], dtype=np.float32) * scale_factor
207
+ )
208
+ part["aabb"]["size"] = (
209
+ np.array(part["aabb"]["size"], dtype=np.float32) * scale_factor
210
+ )
211
+ if part["joint"]["type"] != "fixed":
212
+ part["joint"]["axis"]["origin"] = (
213
+ np.array(part["joint"]["axis"]["origin"], dtype=np.float32)
214
+ * scale_factor
215
+ )
216
+
217
+
218
+ def find_part_mapping(obj1_dict, obj2_dict, use_hungarian=False):
219
+ """
220
+ Find the correspondences from the first object to the second object based on closest bbox centers\n
221
+
222
+ - obj1_dict: the first object dictionary\n
223
+ - obj2_dict: the second object dictionary\n
224
+
225
+ Return:\n
226
+ - mapping: the mapping from the first object to the second object in the form: [[obj_part_idx, distance], ...]
227
+ """
228
+ if use_hungarian:
229
+ return hungarian_matching(obj1_dict, obj2_dict)
230
+
231
+ # Initialize the distances to be +inf
232
+ mapping = np.ones((len(obj1_dict["diffuse_tree"]), 2)) * np.inf
233
+
234
+ # For each part in the first object, find the closest part in the second object based on the bounding box center
235
+ for req_part_idx, req_part in enumerate(obj1_dict["diffuse_tree"]):
236
+ for obj_part_idx, obj_part in enumerate(obj2_dict["diffuse_tree"]):
237
+ distance = np.linalg.norm(
238
+ np.array(req_part["aabb"]["center"])
239
+ - np.array(obj_part["aabb"]["center"])
240
+ )
241
+ if distance < mapping[req_part_idx, 1]:
242
+ mapping[req_part_idx, :] = [obj_part_idx, distance]
243
+
244
+ return mapping
245
+
246
+
247
+ def hungarian_matching(obj1_dict, obj2_dict):
248
+ """
249
+ Find the correspondences from the first object to the second object based on closest bbox centers using Hungarian algorithm\n
250
+
251
+ - obj1_dict: the first object dictionary\n
252
+ - obj2_dict: the second object dictionary\n
253
+
254
+ Return:\n
255
+ - mapping: the mapping from the first object to the second object in the form: [[obj_part_idx], ...]
256
+ """
257
+ INF = 9999999
258
+
259
+ tree1 = obj1_dict["diffuse_tree"]
260
+ tree2 = obj2_dict["diffuse_tree"]
261
+
262
+ n_parts1 = len(tree1)
263
+ n_parts2 = len(tree2)
264
+ n_parts_max = max(n_parts1, n_parts2)
265
+
266
+ # Initialize the cost matrix
267
+ cost_matrix = np.ones((n_parts_max, n_parts_max), dtype=np.float32) * INF
268
+ for i in range(n_parts1):
269
+ for j in range(n_parts2):
270
+ cost_matrix[i, j] = np.linalg.norm(
271
+ np.array(tree1[i]["aabb"]["center"], dtype=np.float32)
272
+ - np.array(tree2[j]["aabb"]["center"], dtype=np.float32)
273
+ )
274
+
275
+ # Find the correspondences using the Hungarian algorithm
276
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
277
+
278
+ # Valid correspondences are those with all cost less than INF
279
+ valid_correspondences = np.where(cost_matrix[row_ind, col_ind] < INF)[0]
280
+ invalid_correspondences = np.where(np.logical_not(cost_matrix[row_ind, col_ind] < INF))[0]
281
+
282
+ row_i = row_ind[valid_correspondences]
283
+ col_i = col_ind[valid_correspondences]
284
+
285
+ # Construct the mapping
286
+ mapping = np.zeros(
287
+ (n_parts1, 2), dtype=np.float32
288
+ )
289
+ mapping[row_i, 0] = col_i
290
+ mapping[row_i, 1] = cost_matrix[row_i, col_i]
291
+
292
+ # assign the index of the most closely matched part
293
+ if n_parts1 > n_parts2:
294
+ row_j = row_ind[invalid_correspondences]
295
+ col_j = cost_matrix[row_j, :].argmin(axis=1)
296
+ mapping[row_j, 0] = col_j
297
+ mapping[row_j, 1] = cost_matrix[row_j, col_j]
298
+
299
+ return mapping
objects/motions.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import quaternion
5
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+ from objects.dict_utils import get_base_part_idx
7
+
8
+ def transform_all_parts(part_vertices, obj_dict, joint_state,
9
+ rotation_fix_range=True, dry_run=True):
10
+ """
11
+ Transform all parts of the object according to the joint state\n
12
+
13
+ - part_vertices: vertices of the object in rest pose in the form:\n
14
+ - [K_parts, N_vertices, 3]\n
15
+ - obj_dict: the object dictionary\n
16
+ - joint_state: the joint state in the range of [0, 1]\n
17
+ - rotation_fix_range (optional): whether to fix the rotation range to 90 degrees for revolute joints\n
18
+ - dry_run (optional): if True, only return the transformation matrices without changing the vertices\n
19
+
20
+ Return:\n
21
+ - part_transformations: records of the transformations applied to the parts\n
22
+ """
23
+ part_transformations = [[] for _ in range(len(obj_dict["diffuse_tree"]))]
24
+ if joint_state == 0.0:
25
+ return part_transformations
26
+
27
+ # Get a visit order of the parts such that children parts are visited before parents
28
+ part_visit_order = []
29
+ base_idx = get_base_part_idx(obj_dict)
30
+ indices_to_visit = [base_idx]
31
+ while len(indices_to_visit) > 0: # Breadth-first traversal
32
+ current_idx = indices_to_visit.pop(0)
33
+ part_visit_order.append(current_idx)
34
+ # if current_idx == 9:
35
+ # import ipdb
36
+ # ipdb.set_trace()
37
+ indices_to_visit += obj_dict["diffuse_tree"][current_idx]["children"]
38
+ part_visit_order.reverse()
39
+
40
+ # Transform the parts in the visit order - children first, then parents
41
+ for i in part_visit_order:
42
+ part = obj_dict["diffuse_tree"][i]
43
+ joint = part["joint"]
44
+ children_idxs = part["children"]
45
+
46
+ # Store the transformation used to transform the part and its children
47
+ applied_tramsformation_matrix = np.eye(4, dtype=np.float32)
48
+ applied_rotation_axis_origin = np.array([np.nan, np.nan, np.nan], dtype=np.float32)
49
+ applied_transformation_type = "none"
50
+ if joint["type"] == "prismatic":
51
+ # Translate the part and its children
52
+ translation = np.array(joint["axis"]["direction"], dtype=np.float32) * joint["range"][1] * joint_state
53
+
54
+ if not dry_run:
55
+ part_vertices[[i] + children_idxs] += translation
56
+
57
+ # Store the transformation used
58
+ applied_tramsformation_matrix[:3, 3] = translation
59
+ applied_transformation_type = "translation"
60
+
61
+ elif joint["type"] == "revolute" or joint["type"] == "continuous":
62
+ if joint["type"] == "revolute":
63
+ if not rotation_fix_range:
64
+ # Use the full range as specified in the object file
65
+ rotation_radian = np.radians(joint["range"][1] * joint_state)
66
+ else:
67
+ # Fix the rotation range to 90 degrees
68
+ rotation_range_sign = np.sign(joint["range"][1])
69
+ rotation_radian = np.radians(rotation_range_sign * 90 * joint_state)
70
+
71
+ else:
72
+ rotation_radian = np.radians(360 * joint_state)
73
+
74
+ # Prepare the rotation matrix via axis-angle representation and quaternion
75
+ rotation_axis_origin = np.array(joint["axis"]["origin"], dtype=np.float32)
76
+ rotation_axis_direction = np.array(joint["axis"]["direction"], dtype=np.float32) / np.linalg.norm(joint["axis"]["direction"])
77
+ rotation_matrix = quaternion.as_rotation_matrix(quaternion.from_rotation_vector(rotation_radian * rotation_axis_direction))
78
+
79
+ if not dry_run:
80
+ # Rotate the part and its children
81
+ vertices_to_rotate = (part_vertices[[i] + children_idxs] - rotation_axis_origin)
82
+ part_vertices[[i] + children_idxs] = np.matmul(rotation_matrix, vertices_to_rotate.transpose([0, 2, 1])).transpose([0, 2, 1]) + rotation_axis_origin
83
+
84
+ # Store the transformation used
85
+ applied_tramsformation_matrix[:3, :3] = rotation_matrix
86
+ applied_rotation_axis_origin = rotation_axis_origin
87
+ applied_transformation_type = "rotation"
88
+
89
+ # Record the transformation used
90
+ if not applied_transformation_type == "none":
91
+ record = {
92
+ "type": applied_transformation_type,
93
+ "matrix": applied_tramsformation_matrix,
94
+ "rotation_axis_origin": applied_rotation_axis_origin
95
+ }
96
+ for idx in [i] + children_idxs:
97
+ part_transformations[idx].append(record)
98
+
99
+ return part_transformations
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+
3
+ torch==2.2.2
4
+ torchvision==0.17.2
5
+ pytorch-lightning==2.4.0
6
+ lightning==2.3.3
7
+ matplotlib
8
+ numpy==1.26.4
9
+ gradio==5.34.2
10
+ wandb
11
+ omegaconf
12
+ imageio
13
+ diffusers
14
+ plotly
15
+ pybullet
16
+ pyrender
17
+ trimesh
18
+ numpy-quaternion
19
+ openai
20
+ spaces
21
+ json_repair
retrieval/__init__.py ADDED
File without changes
retrieval/obj_retrieval.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import random
4
+ import json
5
+ import numpy as np
6
+ from copy import deepcopy
7
+
8
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
+ from metrics.iou_cdist import IoU_cDist
10
+ import networkx as nx
11
+
12
+ all_categories = [
13
+ "Table",
14
+ "StorageFurniture",
15
+ "WashingMachine",
16
+ "Microwave",
17
+ "Dishwasher",
18
+ "Refrigerator",
19
+ "Oven",
20
+ ]
21
+
22
+ all_categories_acd = [
23
+ 'armoire',
24
+ 'bookcase',
25
+ 'chestofdrawers',
26
+ 'hangingcabinet',
27
+ 'kitchencabinet'
28
+ ]
29
+
30
+
31
+ def get_hash(file, key="diffuse_tree", ignore_handles=True, dag=False):
32
+ tree = file[key]
33
+ if dag:
34
+ G = nx.DiGraph()
35
+ else:
36
+ G = nx.Graph()
37
+ for node in tree:
38
+ if ignore_handles and "handle" in node["name"].lower():
39
+ continue
40
+ G.add_node(node["id"])
41
+ if node["parent"] != -1:
42
+ G.add_edge(node["id"], node["parent"])
43
+ hashcode = nx.weisfeiler_lehman_graph_hash(G)
44
+ return hashcode
45
+
46
+
47
+ def _verify_mesh_exists(dir, ply_files, verbose=False):
48
+ """
49
+ Verify that the mesh files exist\n
50
+
51
+ - dir: the directory of the object\n
52
+ - ply_files: the list of mesh files\n
53
+ - verbose (optional): whether to print the progress\n
54
+
55
+ return:\n
56
+ - True if the mesh files exist, False otherwise
57
+ """
58
+
59
+ for ply_file in ply_files:
60
+ if not os.path.exists(os.path.join(dir, ply_file)):
61
+ if verbose:
62
+ print(f" - {os.path.join(dir, ply_file)} does not exist!!!")
63
+ return False
64
+ return True
65
+
66
+
67
+ def _generate_output_part_dicts(
68
+ candidate_dict,
69
+ part_idx,
70
+ candidate_dir,
71
+ requirement_part_bbox_sizes,
72
+ bbox_size_eps=1e-3,
73
+ verbose=False,
74
+ ):
75
+ """
76
+ Generate the output part dictionary for all parts that are fulfilled by the candidate part and computing the scale factor of the parts\n
77
+
78
+ - candidate_dict: the candidate object dictionary\n
79
+ - part_idx: the index of the part in the candidate object\n
80
+ - candidate_dir: the directory of the candidate object\n
81
+ - requirement_part_bbox_sizes: the bounding box sizes of the requirement part in the form: [[lx1, ly1, lz1], [lx2, ly2, lz2], ...]\n
82
+ - bbox_size_eps (optional): the epsilon to avoid zero volume parts\n
83
+ - verbose (optional): whether to print the progress\n
84
+
85
+ Return:\n
86
+ - part_dicts: the output part dictionaries in the form:
87
+ - [{name, dir, files, scale_factor=[sx, sy, sz]}, z_rotate_90]
88
+ - z_rotate_90 is True if the part needs to be rotated by 90 degrees around the z-axis
89
+ - [{}, ...] if any of the mesh files do not exist
90
+ """
91
+
92
+ part_dicts = [{} for _ in range(len(requirement_part_bbox_sizes))]
93
+ fixed_portion = {
94
+ "name": candidate_dict["diffuse_tree"][part_idx]["name"],
95
+ "dir": candidate_dir,
96
+ "files": candidate_dict["diffuse_tree"][part_idx]["plys"],
97
+ "z_rotate_90": False,
98
+ }
99
+
100
+ # Verify that the mesh files exist
101
+ if not _verify_mesh_exists(fixed_portion["dir"], fixed_portion["files"], verbose):
102
+ if verbose:
103
+ print(
104
+ f" - ! Found invalid mesh files in {fixed_portion['dir']}, skipping..."
105
+ )
106
+ return part_dicts # List of empty dicts
107
+
108
+ candidate_bbox_size = np.array(
109
+ candidate_dict["diffuse_tree"][part_idx]["aabb"]["size"]
110
+ )
111
+ candidate_bbox_size = np.maximum(
112
+ candidate_bbox_size, bbox_size_eps
113
+ ) # Avoid zero volume parts
114
+
115
+ for i, requirement_part_bbox_size in enumerate(requirement_part_bbox_sizes):
116
+ part_dicts[i] = deepcopy(fixed_portion)
117
+
118
+ # For non-handle parts, compute the scale factor normally
119
+ if fixed_portion["name"] != "handle":
120
+ part_dicts[i]["scale_factor"] = list(
121
+ np.array(requirement_part_bbox_size) / candidate_bbox_size
122
+ )
123
+
124
+ # For handles, need to consider the orientation of the selected handle and the orientation of the requirement handle
125
+ else:
126
+ requirement_handle_is_horizontal = (
127
+ requirement_part_bbox_size[0] > requirement_part_bbox_size[1]
128
+ )
129
+ candidate_handle_is_horizontal = (
130
+ candidate_bbox_size[0] > candidate_bbox_size[1]
131
+ )
132
+
133
+ # If the orientations are different, rotate the requirement handle by 90 degrees around the z-axis before computing the scale factor
134
+ if requirement_handle_is_horizontal != candidate_handle_is_horizontal:
135
+ rotated_requirement_part_bbox_size = [
136
+ requirement_part_bbox_size[1],
137
+ requirement_part_bbox_size[0],
138
+ requirement_part_bbox_size[2],
139
+ ]
140
+ part_dicts[i]["scale_factor"] = list(
141
+ np.array(rotated_requirement_part_bbox_size) / candidate_bbox_size
142
+ )
143
+ part_dicts[i]["z_rotate_90"] = True
144
+
145
+ # If the orientations are the same, compute the scale factor normally
146
+ else:
147
+ part_dicts[i]["scale_factor"] = list(
148
+ np.array(requirement_part_bbox_size) / candidate_bbox_size
149
+ )
150
+
151
+ return part_dicts
152
+
153
+
154
+ def find_obj_candidates(
155
+ requirement_dict,
156
+ dataset_dir,
157
+ hashbook_path,
158
+ num_states=5,
159
+ metric_compare_handles=False,
160
+ metric_iou_include_base=True,
161
+ metric_num_samples=10000,
162
+ keep_top=5,
163
+ gt_file_name="object.json",
164
+ verbose=False,
165
+ ):
166
+ """
167
+ Find the best object candidates for selecting the base part using AID\n
168
+
169
+ - requirement_dict: the object dictionary of the requirement\n
170
+ - dataset_dir: the directory of the dataset to search in\n
171
+ - hashbook_path: the path to the hashbook for filtering candidates\n
172
+ - num_states: the number of states to average the metric over\n
173
+ - metric_transform_plucker (optional): whether to use Plucker coordinates to move parts when computing the metric\n
174
+ - metric_compare_handles (optional): whether to compare handles when computing the metric\n
175
+ - metric_iou_include_base (optional): whether to include the base when computing the IoU\n
176
+ - metric_scale_factor (optional): the scale factor to scale the object before computing the metric\n
177
+ - Scaling up the object makes the sampling more well distributed\n
178
+ - metric_num_samples (optional): the number of samples to use when computing the metric\n
179
+ - keep_top (optional): the number of top candidates to keep\n
180
+ - gt_file_name (optional): the name of the ground truth json file, which describes a candidate object\n
181
+ - verbose (optional): whether to print the progress\n
182
+
183
+ return:\n
184
+ - a list of best object candidates of the form:
185
+ - {"category", "dir", "score"}
186
+ """
187
+ dataset_dir = os.path.abspath(dataset_dir)
188
+
189
+ # Load the hashbook
190
+ with open(hashbook_path, "r") as f:
191
+ hashbook = json.load(f)
192
+
193
+ if 'acd' in hashbook_path:
194
+ all_categories = all_categories_acd
195
+ else:
196
+ all_categories = [
197
+ "Table",
198
+ "StorageFurniture",
199
+ "WashingMachine",
200
+ "Microwave",
201
+ "Dishwasher",
202
+ "Refrigerator",
203
+ "Oven",
204
+ ]
205
+
206
+ # Resolve paths to directories
207
+ category_specified = False
208
+ requirement_category = ""
209
+
210
+ # if the category is specified, only search in that category, otherwise search in all categories
211
+ if "obj_cat" in requirement_dict["meta"]:
212
+ requirement_category = requirement_dict["meta"]["obj_cat"]
213
+ category_specified = True
214
+ if requirement_category == "StroageFurniture":
215
+ requirement_category = "StorageFurniture"
216
+ category_dirs = (
217
+ [os.path.join(dataset_dir, requirement_category)]
218
+ if category_specified
219
+ else [os.path.join(dataset_dir, category) for category in all_categories]
220
+ )
221
+
222
+ # Extract requirement data
223
+ requirement_part_names = []
224
+ requirement_part_bboxes = []
225
+ for part in requirement_dict["diffuse_tree"]:
226
+ requirement_part_names.append(part["name"])
227
+ requirement_part_bboxes.append(
228
+ np.concatenate([part["aabb"]["center"], part["aabb"]["size"]])
229
+ )
230
+
231
+ # Compute hash of the requirement graph
232
+ requirement_graph_hash = get_hash(requirement_dict)
233
+
234
+ # Prefetch list of ids of candidate objects with the same hash
235
+ # import ipdb
236
+ # ipdb.set_trace()
237
+ if category_specified and requirement_graph_hash in hashbook[requirement_category]:
238
+ same_hash_obj_ids = hashbook[requirement_category][requirement_graph_hash]
239
+ else:
240
+ # Use all categories if category is not specified
241
+ same_hash_obj_ids = []
242
+ for category in all_categories:
243
+ if requirement_graph_hash in hashbook[category]:
244
+ same_hash_obj_ids += hashbook[category][requirement_graph_hash]
245
+
246
+ # Iterate through all candidate objects and keep the top k candidates
247
+ best_obj_candidates = []
248
+ for category_dir in category_dirs:
249
+ obj_ids = os.listdir(category_dir)
250
+ for i, obj_id in enumerate(obj_ids):
251
+ if verbose:
252
+ print(
253
+ f"\r - Finding candidates from {category_dir.split('/')[-1]}: {i+1}/{len(obj_ids)}",
254
+ end="",
255
+ )
256
+
257
+ # Load the candidate object
258
+ obj_dir = os.path.join(category_dir, obj_id)
259
+ if os.path.exists(os.path.join(obj_dir, gt_file_name)):
260
+ with open(os.path.join(obj_dir, gt_file_name), "r") as f:
261
+ obj_dict = json.load(f)
262
+ if "diffuse_tree" not in obj_dict: # Rename for compatibility
263
+ obj_dict["diffuse_tree"] = obj_dict.pop("arti_tree")
264
+
265
+ # Compute metric for selecting the base if the hash matches or if there are no objects with the same hash
266
+ if obj_id in same_hash_obj_ids or len(same_hash_obj_ids) == 0:
267
+ scores = IoU_cDist(
268
+ requirement_dict,
269
+ obj_dict,
270
+ num_states=num_states,
271
+ compare_handles=metric_compare_handles,
272
+ iou_include_base=metric_iou_include_base,
273
+ num_samples=metric_num_samples,
274
+ )
275
+ base_score = scores["AS-cDist"]
276
+
277
+ # Add the candidate to the list of best candidates and keep the top k candidates
278
+ best_obj_candidates.append(
279
+ {
280
+ "category": category_dir.split("/")[-1],
281
+ "dir": obj_dir,
282
+ "score": base_score,
283
+ }
284
+ )
285
+ best_obj_candidates = sorted(
286
+ best_obj_candidates, key=lambda x: x["score"]
287
+ )[:keep_top]
288
+ if verbose:
289
+ print()
290
+
291
+ return best_obj_candidates
292
+
293
+
294
+ def pick_and_rescale_parts(
295
+ requirement_dict,
296
+ obj_candidates,
297
+ dataset_dir,
298
+ gt_file_name="object.json",
299
+ verbose=False,
300
+ ):
301
+ """
302
+ Pick and rescale parts from the object candidates
303
+
304
+ - requirement_dict: the object dictionary of the requirement\n
305
+ - obj_candidates: the list of best object candidates for selecting the base part\n
306
+ - dataset_dir: the directory of the dataset to search in\n
307
+ - gt_file_name (optional): the name of the ground truth file, which describes a candidate object\n
308
+ - verbose (optional): whether to print the progress\n
309
+
310
+ return:\n
311
+ - parts_to_render: a list of selected parts for the requirement parts in the form:
312
+ - [{name, dir, files, scale_factor=[sx, sy, sz]}, z_rotate_90]
313
+ - z_rotate_90 is True if the part needs to be rotated by 90 degrees around the z-axis
314
+ """
315
+
316
+ # Extract requirement data
317
+ if 'acd' in dataset_dir:
318
+ all_categories = all_categories_acd
319
+ else:
320
+ all_categories = [
321
+ "Table",
322
+ "StorageFurniture",
323
+ "WashingMachine",
324
+ "Microwave",
325
+ "Dishwasher",
326
+ "Refrigerator",
327
+ "Oven",
328
+ ]
329
+ requirement_part_names = []
330
+ requirement_part_bbox_sizes = []
331
+ for part in requirement_dict["diffuse_tree"]:
332
+ if part['name'] == 'wheel':
333
+ part['name'] = 'handle'
334
+ requirement_part_names.append(part["name"])
335
+ requirement_part_bbox_sizes.append(part["aabb"]["size"])
336
+
337
+ # Collect the unique part names and store the indices of the parts with the same name
338
+ unique_requirement_part_names = {}
339
+ for i, part_name in enumerate(requirement_part_names):
340
+ if part_name not in unique_requirement_part_names:
341
+ unique_requirement_part_names[part_name] = [i]
342
+ else:
343
+ unique_requirement_part_names[part_name].append(i)
344
+
345
+ parts_to_render = [{} for _ in range(len(requirement_part_names))]
346
+
347
+ # Iterate through the object candidates selected for the base part first
348
+ for candidate in obj_candidates:
349
+ if all(
350
+ [len(part) > 0 for part in parts_to_render]
351
+ ): # Break if all parts are fulfilled
352
+ break
353
+
354
+ if not os.path.exists(os.path.join(candidate["dir"], gt_file_name)):
355
+ continue
356
+ # Load the candidate object
357
+ with open(os.path.join(candidate["dir"], gt_file_name), "r") as f:
358
+ candidate_dict = json.load(f)
359
+
360
+ # Pick parts from the candidate if the part name matches and the part requirement is not yet fulfilled
361
+ for candidate_part_idx, part in enumerate(candidate_dict["diffuse_tree"]):
362
+ part_needed = part["name"] in unique_requirement_part_names
363
+ if not part_needed:
364
+ continue
365
+
366
+ part_not_fulfilled = any(
367
+ [
368
+ len(parts_to_render[i]) == 0
369
+ for i in unique_requirement_part_names[part["name"]]
370
+ ]
371
+ )
372
+ if not part_not_fulfilled:
373
+ continue
374
+
375
+ # Get the indices of the requirement parts that are fulfilled by this candidate part and their bounding box sizes
376
+ fullfill_part_idxs = unique_requirement_part_names[part["name"]]
377
+ fullfill_part_bbox_sizes = [
378
+ requirement_part_bbox_sizes[i] for i in fullfill_part_idxs
379
+ ]
380
+ # Generate all output part dictionaries at once
381
+ part_dicts = _generate_output_part_dicts(
382
+ candidate_dict,
383
+ candidate_part_idx,
384
+ candidate["dir"],
385
+ fullfill_part_bbox_sizes,
386
+ verbose=verbose,
387
+ )
388
+ # Update the output part dictionaries
389
+ [
390
+ parts_to_render[part_idx].update(part_dicts[part_dict_idx])
391
+ for part_dict_idx, part_idx in enumerate(fullfill_part_idxs)
392
+ ]
393
+
394
+ # If there are still parts that are not fulfilled
395
+ if any([len(part) == 0 for part in parts_to_render]):
396
+ # Collect the remaining part names
397
+ remaining_part_names = list(
398
+ set(
399
+ [
400
+ requirement_part_names[i]
401
+ for i in range(len(requirement_part_names))
402
+ if len(parts_to_render[i]) == 0
403
+ ]
404
+ )
405
+ )
406
+ if verbose:
407
+ print(
408
+ f" - Parts {remaining_part_names} are not fulfilled by the selected candidates, searching in the dataset..."
409
+ )
410
+
411
+ # If the category is specified, only search in that category, otherwise search in all categories
412
+ # requirement_dict["meta"]["obj_cat"] = ""
413
+ requirement_category = requirement_dict["meta"]["obj_cat"]
414
+ if requirement_category == "StroageFurniture":
415
+ requirement_category = "StorageFurniture"
416
+ category_specified = requirement_category != ""
417
+ if category_specified:
418
+ category_dirs = [os.path.join(dataset_dir, requirement_category)]
419
+ else:
420
+ category_dirs = [
421
+ os.path.join(dataset_dir, category) for category in all_categories
422
+ ]
423
+
424
+ # Iterate through all objects
425
+ retry = True # Retry if the category is specified, but some parts are still not fulfilled (See the end of the while loop)
426
+ retry_time = 0
427
+ while retry:
428
+ print(retry_time)
429
+ retry_time += 1
430
+ for category_dir in category_dirs:
431
+ obj_ids = os.listdir(category_dir)
432
+ random.shuffle(obj_ids) # Randomize the order of the objects
433
+ for i, obj_id in enumerate(obj_ids):
434
+ if True:
435
+ print(
436
+ f"- Finding missing parts from {category_dir.split('/')[-1]}: {i+1}/{len(obj_ids)} \n"
437
+ )
438
+
439
+ # Load the candidate object
440
+ obj_dir = os.path.join(category_dir, obj_id)
441
+ if not os.path.exists(os.path.join(obj_dir, gt_file_name)):
442
+ continue
443
+ with open(os.path.join(obj_dir, gt_file_name), "r") as f:
444
+ candidate_dict = json.load(f)
445
+
446
+ # Pick the part from the candidate if the part name matches and the parts that are not fulfilled
447
+ for candidate_part_idx, part in enumerate(
448
+ candidate_dict["diffuse_tree"]
449
+ ):
450
+ part_needed = part["name"] in remaining_part_names
451
+
452
+ if part_needed:
453
+ # Get the indices of the requirement parts that are fulfilled by this candidate part and their bounding box sizes
454
+ fullfill_part_idxs = unique_requirement_part_names[
455
+ part["name"]
456
+ ]
457
+ fullfill_part_bbox_sizes = [
458
+ requirement_part_bbox_sizes[i]
459
+ for i in fullfill_part_idxs
460
+ ]
461
+
462
+ # Generate all output part dictionaries at once
463
+ part_dicts = _generate_output_part_dicts(
464
+ candidate_dict,
465
+ candidate_part_idx,
466
+ obj_dir,
467
+ fullfill_part_bbox_sizes,
468
+ verbose=verbose,
469
+ )
470
+
471
+ # Update the output part dictionaries
472
+ [
473
+ parts_to_render[part_idx].update(
474
+ part_dicts[part_dict_idx]
475
+ )
476
+ for part_dict_idx, part_idx in enumerate(
477
+ fullfill_part_idxs
478
+ )
479
+ ]
480
+
481
+ if all([len(part) > 0 for part in parts_to_render]):
482
+ if verbose:
483
+ print(" -> Found all missing parts")
484
+ break
485
+ if all([len(part) > 0 for part in parts_to_render]):
486
+ retry = False
487
+ break
488
+
489
+ # If the category is specified, but some parts are still not fulfilled, search in all categories
490
+ if category_specified and any([len(part) == 0 for part in parts_to_render]):
491
+ if verbose:
492
+ print(
493
+ " - Required category is {requirement_category}, but some parts are still not fulfilled, searching in all categories..."
494
+ )
495
+ category_specified = False
496
+ retry = True
497
+ category_dirs = [
498
+ os.path.join(dataset_dir, category)
499
+ for category in all_categories
500
+ if category != requirement_category
501
+ ]
502
+
503
+ # Raise error if there are still parts that are not fulfilled
504
+ if any([len(part) == 0 for part in parts_to_render]):
505
+ raise RuntimeError(
506
+ "Failed to fulfill all requirements, some parts may not exist in the dataset"
507
+ )
508
+
509
+ return parts_to_render
retrieval/retrieval_hash_acd.json ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "armoire": {
3
+ "ff24f9310d003cd7b7894b2d6ec79a03": [
4
+ "B07H8V49M2"
5
+ ],
6
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
7
+ "B07GFDZVYY"
8
+ ],
9
+ "5144181ac27497fdfa9bdb5b8b799630": [
10
+ "B07GFSJ69T",
11
+ "8415f258d5e129e9bb63cab54c7c207e04c0dfed",
12
+ "06764d11dcec69878cc762c36482be5ef2865443",
13
+ "0760a3dd43bd9dd9c0ec1ea4e033c6e121d92df5",
14
+ "12001de4686bf2e4b9c721c93b35a424dd48249f"
15
+ ],
16
+ "8fae315ff2d1ea185b4ed6d0d092ae0e": [
17
+ "B07GFW9GFX",
18
+ "16b93d86d1a466f5982e60c8d322ddd8f312056e"
19
+ ],
20
+ "c24dd733315066f7c7da3d578f954d8c": [
21
+ "B07H8PS4FZ",
22
+ "cd17a7b8d78ee79bc52015841577a2652f9e0625",
23
+ "3033be75f2ac15885e8f4813dc55e168af09e91f"
24
+ ],
25
+ "02746714ea1f7e20dd4ac8f18d8cdac3": [
26
+ "3bc24e0ea79ade13d4703ca23ece1e4019f9b70b"
27
+ ],
28
+ "dd3473c941b94dd6654e9f89bb51cac9": [
29
+ "128b5f2d072869004b7a218ab674f93f73a66670"
30
+ ],
31
+ "cde6b48ed870286595c1455af7aff8bd": [
32
+ "69340d8678701e72f39cc01890da1b8af3fd603d",
33
+ "24f8284e4bdef4397e5b12dc4f2b74a137d63dbb"
34
+ ],
35
+ "5a8eac0760a558d4174437be478ec0aa": [
36
+ "a88c710a0b90706398a0fd7a9d73123338d04354"
37
+ ],
38
+ "c502b67eb6d91d909ba398fa39bec60c": [
39
+ "28098e6540f2f2fc07f7fc6a00edd6ce371a2618"
40
+ ],
41
+ "00232256ef3ac441f59b36bfc7bd190c": [
42
+ "11f8b552a802c6233a1332713568f05f901b725c"
43
+ ]
44
+ },
45
+ "cabinet": {
46
+ "c502b67eb6d91d909ba398fa39bec60c": [
47
+ "B07D42T6CX"
48
+ ],
49
+ "d25563e624d9195ce94b1f768fdc503d": [
50
+ "B07MGL8651"
51
+ ],
52
+ "2fce5c033589d2dfe24fa67dc6885386": [
53
+ "B07QD6V13M"
54
+ ]
55
+ },
56
+ "nightstand": {
57
+ "5144181ac27497fdfa9bdb5b8b799630": [
58
+ "B072ZK8897",
59
+ "xxxx54c93c4fxabe2x463exafd8xcae01106dc7d",
60
+ "037f34132f162235d80ce46f67c4fa2238d94da0",
61
+ "155c182834f40ebb1d5666d3a72ee828e097097a"
62
+ ],
63
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
64
+ "B072ZMHBKQ",
65
+ "fed2a84682713eabe2b8d0e1e950d891c7442d5f"
66
+ ],
67
+ "5a8eac0760a558d4174437be478ec0aa": [
68
+ "1fa3dec03ab0afb0749373b2c8da8bf77c92e271",
69
+ "265e819d1f027cde8dac05468a70455f62cbf069",
70
+ "4bd1e4ec215403d3239f09517e44568f78da3b40"
71
+ ],
72
+ "cde6b48ed870286595c1455af7aff8bd": [
73
+ "b8bdd9cbc1ca695a206583afdc26f1a4c3987303",
74
+ "0dfabed4818c34cbaa5ef41a3bcaf89177744b2c"
75
+ ]
76
+ },
77
+ "table": {
78
+ "5144181ac27497fdfa9bdb5b8b799630": [
79
+ "B07K6RNQDH",
80
+ "B07JXXR83F"
81
+ ],
82
+ "c24dd733315066f7c7da3d578f954d8c": [
83
+ "B082YNHQLX"
84
+ ],
85
+ "ddc2ef1be48dc58fe68226818824b648": [
86
+ "B075Z93NHP"
87
+ ],
88
+ "ec04032eee6bc67c5fdd5ec6705c3137": [
89
+ "B075Z93NKX"
90
+ ],
91
+ "cde6b48ed870286595c1455af7aff8bd": [
92
+ "B075Z96KQL"
93
+ ]
94
+ },
95
+ "bookcase": {
96
+ "5a8eac0760a558d4174437be478ec0aa": [
97
+ "2c681d7e64e0410d76156f500cd2df798975a25d",
98
+ "36d7feaf2471b67aa638609fe2c2278fba4a15a0"
99
+ ],
100
+ "ff24f9310d003cd7b7894b2d6ec79a03": [
101
+ "a194a188dfcd8e3796529c0263448ba047ce632f"
102
+ ],
103
+ "c502b67eb6d91d909ba398fa39bec60c": [
104
+ "d2f7d40d8e1a56ede42103027842e000d9cacd3e"
105
+ ],
106
+ "ddc2ef1be48dc58fe68226818824b648": [
107
+ "36d90c43b90a526247d61e709f349b3ed54081e2"
108
+ ],
109
+ "1e8c6b47706f002757c3370366001f06": [
110
+ "d5ba163ba97f94c7aa4a4a625eb0547b8894e1a5"
111
+ ],
112
+ "61f645001e86ad8a32357cc828ae33cb": [
113
+ "3b558ce715c88cca63b307f8e0e9b665ca57ef43"
114
+ ],
115
+ "c24dd733315066f7c7da3d578f954d8c": [
116
+ "44641367282a6a9616c91439944d4160a0e6f66a"
117
+ ],
118
+ "8fae315ff2d1ea185b4ed6d0d092ae0e": [
119
+ "17a8f7c040254439fe8771f66f1c526c71d24904"
120
+ ],
121
+ "0e3c4946251b437ca90e5fe70efdea5b": [
122
+ "9418cb5d4e1b7ff0ebd469e28dfdbaa99bc61f4d"
123
+ ]
124
+ },
125
+ "desk": {
126
+ "ddc2ef1be48dc58fe68226818824b648": [
127
+ "319c007ee7d07ca84c797a512ad4a98c9abc42da",
128
+ "1619e2e6a18d5d374963a3a4280f7a6d76356079"
129
+ ],
130
+ "cde6b48ed870286595c1455af7aff8bd": [
131
+ "7489bcf3f226d315b4627bc442422f44f9eff092",
132
+ "429de8a53ddf94a47ec553147dc8603b803be056"
133
+ ],
134
+ "5a8eac0760a558d4174437be478ec0aa": [
135
+ "3a12fb2ab85d734b4b11bc6f541be9961e7ecd23",
136
+ "03f226dd0e012925a8674564c4de19cb786c9a88"
137
+ ],
138
+ "8fae315ff2d1ea185b4ed6d0d092ae0e": [
139
+ "b1c0fc607c063010c4b9300c955a0e3e5f7001fc"
140
+ ],
141
+ "02746714ea1f7e20dd4ac8f18d8cdac3": [
142
+ "01be253cbfd14b947e9dbe09d0b1959e97d72122"
143
+ ],
144
+ "69144809aea48cb46eae9c3950f24a15": [
145
+ "d7cec0e53dadbc4291064708c84e3614b79ac3c9"
146
+ ],
147
+ "f38a9419ca785a395579ce42491c830e": [
148
+ "561105c73bf76152a2b32e4f55f80db6a25ac0d4"
149
+ ],
150
+ "d25563e624d9195ce94b1f768fdc503d": [
151
+ "128ed2ced9a101aa0d131fd224012fd52198003f"
152
+ ],
153
+ "17a09dc7b6207f11cd18889788802b88": [
154
+ "6d264e3023b940b0b1d31b77e04d0c845853c1f0"
155
+ ]
156
+ },
157
+ "dishwasher": {
158
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
159
+ "cb3dd8f2c8de396606e0794f6effc921aff7235d",
160
+ "aad01c69d7a27a1740e422f5f64b781816bd86fa"
161
+ ]
162
+ },
163
+ "microwave": {
164
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
165
+ "a630fbac79cdc164c9344a588f72d207b6d25e33",
166
+ "bf840db863fc9c2646b2f8f372e4847b2fd42e34",
167
+ "c3b2adbd3b89bdcd01f1e813bc4d2e06975ec727",
168
+ "68407fbf5c7296351b2c26c2e59510effc87a637",
169
+ "81da9279eae235c3faced51d516e970acdac5e84"
170
+ ],
171
+ "c502b67eb6d91d909ba398fa39bec60c": [
172
+ "9849af0395972ff84e40f5c1a51db8a25e3ef6f7"
173
+ ]
174
+ },
175
+ "oven": {
176
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
177
+ "702d3ab650d34ee1dd236b0df5882573ec70c504"
178
+ ],
179
+ "5144181ac27497fdfa9bdb5b8b799630": [
180
+ "00b0d5e167ae6b42666de010025efad4506563f1",
181
+ "c1bdae17057dfa88d5e3894433642030fd66c7d6",
182
+ "197a447eb68ba32ab44c50948b3af1e63048e174",
183
+ "239c5c38a53badc24ca6950ee78d8f6c115c3074",
184
+ "4a326efb8ab35d8ce823575d4c3b7033e8c3e5e8"
185
+ ],
186
+ "cde6b48ed870286595c1455af7aff8bd": [
187
+ "ef32e2cd6dd99d883f627e238f55ad0766240d44",
188
+ "f8bff67d469223c8e8bf44553834bd5482a96ecc",
189
+ "41efc8ed9c8d433e9ff877f3b9ea1c0eda45479c"
190
+ ]
191
+ },
192
+ "refrigerator": {
193
+ "c502b67eb6d91d909ba398fa39bec60c": [
194
+ "1c7874f93ca418d7edebd150a7422095fd76897a"
195
+ ],
196
+ "5144181ac27497fdfa9bdb5b8b799630": [
197
+ "9449ef7831c43bc0db23aac79ea442aa71d0db11",
198
+ "0bb1cdb98fbdfda5b41abaa39aa7f82321e58b72",
199
+ "1c33bd447d70d4d22116c434d912f1fae78e02b7"
200
+ ],
201
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
202
+ "cae4c60830bba615ff533dc23ffee6e6e5c7d14e"
203
+ ]
204
+ },
205
+ "sink_cabinet": {
206
+ "cde6b48ed870286595c1455af7aff8bd": [
207
+ "56dc6fc7669736b5fd6a85d1b14a01d029beff59"
208
+ ],
209
+ "c502b67eb6d91d909ba398fa39bec60c": [
210
+ "f637f110fbed653b7983d9fc6a6d53795b384461"
211
+ ],
212
+ "5144181ac27497fdfa9bdb5b8b799630": [
213
+ "5f074f91cc2ce2a4d5a62e6cce77c435e5dbf457",
214
+ "6219ef05f4a7b56419749e45a45143df8af44495",
215
+ "027f20642dd34e7914fc4fc4efa70fbb54bcecbb",
216
+ "112dc87e26450400941c6eaff60866bc19badc64"
217
+ ],
218
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
219
+ "76102a94ca13d9cdaf7d5a77262cddd4df9a806c"
220
+ ],
221
+ "c24dd733315066f7c7da3d578f954d8c": [
222
+ "829bf87b541238bb7579a05d18d4f7f1d1f98af1"
223
+ ],
224
+ "ada4be0df4d7d600d6729eb4a621f99a": [
225
+ "f1de6498f43789e1c27150b3ae1f9b5bfc051775"
226
+ ]
227
+ },
228
+ "tv_stand": {
229
+ "0e3c4946251b437ca90e5fe70efdea5b": [
230
+ "a77e6006efcacc637e1c2a49e72232ee0f435e35",
231
+ "212778c0c265e0db358ad7c8c1fa9a4bcfe41bd7"
232
+ ],
233
+ "cde6b48ed870286595c1455af7aff8bd": [
234
+ "ba231dd136e3bc77fb04ade17235b923aa7b2f07"
235
+ ],
236
+ "5144181ac27497fdfa9bdb5b8b799630": [
237
+ "056be15536045e9a6a94b9b93cff62f72d43c326",
238
+ "1aebe7d2500bfbcb0c6ec787a98a2b3701099ed7"
239
+ ],
240
+ "94d192237d5fe1b065910cb51d8ee711": [
241
+ "63b7d75b724e079aab99030084f0eae1b43b7498"
242
+ ]
243
+ },
244
+ "washer": {
245
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
246
+ "6cd2dc2611c27f758c972b4874efad8c8cbd5d29",
247
+ "912447108f21083d877aab6653742fceccc6ce7d",
248
+ "00b285ed70673826ccb6941929a64abfaf5f9239",
249
+ "eccf8ba37c804b9067e675f09eec2c13e951f61c",
250
+ "031024e00569909d466d87df9fead90355ba29e5",
251
+ "faefe63ba896c06920aa6d23b05ab83f3d6d37ea",
252
+ "2c2b2914fc526f6e8bbe65511ecf58bba0027ec0",
253
+ "4bc6f883a374400355bcf95e611a4e8f8b950ed5"
254
+ ]
255
+ },
256
+ "chestofdrawers": {
257
+ "5144181ac27497fdfa9bdb5b8b799630": [
258
+ "762665bc7a958151874c15edfc2711b161376678",
259
+ "d96b2b9537c7c721d5a79b375aefbfacecd04f65"
260
+ ],
261
+ "47259e6c2fba7c74a9b725012e01ebba": [
262
+ "d621c3a39d9291f0943531204d68e705633986c9"
263
+ ],
264
+ "c24dd733315066f7c7da3d578f954d8c": [
265
+ "293c42d04758bffcc63070066170cbb1f09918cc",
266
+ "1c0bbc026e76c09885dc5c6f156a6c3dec605d10"
267
+ ],
268
+ "25bfc0f15836b69b830cf66b5217bed6": [
269
+ "807955fd4dcb59b67789b24f2e7bc167027c870a"
270
+ ],
271
+ "322c8717a498a3e832420518775f8ffc": [
272
+ "2cdb28938dff1f9ab13aee7630cea51f44f60952"
273
+ ],
274
+ "cde6b48ed870286595c1455af7aff8bd": [
275
+ "831e2cff446337af8791038161c7d32d96726b20",
276
+ "93767fc355a1afbfff79d54a25204069e0543d2b",
277
+ "11808e4bfc4534caf787fa15ba07bc2cbee95fdd"
278
+ ],
279
+ "3ab3e03b34bc406737d81bd5db0ee212": [
280
+ "4c067979055ae739365d340a1699036be5c136c7"
281
+ ],
282
+ "c502b67eb6d91d909ba398fa39bec60c": [
283
+ "848c2cc9428329445b3de7dc982469c583b16c1d"
284
+ ],
285
+ "7936dbd6c88dda2542d5509f6078a0a9": [
286
+ "09a0a4a031e33e738214b812f71dd838232df54c"
287
+ ],
288
+ "8fae315ff2d1ea185b4ed6d0d092ae0e": [
289
+ "609530b858465d9795ac43ba435fdd5f12d95956",
290
+ "652403bb9a0b199ebdebe538a44bc897eab624f6"
291
+ ],
292
+ "ddc2ef1be48dc58fe68226818824b648": [
293
+ "bc59a49ebf99164d5ed88bd6eaff12ec4ed86d0a"
294
+ ]
295
+ },
296
+ "hangingcabinet": {
297
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
298
+ "B07B8NBQQP",
299
+ "dcef0d475b7fdea2530898215feafddac1fe9bcc",
300
+ "bb966a4f853df29b8b37b89e157aed4eb3936aec"
301
+ ],
302
+ "c502b67eb6d91d909ba398fa39bec60c": [
303
+ "be1b0ef31886b4b2b42cf0bf1b7df548917c9943"
304
+ ],
305
+ "5a8eac0760a558d4174437be478ec0aa": [
306
+ "3432a71596b6cd7e944b6f19cf6d713fe17fc8bd",
307
+ "4468c13bc98184bcc403027164ede52b178e5d20",
308
+ "644c2d1505103189b6ae49f5a58b97ed41202149"
309
+ ]
310
+ },
311
+ "kitchencabinet": {
312
+ "c502b67eb6d91d909ba398fa39bec60c": [
313
+ "3a1c604565d5fa1f411f3fb437a816094ae69122",
314
+ "88c87a19b5e883787b5707d90545e25360594822"
315
+ ],
316
+ "8fae315ff2d1ea185b4ed6d0d092ae0e": [
317
+ "bc3c5c45d5a9126bff85470c7c48d4e2b7ebfd0d"
318
+ ],
319
+ "86f7cf811774c9dc1f8ac7ebefafd51c": [
320
+ "0c42ee5b635c8acdb6bf235ff9420d742a16fd30"
321
+ ],
322
+ "c24dd733315066f7c7da3d578f954d8c": [
323
+ "10444a1165c1745f960eb6183d24dd05b60e781f"
324
+ ],
325
+ "cde6b48ed870286595c1455af7aff8bd": [
326
+ "1a570da5ae0c4d1d51c94be81d7248f155a0dcef"
327
+ ]
328
+ }
329
+ }
retrieval/retrieval_hash_no_handles.json ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Table": {
3
+ "c502b67eb6d91d909ba398fa39bec60c": [
4
+ "25160",
5
+ "32761",
6
+ "28594",
7
+ "25959",
8
+ "34178",
9
+ "23724",
10
+ "21473",
11
+ "26545",
12
+ "26692",
13
+ "27478",
14
+ "19384",
15
+ "26800",
16
+ "25756",
17
+ "23807",
18
+ "26886",
19
+ "31249",
20
+ "21718",
21
+ "23372",
22
+ "32601"
23
+ ],
24
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
25
+ "32566",
26
+ "26525",
27
+ "29557",
28
+ "24644",
29
+ "20453",
30
+ "26806",
31
+ "26670",
32
+ "20279",
33
+ "20411",
34
+ "20555",
35
+ "27189",
36
+ "30869",
37
+ "19855",
38
+ "24931",
39
+ "22339",
40
+ "28668",
41
+ "26073",
42
+ "26652",
43
+ "21467",
44
+ "20985",
45
+ "22241",
46
+ "22508",
47
+ "29921",
48
+ "27044",
49
+ "23511"
50
+ ],
51
+ "5144181ac27497fdfa9bdb5b8b799630": [
52
+ "27619",
53
+ "26875",
54
+ "27267",
55
+ "19825",
56
+ "33914",
57
+ "23472",
58
+ "22692",
59
+ "32174",
60
+ "19179",
61
+ "29133",
62
+ "31601",
63
+ "32086",
64
+ "30238",
65
+ "25308",
66
+ "32746",
67
+ "28164",
68
+ "32932",
69
+ "34617",
70
+ "23782"
71
+ ],
72
+ "dd3473c941b94dd6654e9f89bb51cac9": [
73
+ "30666"
74
+ ],
75
+ "8fae315ff2d1ea185b4ed6d0d092ae0e": [
76
+ "33930",
77
+ "32052",
78
+ "32213"
79
+ ],
80
+ "8999452509fdfa98335c5ba44ed05498": [
81
+ "26503",
82
+ "20043",
83
+ "29525",
84
+ "26657",
85
+ "30341"
86
+ ],
87
+ "cde6b48ed870286595c1455af7aff8bd": [
88
+ "19836",
89
+ "22301",
90
+ "30663",
91
+ "25144",
92
+ "25493",
93
+ "32354",
94
+ "22433",
95
+ "32324",
96
+ "30857",
97
+ "26387",
98
+ "32259",
99
+ "33457",
100
+ "20745",
101
+ "33116",
102
+ "25913"
103
+ ],
104
+ "0e3c4946251b437ca90e5fe70efdea5b": [
105
+ "19898",
106
+ "26608",
107
+ "22367",
108
+ "24152"
109
+ ],
110
+ "c24dd733315066f7c7da3d578f954d8c": [
111
+ "34610",
112
+ "33810"
113
+ ]
114
+ },
115
+ "Dishwasher": {
116
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
117
+ "12480",
118
+ "12579",
119
+ "12542",
120
+ "12085",
121
+ "12559",
122
+ "12092",
123
+ "12580",
124
+ "12565",
125
+ "12612",
126
+ "11622",
127
+ "12530",
128
+ "12071",
129
+ "12654",
130
+ "12259",
131
+ "12558",
132
+ "11700",
133
+ "12553",
134
+ "12414",
135
+ "12543",
136
+ "12561",
137
+ "12590",
138
+ "12540",
139
+ "12621",
140
+ "12531",
141
+ "12614",
142
+ "12560",
143
+ "12428",
144
+ "12606",
145
+ "12552",
146
+ "12592",
147
+ "12617",
148
+ "12583",
149
+ "12605",
150
+ "12596",
151
+ "11661"
152
+ ],
153
+ "cde6b48ed870286595c1455af7aff8bd": [
154
+ "12349",
155
+ "11826"
156
+ ],
157
+ "c502b67eb6d91d909ba398fa39bec60c": [
158
+ "12065",
159
+ "12484"
160
+ ],
161
+ "5144181ac27497fdfa9bdb5b8b799630": [
162
+ "12597"
163
+ ]
164
+ },
165
+ "WashingMachine": {
166
+ "5144181ac27497fdfa9bdb5b8b799630": [
167
+ "100283",
168
+ "103369",
169
+ "103361",
170
+ "100282",
171
+ "103425",
172
+ "103480",
173
+ "103778"
174
+ ],
175
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
176
+ "103490",
177
+ "103775"
178
+ ],
179
+ "cde6b48ed870286595c1455af7aff8bd": [
180
+ "103452",
181
+ "103518",
182
+ "103521"
183
+ ],
184
+ "c502b67eb6d91d909ba398fa39bec60c": [
185
+ "103508",
186
+ "103776"
187
+ ],
188
+ "c24dd733315066f7c7da3d578f954d8c": [
189
+ "103781",
190
+ "103528"
191
+ ]
192
+ },
193
+ "Microwave": {
194
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
195
+ "7304",
196
+ "7236",
197
+ "7221",
198
+ "7263",
199
+ "7292",
200
+ "7306",
201
+ "7320",
202
+ "7310"
203
+ ],
204
+ "cde6b48ed870286595c1455af7aff8bd": [
205
+ "7273",
206
+ "7167",
207
+ "7296",
208
+ "7349"
209
+ ],
210
+ "5144181ac27497fdfa9bdb5b8b799630": [
211
+ "7366",
212
+ "7119"
213
+ ],
214
+ "c502b67eb6d91d909ba398fa39bec60c": [
215
+ "7128"
216
+ ]
217
+ },
218
+ "Oven": {
219
+ "ddc2ef1be48dc58fe68226818824b648": [
220
+ "101946",
221
+ "101930"
222
+ ],
223
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
224
+ "7130",
225
+ "101971",
226
+ "7290",
227
+ "7138"
228
+ ],
229
+ "c502b67eb6d91d909ba398fa39bec60c": [
230
+ "101773",
231
+ "102055"
232
+ ],
233
+ "5144181ac27497fdfa9bdb5b8b799630": [
234
+ "7220",
235
+ "7187"
236
+ ],
237
+ "0e3c4946251b437ca90e5fe70efdea5b": [
238
+ "101921",
239
+ "101943",
240
+ "101917",
241
+ "101947"
242
+ ],
243
+ "8fae315ff2d1ea185b4ed6d0d092ae0e": [
244
+ "102018",
245
+ "102044",
246
+ "102060",
247
+ "101931"
248
+ ],
249
+ "538827adb8c90adf1322121db4e66fef": [
250
+ "101924"
251
+ ],
252
+ "c24dd733315066f7c7da3d578f954d8c": [
253
+ "101940",
254
+ "7347",
255
+ "102019",
256
+ "7332",
257
+ "7120",
258
+ "7201",
259
+ "101909",
260
+ "7179"
261
+ ],
262
+ "dd3473c941b94dd6654e9f89bb51cac9": [
263
+ "101808",
264
+ "102001"
265
+ ],
266
+ "ebb9c0168d323bca8e92227bdaa7a788": [
267
+ "101908"
268
+ ]
269
+ },
270
+ "Refrigerator": {
271
+ "5144181ac27497fdfa9bdb5b8b799630": [
272
+ "12066",
273
+ "10036",
274
+ "11712",
275
+ "11178",
276
+ "10867",
277
+ "10900",
278
+ "12059",
279
+ "11299",
280
+ "10620",
281
+ "10685",
282
+ "11846",
283
+ "12043",
284
+ "12248",
285
+ "10347",
286
+ "10489",
287
+ "10751",
288
+ "12036",
289
+ "10143",
290
+ "10655",
291
+ "10612",
292
+ "10068",
293
+ "10638",
294
+ "12050",
295
+ "11231",
296
+ "10586",
297
+ "11550",
298
+ "11304",
299
+ "10627"
300
+ ],
301
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
302
+ "10849",
303
+ "12054",
304
+ "12038",
305
+ "10373",
306
+ "10944",
307
+ "12252",
308
+ "10905",
309
+ "11260",
310
+ "12055",
311
+ "10144",
312
+ "10797",
313
+ "11211",
314
+ "12250",
315
+ "12042",
316
+ "12249"
317
+ ],
318
+ "cde6b48ed870286595c1455af7aff8bd": [
319
+ "11709"
320
+ ]
321
+ },
322
+ "Safe": {
323
+ "5144181ac27497fdfa9bdb5b8b799630": [
324
+ "102423",
325
+ "101583",
326
+ "102318",
327
+ "101612",
328
+ "102301",
329
+ "101591",
330
+ "101594",
331
+ "102380",
332
+ "101613",
333
+ "101599",
334
+ "101619",
335
+ "101605",
336
+ "102389",
337
+ "101363",
338
+ "102387"
339
+ ],
340
+ "cde6b48ed870286595c1455af7aff8bd": [
341
+ "102311",
342
+ "101579",
343
+ "101611",
344
+ "102384",
345
+ "102309",
346
+ "101623",
347
+ "101593",
348
+ "101603",
349
+ "102316",
350
+ "101584",
351
+ "102381"
352
+ ],
353
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
354
+ "101604",
355
+ "101564",
356
+ "102418"
357
+ ]
358
+ },
359
+ "StorageFurniture": {
360
+ "cde6b48ed870286595c1455af7aff8bd": [
361
+ "47711",
362
+ "45841",
363
+ "45427",
364
+ "46653",
365
+ "48740",
366
+ "46544",
367
+ "44962",
368
+ "48517",
369
+ "47296",
370
+ "46230",
371
+ "45213",
372
+ "44853",
373
+ "40453",
374
+ "45632",
375
+ "45949",
376
+ "48855",
377
+ "44781",
378
+ "48797",
379
+ "45132",
380
+ "46981",
381
+ "46440",
382
+ "48491",
383
+ "45661",
384
+ "45503",
385
+ "47853",
386
+ "47252",
387
+ "48010",
388
+ "46879",
389
+ "48876",
390
+ "46537",
391
+ "45622",
392
+ "46641",
393
+ "46334",
394
+ "48253",
395
+ "48063",
396
+ "47233",
397
+ "45940",
398
+ "45290",
399
+ "46874",
400
+ "47438",
401
+ "45135",
402
+ "46130",
403
+ "46123",
404
+ "47088",
405
+ "45756",
406
+ "47944",
407
+ "47178"
408
+ ],
409
+ "3ba4ffe16dfe637510ed1c3676ec6cb0": [
410
+ "47168",
411
+ "45176",
412
+ "45908",
413
+ "45249",
414
+ "46127",
415
+ "45633",
416
+ "45936",
417
+ "47388",
418
+ "45091",
419
+ "45385",
420
+ "48721",
421
+ "45667",
422
+ "41004",
423
+ "41529",
424
+ "45671",
425
+ "45130",
426
+ "45504",
427
+ "45638",
428
+ "45087",
429
+ "45164",
430
+ "48036",
431
+ "47651",
432
+ "45950",
433
+ "45623",
434
+ "46132",
435
+ "45693",
436
+ "45783",
437
+ "45007",
438
+ "47180",
439
+ "45690",
440
+ "48271",
441
+ "45267",
442
+ "46092",
443
+ "45134",
444
+ "47021",
445
+ "46744",
446
+ "47187",
447
+ "48686",
448
+ "46556",
449
+ "46408",
450
+ "46616",
451
+ "46439",
452
+ "45448",
453
+ "47963",
454
+ "45212",
455
+ "45413",
456
+ "45645",
457
+ "38516",
458
+ "45822",
459
+ "47315",
460
+ "45916",
461
+ "45173",
462
+ "47686",
463
+ "45699",
464
+ "45244",
465
+ "46889",
466
+ "46906",
467
+ "46430",
468
+ "45717",
469
+ "45937",
470
+ "45203",
471
+ "45372",
472
+ "47133",
473
+ "46966",
474
+ "45419",
475
+ "45779",
476
+ "45297",
477
+ "46922",
478
+ "46944",
479
+ "46417",
480
+ "45850",
481
+ "45620",
482
+ "46044",
483
+ "45248",
484
+ "48452",
485
+ "47817",
486
+ "45964",
487
+ "46029",
488
+ "46179",
489
+ "46787",
490
+ "41452",
491
+ "45166",
492
+ "48167",
493
+ "48243",
494
+ "45177",
495
+ "45961",
496
+ "47632",
497
+ "46033",
498
+ "47391",
499
+ "47316",
500
+ "41086",
501
+ "45384",
502
+ "45526",
503
+ "45915",
504
+ "47281",
505
+ "45415",
506
+ "45606",
507
+ "47742",
508
+ "46401",
509
+ "45323",
510
+ "45178",
511
+ "47514",
512
+ "46117",
513
+ "46197",
514
+ "48467",
515
+ "46452",
516
+ "48519",
517
+ "45855",
518
+ "47729",
519
+ "46057",
520
+ "45600",
521
+ "35059",
522
+ "45691",
523
+ "45524",
524
+ "48490",
525
+ "47182",
526
+ "45910",
527
+ "48413",
528
+ "45247",
529
+ "48023",
530
+ "48746",
531
+ "47419",
532
+ "45621",
533
+ "45403",
534
+ "46427",
535
+ "45443",
536
+ "45922",
537
+ "47601",
538
+ "45853",
539
+ "45516",
540
+ "46107"
541
+ ],
542
+ "8fae315ff2d1ea185b4ed6d0d092ae0e": [
543
+ "45271",
544
+ "46199",
545
+ "47669",
546
+ "49188",
547
+ "46145",
548
+ "47926",
549
+ "45948",
550
+ "47290",
551
+ "40417",
552
+ "48623",
553
+ "47235",
554
+ "49062",
555
+ "47648",
556
+ "45612"
557
+ ],
558
+ "1e8c6b47706f002757c3370366001f06": [
559
+ "47466"
560
+ ],
561
+ "c502b67eb6d91d909ba398fa39bec60c": [
562
+ "45092",
563
+ "46443",
564
+ "45696",
565
+ "45189",
566
+ "46108",
567
+ "48356",
568
+ "45238",
569
+ "41003",
570
+ "45642",
571
+ "49025",
572
+ "46955",
573
+ "45749",
574
+ "45801",
575
+ "45374",
576
+ "46084",
577
+ "47227",
578
+ "47578",
579
+ "45243",
580
+ "45636",
581
+ "45387",
582
+ "41083",
583
+ "48258",
584
+ "41085",
585
+ "46060",
586
+ "44817",
587
+ "46598",
588
+ "46002",
589
+ "48169",
590
+ "46762",
591
+ "45710",
592
+ "48018",
593
+ "45262",
594
+ "47207",
595
+ "46856",
596
+ "46466",
597
+ "47595",
598
+ "47701",
599
+ "45332",
600
+ "49140",
601
+ "45219",
602
+ "48263",
603
+ "45677",
604
+ "47976",
605
+ "46549",
606
+ "45194",
607
+ "45759",
608
+ "46893",
609
+ "46120",
610
+ "47089"
611
+ ],
612
+ "5144181ac27497fdfa9bdb5b8b799630": [
613
+ "45575",
614
+ "45594",
615
+ "45235",
616
+ "46655",
617
+ "49182",
618
+ "45687",
619
+ "46403",
620
+ "46768",
621
+ "46896",
622
+ "45001",
623
+ "47613",
624
+ "47254",
625
+ "46180",
626
+ "47024",
627
+ "46825",
628
+ "48379",
629
+ "49132",
630
+ "48878",
631
+ "47577",
632
+ "47565",
633
+ "45573",
634
+ "44826",
635
+ "46847",
636
+ "46732",
637
+ "45168",
638
+ "46277",
639
+ "47238",
640
+ "45746",
641
+ "47808",
642
+ "45662",
643
+ "48381",
644
+ "45963",
645
+ "45354",
646
+ "45676",
647
+ "47278",
648
+ "47529",
649
+ "46437",
650
+ "45378",
651
+ "46563",
652
+ "47570",
653
+ "45444",
654
+ "48700",
655
+ "45780",
656
+ "47099",
657
+ "46490",
658
+ "45523",
659
+ "47747",
660
+ "46045",
661
+ "45305",
662
+ "40147",
663
+ "49133",
664
+ "46700",
665
+ "46236",
666
+ "45505",
667
+ "48859",
668
+ "46166",
669
+ "46456",
670
+ "45162",
671
+ "45776",
672
+ "45420",
673
+ "46481",
674
+ "45767",
675
+ "45423",
676
+ "45790",
677
+ "49038",
678
+ "45670",
679
+ "47954",
680
+ "48479",
681
+ "46019",
682
+ "46801",
683
+ "45984",
684
+ "45159",
685
+ "49042",
686
+ "45784",
687
+ "48177",
688
+ "46859",
689
+ "46741",
690
+ "46134",
691
+ "45694",
692
+ "46480",
693
+ "45689",
694
+ "46037",
695
+ "47185",
696
+ "45397",
697
+ "48013",
698
+ "46699",
699
+ "48513",
700
+ "45146",
701
+ "45463",
702
+ "41510",
703
+ "45747"
704
+ ],
705
+ "8999452509fdfa98335c5ba44ed05498": [
706
+ "46172",
707
+ "46014",
708
+ "45261",
709
+ "46109",
710
+ "46380"
711
+ ],
712
+ "0e3c4946251b437ca90e5fe70efdea5b": [
713
+ "48497"
714
+ ],
715
+ "c24dd733315066f7c7da3d578f954d8c": [
716
+ "48051"
717
+ ],
718
+ "87036528afd9c9dd03a6ab72efec0136": [
719
+ "45725"
720
+ ]
721
+ }
722
+ }
scripts/graph_pred/api.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
3
+ import re
4
+ import json
5
+ import base64
6
+ import argparse
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ from openai import AzureOpenAI
10
+ from scripts.graph_pred.prompt_workflow_new import messages
11
+ import json_repair
12
+ # Initialize the OpenAI client
13
+
14
+ endpoint = os.environ.get("ENDPOINT")
15
+ api_key = os.environ.get("API_KEY")
16
+ api_version = os.environ.get("API_VERSION")
17
+ model_name = os.environ.get("MODEL_NAME")
18
+ client = AzureOpenAI(
19
+ azure_endpoint=endpoint,
20
+ api_key=api_key,
21
+ api_version=api_version,
22
+ )
23
+
24
+
25
+ def encode_image(image_path: str, center_crop=False):
26
+ """Resize and encode the image as base64"""
27
+ # load the image
28
+ image = Image.open(image_path)
29
+
30
+ # resize the image to 224x224
31
+ if center_crop: # (resize to 256x256 and then center crop to 224x224)
32
+ image = image.resize((256, 256))
33
+ width, height = image.size
34
+ left = (width - 224) / 2
35
+ top = (height - 224) / 2
36
+ right = (width + 224) / 2
37
+ bottom = (height + 224) / 2
38
+ image = image.crop((left, top, right, bottom))
39
+ else:
40
+ image = image.resize((224, 224))
41
+
42
+ # conver the image to bytes
43
+ buffer = BytesIO()
44
+ image.save(buffer, format="PNG")
45
+ buffer.seek(0)
46
+ # encode the image as base64
47
+ encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
48
+ return encoded_image
49
+
50
+ def display_image(image_data):
51
+ """Display the image from the base64 encoded image data"""
52
+ img = Image.open(BytesIO(base64.b64decode(image_data)))
53
+ img.show()
54
+ img.close()
55
+
56
+
57
+ def convert_format(src):
58
+ '''Convert the JSON format from the response to a tree format'''
59
+ def _sort_nodes(tree):
60
+ num_nodes = len(tree)
61
+ sorted_tree = [dict() for _ in range(num_nodes)]
62
+ for node in tree:
63
+ sorted_tree[node["id"]] = node
64
+ return sorted_tree
65
+
66
+ def _traverse(node, parent_id, current_id):
67
+ for key, value in node.items():
68
+ node_id = current_id[0]
69
+ current_id[0] += 1
70
+
71
+ # Create the node
72
+ tree_node = {
73
+ "id": node_id,
74
+ "parent": parent_id,
75
+ "name": key,
76
+ "children": [],
77
+ }
78
+
79
+ # Traverse children if they exist
80
+ if isinstance(value, list):
81
+ for child in value:
82
+ child_id = _traverse(child, node_id, current_id)
83
+ tree_node["children"].append(child_id)
84
+
85
+ # Add this node to the tree
86
+ tree.append(tree_node)
87
+ return node_id
88
+
89
+ tree = []
90
+ current_id = [0]
91
+ _traverse(src, -1, current_id)
92
+ diffuse_tree = _sort_nodes(tree)
93
+ return diffuse_tree
94
+
95
+ def predict_graph_twomode(image_path, first_img_data=None, second_img_data=None, debug=False, center_crop=False):
96
+ '''Predict the part connectivity graph from the image'''
97
+ # Encode the image
98
+ if first_img_data is None or second_img_data is None:
99
+ first_img_data = encode_image(image_path, center_crop)
100
+ second_img_data = encode_image(image_path.replace('close', 'open'), center_crop)
101
+ # if debug:
102
+ # display_image(image_data) # for double checking the image
103
+ # breakpoint()
104
+ new_message = messages.copy()
105
+ new_message.append(
106
+ {
107
+ "role": "user",
108
+ "content": [
109
+ {
110
+ "type": "image_url",
111
+ "image_url": {"url": f"data:image/png;base64,{first_img_data}"},
112
+ },
113
+ {
114
+ "type": "image_url",
115
+ "image_url": {"url": f"data:image/png;base64,{second_img_data}"},
116
+ }
117
+ ],
118
+ },
119
+ )
120
+ # Get the completion from the model
121
+ completion = client.chat.completions.create(
122
+ model=model_name,
123
+ messages=new_message,
124
+ response_format={"type": "text"},
125
+ temperature=1,
126
+ max_tokens=4096,
127
+ top_p=1,
128
+ frequency_penalty=0,
129
+ presence_penalty=0,
130
+ )
131
+ print('processing the response...')
132
+
133
+ # Extract the response
134
+ content = completion.choices[0].message.content
135
+
136
+ src = json.loads(re.search(r"```json\n(.*?)\n```", content, re.DOTALL).group(1))
137
+ print(src)
138
+ # Convert the JSON format to tree format
139
+ diffuse_tree = convert_format(src)
140
+
141
+ return {"diffuse_tree": diffuse_tree, "original_response": content}
142
+
143
+ def save_response(save_path, response):
144
+ '''Save the response to a json file'''
145
+ with open(save_path, "w") as file:
146
+ json.dump(response, file, indent=4)
147
+
148
+
149
+
150
+ def gpt_infer_image_category(image1, image2):
151
+ system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties."
152
+
153
+ text_prompt = (
154
+ "Given two images of an object, determine its category. "
155
+ "The category must be one of the following: Table, Dishwasher, StorageFurniture, "
156
+ "Refrigerator, WashingMachine, Microwave, Oven. "
157
+ "Output only the category name and nothing else. Do not include any other text."
158
+ )
159
+
160
+ content_user = [
161
+ {
162
+ "type": "text",
163
+ "text": text_prompt,
164
+ },
165
+ {
166
+ "type": "image_url",
167
+ "image_url": {"url": f"data:image/png;base64,{image1}"},
168
+ },
169
+ {
170
+ "type": "image_url",
171
+ "image_url": {"url": f"data:image/png;base64,{image2}"},
172
+ },
173
+ ]
174
+ payload = {
175
+ "messages": [
176
+ {"role": "system", "content": system_role},
177
+ {"role": "user", "content": content_user},
178
+ ],
179
+ "temperature": 0.1,
180
+ "max_tokens": 500,
181
+ "top_p": 0.1,
182
+ "frequency_penalty": 0,
183
+ "presence_penalty": 0,
184
+ "stop": None,
185
+ "model": model_name,
186
+ }
187
+ completion = client.chat.completions.create(**payload)
188
+ response = completion.choices[0].message.content
189
+ json_repair.loads(response)
190
+
191
+ return response
192
+
193
+
194
+ if __name__ == "__main__":
195
+ parser = argparse.ArgumentParser(description="Predict the part connectivity graph from an image")
196
+ parser.add_argument("--img_path", type=str, required=True, help="path to the image")
197
+ parser.add_argument("--save_path", type=str, required=True, help="path to the save the response")
198
+ parser.add_argument("--center_crop", action="store_true", help="whether to center crop the image to 224x224, otherwise resize to 224x224")
199
+ args = parser.parse_args()
200
+
201
+ try:
202
+ response = predict_graph(args.img_path, args.center_crop)
203
+ save_response(args.save_path, response)
204
+ response = predict_graph_twomode(args.img_path, args.center_crop)
205
+ save_response(args.save_path[:-5] + 'twomode.json', response)
206
+ except Exception as e:
207
+ with open('openai_err.log', 'a') as f:
208
+ f.write('---------------------------\n')
209
+ f.write(f'{args.img_path}\n')
210
+ f.write(f'{e}\n')
scripts/graph_pred/eval.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import networkx as nx
5
+ from tqdm import tqdm
6
+
7
+ def get_hash(file, key='diffuse_tree'):
8
+ tree = file[key]
9
+ G = nx.DiGraph()
10
+ for node in tree:
11
+ G.add_node(node['id'])
12
+ if node['parent'] != -1:
13
+ G.add_edge(node['id'], node['parent'])
14
+ hashcode = nx.weisfeiler_lehman_graph_hash(G)
15
+ return hashcode
16
+
17
+ if __name__ == "__main__":
18
+ '''Script to evaluate the accuracy of the generated graphs'''
19
+
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--exp_dir', type=str, required=True, help='path to the experiment directory')
22
+ parser.add_argument('--gt_data_root', type=str, required=True, help='root directory of the ground-truth data')
23
+ parser.add_argument('--gt_json_name', type=str, default='object.json', help='Path to the ground truth data')
24
+ args = parser.parse_args()
25
+
26
+ assert os.path.exists(args.exp_dir), "The experiment directory does not exist"
27
+ assert os.path.exists(args.gt_data_root), "The ground-truth data root does not exist"
28
+
29
+ exp_dir = args.exp_dir
30
+ gt_data_dir = args.gt_data_root
31
+
32
+ acc = 0
33
+ files = os.listdir(exp_dir)
34
+ sorted(files)
35
+ total = len(files)
36
+ wrong_files = []
37
+ for file in tqdm(files):
38
+ tokens = file.split('@')
39
+ gt_dir = f'{gt_data_dir}'
40
+ for token in tokens[:-1]:
41
+ gt_dir = os.path.join(gt_dir, token)
42
+ with open(os.path.join(gt_dir, args.gt_json_name)) as f:
43
+ gt = json.load(f)
44
+ # load json files
45
+ with open(os.path.join(exp_dir, file)) as f:
46
+ pred = json.load(f)
47
+ # get hash for the graph
48
+ pred_hash = get_hash(pred)
49
+ gt_hash = get_hash(gt)
50
+ # compare hash
51
+ if pred_hash == gt_hash:
52
+ acc += 1
53
+ else:
54
+ wrong_files.append(file)
55
+
56
+
57
+ with open(os.path.join(os.path.dirname(exp_dir), f'acc_{os.path.basename(exp_dir)}.json'), 'w') as f:
58
+ json.dump({'acc': acc/total, 'wrong_files': wrong_files}, f, indent=4)
59
+
60
+
61
+
62
+
scripts/graph_pred/prompt_workflow_new.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import base64
5
+
6
+ def encode_image(image_path: str, center_crop=False):
7
+ """Resize and encode the image as base64"""
8
+ # load the image
9
+ image = Image.open(image_path)
10
+
11
+ # resize the image to 224x224
12
+ if center_crop: # (resize to 256x256 and then center crop to 224x224)
13
+ image = image.resize((256, 256))
14
+ width, height = image.size
15
+ left = (width - 224) / 2
16
+ top = (height - 224) / 2
17
+ right = (width + 224) / 2
18
+ bottom = (height + 224) / 2
19
+ image = image.crop((left, top, right, bottom))
20
+ else:
21
+ image = image.resize((512, 512))
22
+
23
+ # conver the image to bytes
24
+ buffer = BytesIO()
25
+ image.save(buffer, format="PNG")
26
+ buffer.seek(0)
27
+ # encode the image as base64
28
+ encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
29
+ return encoded_image
30
+
31
+ system_prompt = """
32
+ You are an expert in the recognition, structural parsing, and physical‑feasibility validation of articulated objects from image inputs.
33
+
34
+ You will be provided with two rendered images of the same object:
35
+ 1. A closed‑state image (all movable parts in their fully closed positions)
36
+ 2. An open‑state image (all movable parts in their fully opened positions)
37
+
38
+ Your task is to analyze the object's articulated structure and generate a connectivity graph describing the part relationships.
39
+
40
+ You must follow this workflow:
41
+
42
+ 1. Part Detection:
43
+ - Detect candidate parts with their coarse position in the **closed-state image**, with optional assistance from the **open-state image** to resolve ambiguous or occluded parts.
44
+ - Allowed part types: ['base', 'door', 'drawer', 'handle', 'knob', 'tray']
45
+ - Ignore small decorative things directly attached to the base.
46
+ - There must be exactly one "base"; "tray" is only allowed if the object is a microwave (but "tray" is not the component must be shown in microwave).
47
+
48
+ 2. Step-by-Step Reasoning:
49
+ 1. Part Listing: List all detected parts and their counts (do not infer attachment yet)
50
+ 2. Validation: Enforce structural rules:
51
+ - Exactly one base
52
+ - Each door or drawer may have at most two handles or knobs
53
+ - Every handle/knob must attach to a door or drawer
54
+ - Trays only appear in microwaves
55
+ 3. Attachment Inference: For each non-base part, infer its parent (e.g., "drawer_1 (attached to base)").
56
+ - Use the open-state image to assist in verifying attachments if the closed-state image is ambiguous.
57
+ 4. Connectivity Graph Construction: Output a JSON tree of part relationships. The base is the root.
58
+
59
+ Example output:
60
+ ```json
61
+ {
62
+ "base": [
63
+ { "door": [ { "handle": [] } ] },
64
+ { "drawer": [ { "handle": [] } ] }
65
+ ]
66
+ }
67
+ ```
68
+
69
+ Final Output: You **MUST** output a single JSON tree representing the part connectivity of the object. Use the open-state image to improve the accuracy and completeness of your structural interpretation, but base your parsing primarily on the closed-state image.
70
+ Note: You **MUST** give the final output in JSON format.
71
+ """
72
+
73
+ answer_sys = """
74
+ Great — please upload the two images of the same object:
75
+
76
+ One in the closed state (all movable parts fully closed)
77
+
78
+ One in the open state (all movable parts fully open)
79
+
80
+ Once I receive both, I’ll proceed with:
81
+
82
+ Part detection on the closed image
83
+
84
+ Structured reasoning and validation
85
+
86
+ Attachment inference
87
+
88
+ Final part connectivity graph in JSON format
89
+
90
+ Let me know when you're ready to continue!
91
+ """
92
+
93
+ answer_1 = """
94
+ Great! Let’s go through the articulated object parsing workflow based on the provided **closed-state** and **open-state** images.
95
+
96
+ ---
97
+
98
+ ### **1. Part Detection**
99
+
100
+ From the **closed-state image** (first image), with verification from the **open-state image** (second image):
101
+
102
+ - **Base**: 1 base structure (entire wooden cabinet)
103
+ - **Doors**: 4 doors
104
+ - 2 on the upper section (left & right)
105
+ - 2 on the lower section (left & right)
106
+ - **Drawers**: 4 drawers
107
+ - 3 in a row at the middle
108
+ - 1 larger drawer centered below the middle row
109
+ - **Handles**: 8 handles
110
+ - 1 per drawer
111
+ - 1 per door
112
+
113
+ ---
114
+
115
+ ### **2. Step-by-Step Reasoning**
116
+
117
+ #### 2.1 **Part Listing**
118
+
119
+ - `base`: 1
120
+ - `door`: 4
121
+ - `drawer`: 4
122
+ - `handle`: 8
123
+
124
+ #### 2.2 **Validation**
125
+
126
+ - ✅ Exactly one base
127
+ - ✅ Each door has 1 handle → 4 handles for 4 doors
128
+ - ✅ Each drawer has 1 handle → 4 handles for 4 drawers
129
+ - ✅ No trays → ✓ (not a microwave)
130
+ - ✅ Every handle is attached to a door or drawer (verified using open-state image)
131
+
132
+ All rules are satisfied.
133
+
134
+ #### 2.3 **Attachment Inference**
135
+
136
+ From closed/open-state comparison:
137
+
138
+ - Each **door** is directly attached to the **base**
139
+ - Each **drawer** is mounted to slide out of the **base**
140
+ - Each **handle** is on its respective **door** or **drawer**
141
+
142
+ ---
143
+
144
+ ### **3. Connectivity Graph Construction**
145
+
146
+ Final hierarchical structure:
147
+
148
+ ```json
149
+ {
150
+ "base": [
151
+ { "door_1": [ { "handle_1": [] } ] },
152
+ { "door_2": [ { "handle_2": [] } ] },
153
+ { "door_3": [ { "handle_3": [] } ] },
154
+ { "door_4": [ { "handle_4": [] } ] },
155
+ { "drawer_1": [ { "handle_5": [] } ] },
156
+ { "drawer_2": [ { "handle_6": [] } ] },
157
+ { "drawer_3": [ { "handle_7": [] } ] },
158
+ { "drawer_4": [ { "handle_8": [] } ] }
159
+ ]
160
+ }
161
+ ```
162
+ """
163
+
164
+ answer_2 = """
165
+ Thanks for the new image pair! Let's analyze the articulated structure of the **washing machine** based on the provided closed-state and open-state images.
166
+
167
+ ---
168
+
169
+ ### **1. Part Detection**
170
+
171
+ From the **closed-state image**, validated using the **open-state image**:
172
+
173
+ - **Base**: 1 main body (the washing machine structure)
174
+ - **Door**: 1 circular front-loading door
175
+ - **Handle**: Notice that there is no handle attached to the dooor
176
+ - **Knob**: 1 control knob on the top of the machine
177
+
178
+ ---
179
+
180
+ ### **2. Step-by-Step Reasoning**
181
+
182
+ #### 2.1 **Part Listing**
183
+
184
+ - `base`: 1
185
+ - `door`: 1
186
+ - `knob`: 1
187
+
188
+ #### 2.2 **Validation**
189
+
190
+ - ✅ Exactly one base
191
+ - ✅ No drawers, trays → Valid
192
+ - ✅ Trays not present (correct: not a microwave)
193
+
194
+ → **All structural constraints satisfied**
195
+
196
+ #### 2.3 **Attachment Inference**
197
+
198
+ - The **door** is front-mounted and clearly opens via a **hinge** on the left side (from viewer's perspective)
199
+ - The **knob** is directly mounted on the **base** (top section)
200
+ ---
201
+
202
+ ### **3. Connectivity Graph Construction**
203
+
204
+ ```json
205
+ {
206
+ "base": [
207
+ { "door": []},
208
+ { "knob": []}
209
+ ]
210
+ }
211
+ ```
212
+ """
213
+
214
+ answer_3 = """
215
+ Thanks! Let’s analyze the articulated structure of this oven using your provided closed and open state images.
216
+
217
+ ---
218
+
219
+ ### ✅ Articulated Object Parsing (Oven)
220
+
221
+ #### **1. Part Detection**
222
+
223
+ From the **closed-state** image (with **open-state** verification):
224
+
225
+ - **base**: 1 (the main oven body)
226
+ - **door**: 1 (the front oven door)
227
+ - **handle**: 1 (long horizontal handle on the door)
228
+
229
+ > 🔍 *The 6 small control knobs on top are **ignored** as instructed.*
230
+
231
+ ---
232
+
233
+ #### **2. Step-by-Step Reasoning**
234
+
235
+ **2.1 Part Listing:**
236
+
237
+ - `base`: 1
238
+ - `door`: 1
239
+ - `handle`: 1
240
+
241
+ **2.2 Validation:**
242
+
243
+ - ✅ Exactly one base
244
+ - ✅ Door has 1 handle
245
+ - ✅ Handle is attached to the door
246
+ - ✅ No trays (✓ valid for an oven)
247
+ - ✅ No knobs or extra handles beyond what's allowed
248
+
249
+ ✔ All constraints are satisfied.
250
+
251
+ **2.3 Attachment Inference:**
252
+
253
+ - The **door** is bottom-hinged and connected to the **base**
254
+ - The **handle** is directly mounted on the door
255
+
256
+ ---
257
+
258
+ #### **3. Final Connectivity Graph**
259
+
260
+ ```json
261
+ {
262
+ "base": [
263
+ { "door": [ { "handle": [] } ] }
264
+ ]
265
+ }
266
+ ```
267
+ """
268
+
269
+ answer_4 = """
270
+ Thanks! Let's now analyze the **wardrobe** based on the closed and open-state images you've provided.
271
+
272
+ ---
273
+
274
+ ### ✅ Articulated Object Parsing – Wardrobe
275
+
276
+ ---
277
+
278
+ ### **1. Part Detection**
279
+
280
+ - **Base**: 1 (entire wardrobe structure)
281
+ - **Doors**: 4 vertical panels (two left, two right)
282
+ - **Drawers**: 4 horizontal drawers at the bottom
283
+ - **Handles**: 8 in total
284
+ - 4 for doors (1 each)
285
+ - 4 for drawers (1 each)
286
+
287
+ ---
288
+
289
+ ### **2. Step-by-Step Reasoning**
290
+
291
+ #### **2.1 Part Listing**
292
+
293
+ - `base`: 1
294
+ - `door`: 4
295
+ - `drawer`: 4
296
+ - `handle`: 8
297
+
298
+ #### **2.2 Validation**
299
+
300
+ - ✅ Exactly one base
301
+ - ✅ Each door has 1 handle → valid
302
+ - ✅ Each drawer has 1 handle → valid
303
+ - ✅ No trays → ✓ valid (not a microwave)
304
+ - ✅ Every handle is attached to a door or drawer (confirmed via open-state image)
305
+
306
+ ✔ All structural rules are satisfied.
307
+
308
+ ---
309
+
310
+ #### **2.3 Attachment Inference**
311
+
312
+ - Each **door** is hinged to the **base**
313
+ - Each **drawer** slides out of the **base**
314
+ - Each **handle** is mounted on one **door** or **drawer**
315
+
316
+ ---
317
+
318
+ ### **3. Final Connectivity Graph**
319
+
320
+ ```json
321
+ {
322
+ "base": [
323
+ { "door_1": [ { "handle_1": [] } ] },
324
+ { "door_2": [ { "handle_2": [] } ] },
325
+ { "door_3": [ { "handle_3": [] } ] },
326
+ { "door_4": [ { "handle_4": [] } ] },
327
+ { "drawer_1": [ { "handle_5": [] } ] },
328
+ { "drawer_2": [ { "handle_6": [] } ] },
329
+ { "drawer_3": [ { "handle_7": [] } ] },
330
+ { "drawer_4": [ { "handle_8": [] } ] }
331
+ ]
332
+ }
333
+ ```
334
+ """
335
+
336
+ answer_all = [answer_sys, answer_1, answer_2, answer_3, answer_4]
337
+ messages = [
338
+ {"role": "user", "content": system_prompt},
339
+ {"role": "assistant", "content": answer_sys}
340
+ ]
341
+
342
+ for i in range(4):
343
+ root_path = './scripts/imgs_reference/'
344
+ close_path = os.path.join(root_path, f'close{i + 1}.png')
345
+ open_path = os.path.join(root_path, f'open{i + 1}.png')
346
+ close_img = encode_image(close_path, center_crop=False)
347
+ open_img = encode_image(open_path, center_crop=False)
348
+ messages.append(
349
+ {
350
+ "role": "user",
351
+ "content": [
352
+ {
353
+ "type": "image_url",
354
+ "image_url": {"url": f"data:image/png;base64,{close_img}"},
355
+ },
356
+ {
357
+ "type": "image_url",
358
+ "image_url": {"url": f"data:image/png;base64,{open_img}"},
359
+ }
360
+ ],
361
+ }
362
+ )
363
+ messages.append({"role": "assistant", "content": answer_all[i + 1]})
scripts/json2urdf.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from xml.etree.ElementTree import Element, SubElement, tostring, ElementTree
3
+ from xml.dom.minidom import parseString
4
+ import pybullet as p
5
+ import pybullet_data
6
+ import os
7
+ from PIL import Image # 使用Pillow保存图像
8
+ import numpy as np
9
+ import trimesh # 用于处理3D网格
10
+ import imageio
11
+ import math
12
+
13
+ def degrees_to_radians(degrees):
14
+ """Convert an angle from degrees to radians."""
15
+ return degrees * math.pi / 180.0
16
+
17
+ def ply_to_obj(ply_filename, obj_filename, urdf_filename, part, json_data):
18
+ """Convert a PLY file to OBJ format using trimesh."""
19
+ base_path = '/'.join(urdf_filename.split('/')[:-1])
20
+ mesh = trimesh.load(os.path.join(base_path, ply_filename), force='mesh')
21
+ print(ply_filename, mesh.bounding_box.centroid)
22
+ # mesh.vertices -= (mesh.bounding_box.centroid)
23
+ # find base (parent == -1)
24
+ base_part_id = next((p['id'] for p in json_data['diffuse_tree'] if p['parent'] == -1), None)
25
+ if 'joint' in part.keys():
26
+ mesh.vertices -= (mesh.bounding_box.centroid + part['joint']['axis']['origin'])
27
+ while part['parent'] != base_part_id:
28
+ parent_part = next((p for p in json_data['diffuse_tree'] if p['id'] == part['parent']), None)
29
+ if parent_part is None:
30
+ break
31
+ mesh.vertices -= (parent_part['joint']['axis']['origin'])
32
+ part = parent_part
33
+ else:
34
+ mesh.vertices -= (mesh.bounding_box.centroid)
35
+ mesh.export(os.path.join(base_path, obj_filename))
36
+
37
+ def create_urdf_from_json(json_data, urdf_filename, parent_part=None):
38
+ robot = Element('robot', name='articulate_object')
39
+
40
+ def add_link(parent, part, urdf_filename, json_data, base=False):
41
+ link = SubElement(parent, 'link', name=f"link_{part['id']}")
42
+
43
+ # 将PLY文件转换为OBJ文件
44
+ ply_path = part['objs'][0]
45
+ obj_path = os.path.splitext(ply_path)[0] + '.obj'
46
+ # if not os.path.exists(obj_path):
47
+ ply_to_obj(ply_path, obj_path, urdf_filename, part, json_data)
48
+
49
+ visual = SubElement(link, 'visual')
50
+ origin = SubElement(visual, 'origin', xyz=" ".join(map(str, part['aabb']['center'])), rpy="0 0 0")
51
+ geometry = SubElement(visual, 'geometry')
52
+ mesh = SubElement(geometry, 'mesh', filename=obj_path) # 使用转换后的OBJ路径
53
+
54
+
55
+ def add_joint(parent, child_part, parent_part):
56
+ joint = SubElement(parent, 'joint', name=f"{parent_part['id']}_{child_part['id']}_joint", type=child_part['joint']['type'])
57
+
58
+ # for i in range(3):
59
+ # child_part['joint']['axis']['origin'][i] -= (child_part['aabb']['size'][i])
60
+ origin = SubElement(joint, 'origin', xyz=" ".join(map(str, child_part['joint']['axis']['origin'])), rpy="0 0 0")
61
+ # origin = SubElement(joint, 'origin', xyz=" 0 0 0 ", rpy="0 0 0")
62
+ axis = SubElement(joint, 'axis', xyz=" ".join(map(str, child_part['joint']['axis']['direction'])))
63
+ if child_part['joint']['type'] == 'revolute':
64
+ child_part['joint']['range'][0] = degrees_to_radians(child_part['joint']['range'][0])
65
+ child_part['joint']['range'][1] = degrees_to_radians(child_part['joint']['range'][1])
66
+
67
+ lower, upper = child_part['joint']['range']
68
+ if upper < lower:
69
+ lower, upper = upper, lower
70
+
71
+ limit = SubElement(
72
+ joint, 'limit',
73
+ lower=str(lower),
74
+ upper=str(upper),
75
+ effort="10",
76
+ velocity="1"
77
+ )
78
+ parent_element = SubElement(joint, 'parent', link=f"link_{parent_part['id']}")
79
+ child_element = SubElement(joint, 'child', link=f"link_{child_part['id']}")
80
+
81
+ base_part = json_data['diffuse_tree'][0]
82
+ add_link(robot, base_part, urdf_filename, json_data, base=True)
83
+
84
+ for part in json_data['diffuse_tree'][1:]:
85
+ base_part = next((p for p in json_data['diffuse_tree'] if p['parent'] == -1), None)
86
+ parent_part = next((p for p in json_data['diffuse_tree'] if p['id'] == part['parent']), None)
87
+ add_link(robot, part, urdf_filename, json_data)
88
+ if parent_part:
89
+ add_joint(robot, part, parent_part)
90
+
91
+ xmlstr = parseString(tostring(robot)).toprettyxml(indent=" ")
92
+
93
+ with open(urdf_filename, "w") as f:
94
+ f.write(xmlstr)
95
+
96
+
97
+ def pybullet_render(urdf_path, target_dir, num_frames, distance=3, fov=60):
98
+ physicsClient = p.connect(p.DIRECT)
99
+ p.setAdditionalSearchPath(pybullet_data.getDataPath())
100
+ try:
101
+ robot = p.loadURDF(urdf_path, [0, 0, 0])
102
+ except Exception as e:
103
+ print(e)
104
+ return
105
+ for i in range(-1, p.getNumJoints(robot)):
106
+ rgba = [np.random.uniform(0.2, 1.0), np.random.uniform(0.2, 1.0), np.random.uniform(0.2, 1.0), 1]
107
+ p.changeVisualShape(robot, linkIndex=i, rgbaColor=rgba)
108
+ p.resetBasePositionAndOrientation(robot, [0, 0, 0], [0, 0.7071, 0.7071, 0])
109
+ joint_info = []
110
+ for i in range(p.getNumJoints(robot)):
111
+ info = p.getJointInfo(robot, i)
112
+ if info[2] != p.JOINT_FIXED:
113
+ joint_info.append({
114
+ 'index': info[0],
115
+ 'type': info[2],
116
+ 'name': info[1].decode('utf-8'),
117
+ 'lower_limit': info[8],
118
+ 'upper_limit': info[9],
119
+ 'initial_position': p.getJointState(robot, info[0])[0]
120
+ })
121
+
122
+ joint_positions = {}
123
+ for joint in joint_info:
124
+ start = joint['lower_limit']
125
+ end = joint['upper_limit']
126
+ joint_positions[joint['index']] = np.concatenate((np.linspace(start, end, num_frames), np.linspace(end, start, num_frames)))
127
+
128
+ gif_frames = []
129
+ for frame in range(num_frames*2):
130
+ # import pdb; pdb.set_trace()
131
+ for joint in joint_info:
132
+ p.resetJointState(robot, joint['index'], joint_positions[joint['index']][frame])
133
+ joint_state = p.getJointState(robot, joint['index'])
134
+ p.stepSimulation()
135
+ viewMatrix = p.computeViewMatrixFromYawPitchRoll(
136
+ cameraTargetPosition=[0, 0, 0],
137
+ distance=3.0,
138
+ yaw=-150,
139
+ pitch=-10,
140
+ roll=0,
141
+ upAxisIndex=2
142
+ )
143
+ projectionMatrix=p.computeProjectionMatrixFOV(
144
+ fov=fov, ##60
145
+ aspect=1.0,
146
+ nearVal=0.1, farVal=100)
147
+
148
+ width, height, rgbPixels, depthBuffer, segMask = p.getCameraImage(
149
+ width=1024, height=1024, viewMatrix=viewMatrix,
150
+ projectionMatrix=projectionMatrix,
151
+ renderer=p.ER_BULLET_HARDWARE_OPENGL)
152
+
153
+ #get rgba image
154
+ rgba_image = np.reshape(rgbPixels, (height, width, 4))
155
+ # rgba_image[np.all(rgba_image[:, :, :3] == 255, axis=-1)] = [0, 0, 0, 0]
156
+ gif_frames.append(rgba_image[:, :, :3])
157
+
158
+ p.disconnect()
159
+ imageio.mimsave(f'{target_dir}/animation.gif', gif_frames, fps=8, loop=0)
160
+
scripts/mesh_retrieval/retrieve.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
4
+ import json
5
+ import argparse
6
+ import numpy as np
7
+ from retrieval.obj_retrieval import find_obj_candidates, pick_and_rescale_parts
8
+ import trimesh
9
+ import shutil
10
+
11
+ def _retrieve_part_meshes(info_dict, save_dir, gt_data_root):
12
+ mesh_save_dir = os.path.join(save_dir, "plys")
13
+ obj_save_dir = os.path.join(save_dir, "objs")
14
+ os.makedirs(mesh_save_dir, exist_ok=True)
15
+ os.makedirs(obj_save_dir, exist_ok=True)
16
+ print(save_dir)
17
+ if os.path.exists(os.path.join(save_dir, "object.ply")):
18
+ return
19
+ HASHBOOK_PATH = "retrieval/retrieval_hash_no_handles.json"
20
+
21
+ obj_candidates = find_obj_candidates(
22
+ info_dict,
23
+ gt_data_root,
24
+ HASHBOOK_PATH,
25
+ gt_file_name="object.json",
26
+ num_states=5,
27
+ metric_num_samples=4096,
28
+ keep_top=3,
29
+ )
30
+
31
+ retrieved_mesh_specs = pick_and_rescale_parts(
32
+ info_dict, obj_candidates, gt_data_root, gt_file_name="object.json"
33
+ )
34
+
35
+ scene = trimesh.Scene()
36
+ for i, mesh_spec in enumerate(retrieved_mesh_specs):
37
+ part_spec = info_dict["diffuse_tree"][i]
38
+ current_part_meshes = []
39
+ file_paths = []
40
+ for file in mesh_spec["files"]:
41
+ file = os.path.join(mesh_spec["dir"], file).replace("ply", "obj")
42
+ file_paths.append(file)
43
+ m = trimesh.load(file, force="mesh")
44
+ current_part_meshes.append(m)
45
+
46
+ if not current_part_meshes:
47
+ continue
48
+
49
+ bounds = np.array([m.bounds for m in current_part_meshes])
50
+ min_extents = bounds[:, 0, :].min(axis=0)
51
+ max_extents = bounds[:, 1, :].max(axis=0)
52
+ group_centroid = (min_extents + max_extents) / 2.0
53
+
54
+ transformation = trimesh.transformations.compose_matrix(
55
+ scale=mesh_spec["scale_factor"],
56
+ angles=[0, 0, np.radians(90) if mesh_spec["z_rotate_90"] else 0],
57
+ translate=part_spec["aabb"]["center"],
58
+ )
59
+
60
+ part_scene = trimesh.Scene()
61
+ for mesh in current_part_meshes:
62
+ mesh.vertices -= group_centroid
63
+ mesh.apply_transform(transformation)
64
+ part_scene.add_geometry(mesh)
65
+ scene.add_geometry(mesh)
66
+
67
+ obj_path = os.path.join(obj_save_dir, f"part_{i}/part_{i}.obj")
68
+ os.makedirs(os.path.dirname(obj_path), exist_ok=True)
69
+ part_scene.export(obj_path, include_texture=True)
70
+ info_dict["diffuse_tree"][i]["objs"] = [f"objs/part_{i}/part_{i}.obj"]
71
+
72
+ scene.export(os.path.join(save_dir, "object.ply"))
73
+ del mesh, scene
74
+ return info_dict
75
+
76
+ def main(args):
77
+ with open(os.path.join(args.src_dir, args.json_name), "r") as f:
78
+ info_dict = json.load(f)
79
+
80
+ if 'meta' not in info_dict.keys():
81
+ info_dict['meta'] = {
82
+ 'obj_cat': 'StroageFurniture'
83
+ }
84
+
85
+ updated_json = _retrieve_part_meshes(info_dict, args.src_dir, args.gt_data_root)
86
+
87
+ if updated_json is not None:
88
+ with open(os.path.join(args.src_dir, args.json_name), "w") as f:
89
+ json.dump(updated_json, f)
90
+
91
+ if __name__ == '__main__':
92
+ parser = argparse.ArgumentParser()
93
+ parser.add_argument('--src_dir', type=str, required=True, help='path to the directory containing object.json')
94
+ parser.add_argument('--json_name', type=str, default='object.json', help='name of the json file')
95
+ parser.add_argument('--gt_data_root', type=str, default='./', help='path to the ground truth data')
96
+ args = parser.parse_args()
97
+ main(args)
scripts/mesh_retrieval/retrieve_gpt.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import argparse
4
+ from tqdm.contrib.concurrent import process_map
5
+ from functools import partial
6
+
7
+ def run_retrieve(src_dir, json_name, data_root):
8
+ fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root]
9
+ try:
10
+ subprocess.run(fn_call, check=True, stderr=subprocess.STDOUT)
11
+ except subprocess.CalledProcessError as e:
12
+ print(f'Error from run_retrieve: {src_dir}')
13
+ print(f'Error: {e}')
14
+ return ' '.join(fn_call)
15
+
16
+ if __name__ == '__main__':
17
+ root_path = '/home/users/ruiqi.wu/manipulate_3d_generate/data/gpt_blender/'
18
+ for class_name in os.listdir(root_path):
19
+ if class_name == 'StroageFurniture':
20
+ for model_id in os.listdir(os.path.join(root_path, class_name)):
21
+ json_path = os.path.join(root_path, class_name, model_id, 'object.json')
22
+ object_path = os.path.join(root_path, class_name, model_id, 'object.ply')
23
+ if os.path.exists(json_path):
24
+ if not os.path.exists(object_path):
25
+ print(json_path)
26
+ src_dir = os.path.join(root_path, class_name, model_id)
27
+ json_name = 'object.json'
28
+ data_root = '../singapo'
29
+ run_retrieve(src_dir, json_name, data_root)
scripts/mesh_retrieval/run_retrieve.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import argparse
4
+ from tqdm.contrib.concurrent import process_map
5
+ from functools import partial
6
+
7
+ def run_retrieve(src_dir, json_name, data_root):
8
+ if 'StorageFurniture' not in src_dir and 'Table' not in src_dir:
9
+ data_root = '../acd_data/merged-data'
10
+ fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root]
11
+ try:
12
+ subprocess.run(fn_call, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
13
+ except subprocess.CalledProcessError as e:
14
+ print(f'Error from run_retrieve: {src_dir}')
15
+ print(f'Error: {e}')
16
+ return ' '.join(fn_call)
17
+
18
+ if __name__ == '__main__':
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--src", type=str, required=True, help="path to the experiment folder")
21
+ parser.add_argument("--json_name", type=str, default="object.json", help="name of the json file")
22
+ parser.add_argument("--gt_data_root", type=str, default="../data", help="path to the ground truth data")
23
+ parser.add_argument("--max_workers", type=int, default=6, help="number of images to render for each object")
24
+ args = parser.parse_args()
25
+
26
+ assert os.path.exists(args.src), f"Src path does not exist: {args.src}"
27
+ assert os.path.exists(args.gt_data_root), f"GT data root does not exist: {args.gt_data_root}"
28
+
29
+ exp_path = args.src
30
+ # len_root = len(exp)
31
+ print('----------- Retrieve Part Meshes -----------')
32
+ src_dirs = []
33
+ # exps = os.listdir(root)
34
+ # for exp in exps:
35
+ # exp_path = os.path.join(root, exp)
36
+ for model_id in os.listdir(exp_path):
37
+ model_id_path = os.path.join(exp_path, model_id)
38
+ # print(model_id_path)
39
+ if '.' in model_id:
40
+ continue
41
+ for model_id_id in os.listdir(model_id_path):
42
+ if '.' not in model_id_id:
43
+ model_id_id_path = os.path.join(model_id_path, model_id_id)
44
+ for json_file in os.listdir(model_id_id_path):
45
+ if json_file.endswith(args.json_name):
46
+ if os.path.exists(os.path.join(model_id_id_path, 'object.ply')):
47
+ print(f"Found {model_id_id_path} with object.ply")
48
+ else:
49
+ # run_retrieve(model_id_id_path, json_name=args.json_name, data_root=args.gt_data_root)
50
+ src_dirs.append(model_id_id_path)
51
+ print(len(src_dirs), model_id_id_path)
52
+ # for dirpath, dirname, fnames in os.walk(root):
53
+ # for fname in fnames:
54
+ # if fname.endswith(args.json_name):
55
+ # src_dirs.append(dirpath) # save the relative directory path
56
+ # print(root)
57
+ print(f"Found {len(src_dirs)} jsons to retrieve part meshes")
58
+ # print(src_dirs)
59
+ # import ipdb
60
+ # ipdb.set_trace()
61
+
62
+ # for src_dir in src_dirs:
63
+ # print(src_dir)
64
+ # command = run_retrieve(src_dir, json_name=args.json_name, data_root=args.gt_data_root)
65
+ # command_file = open('retrieve_commands.sh', 'a')
66
+ # command_file.write(command + '\n')
67
+ # command_file.close()
68
+ process_map(partial(run_retrieve, json_name=args.json_name, data_root=args.gt_data_root), src_dirs, max_workers=6, chunksize=1)