File size: 17,326 Bytes
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dc99b2
 
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
import json
import torch
import argparse
import numpy as np
from PIL import Image, ImageOps
import imageio
# from my_utils.plot import viz_graph
from my_utils.misc import load_config
import torchvision.transforms as T
from diffusers import DDPMScheduler
from models.denoiser import Denoiser
from scripts.json2urdf import create_urdf_from_json, pybullet_render
from dataset.utils import make_white_background, load_input_from, convert_data_range, parse_tree
import models
import torch.nn.functional as F
from io import BytesIO
import base64
from scripts.graph_pred.api import predict_graph_twomode, gpt_infer_image_category
import subprocess
import spaces
import time


cat_ref = {
    "Table": 0,
    "Dishwasher": 1,
    "StorageFurniture": 2,
    "Refrigerator": 3,
    "WashingMachine": 4,
    "Microwave": 5,
    "Oven": 6,
}

def run_retrieve(src_dir, json_name, data_root):
    fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root]
    try:
        subprocess.run(fn_call, check=True,  stderr=subprocess.STDOUT)
    except subprocess.CalledProcessError as e:
        print(f'Error from run_retrieve: {src_dir}')
        print(f'Error: {e}')

def make_white_background(src_img):
    '''Make the white background for the input RGBA image.'''
    src_img.load() 
    background = Image.new("RGB", src_img.size, (255, 255, 255))
    background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel
    return background

def pad_to_square(img, fill=0):
    """Pad image to square with given fill value (default: 0 = black)."""
    width, height = img.size
    if width == height:
        return img
    max_side = max(width, height)
    delta_w = max_side - width
    delta_h = max_side - height
    padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2)
    return ImageOps.expand(img, padding, fill=fill)

def load_img(img_path):
    transform = T.Compose([
        T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])
    with Image.open(img_path) as img:
        if img.mode == 'RGBA':
            img = make_white_background(img)
        img = img.convert('RGB')  # Ensure it's 3-channel for normalization
        img = pad_to_square(img, fill=0)
        img = transform(img)
    img_batch = img.unsqueeze(0).cuda()

    return img_batch


def load_frame_with_imageio(frame):
    """
    将单帧图像处理为符合 DINO 模型输入的格式。
    """
    transform = T.Compose([
        T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])
    img = Image.fromarray(frame)  # 转为 PIL 图像
    if img.mode == 'RGBA':
        img = make_white_background(img)
    img = transform(img)  # 应用预处理
    return img.unsqueeze(0).cuda()  # 增加 batch 维度

def read_video_as_batch_with_imageio(video_path):
    """
    使用 imageio 读取视频并将所有帧处理为 batch 格式 (B, C, H, W)。
    """
    reader = imageio.get_reader(video_path)
    batch_frames = []

    try:
        for frame in reader:
            # 加载帧并处理为 (1, C, H, W)
            processed_frame = load_frame_with_imageio(frame)
            batch_frames.append(processed_frame)

        reader.close()
        if batch_frames:
            return torch.cat(batch_frames, dim=0).cuda()  # 在 batch 维度堆叠,并转移到 GPU
        else:
            print("视频没有有效帧")
            return None
    except Exception as e:
        print(f"处理视频时出错: {e}")
        return None

def extract_dino_feature(img_path_1, img_path_2):
    print('Extracting DINO feature...')
    feat_1 = load_img(img_path_1)
    feat_2 = load_img(img_path_2)
    frames = torch.cat([feat_1, feat_2], dim=0)
    dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg', pretrained=True).cuda()
    print('step4')
    with torch.no_grad():
        feat = dinov2_vitb14_reg.forward_features(frames)["x_norm_patchtokens"]
    # release the GPU memory of the model
    feat_input = torch.cat([feat[0], feat[-1]], dim=0).unsqueeze(0)
    print('Extracting DINO feature over')
    torch.cuda.empty_cache()
    return feat_input

def set_scheduler(n_steps=100):
    scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='linear', prediction_type='epsilon')
    scheduler.set_timesteps(n_steps)
    return scheduler

def prepare_model_input(data, cond, feat, n_samples):
    # attention masks
    attr_mask = torch.from_numpy(cond['attr_mask']).unsqueeze(0).repeat(n_samples, 1, 1)
    key_pad_mask = torch.from_numpy(cond['key_pad_mask'])
    graph_mask = torch.from_numpy(cond['adj_mask'])
    # input image feature
    f = feat.repeat(n_samples, 1, 1)
    # input noise
    B, C = data.shape
    noise = torch.randn([n_samples, B, C], dtype=torch.float32)
    # dummy image feature (used for guided diffusion)
    dummy_feat = torch.from_numpy(np.load('systems/dino_dummy.npy').astype(np.float32))
    dummy_feat = dummy_feat.unsqueeze(0).repeat(n_samples, 1, 1)
    # dummy object category
    cat = torch.zeros(1, dtype=torch.long).repeat(n_samples)
    return {
        "noise": noise.cuda(),
        "attr_mask": attr_mask.cuda(),
        "key_pad_mask": key_pad_mask.cuda(),
        "graph_mask": graph_mask.cuda(),
        "dummy_f": dummy_feat.cuda(),
        'cat': cat.cuda(),
        'f': f.cuda(),  
    }

def prepare_model_input_nocond(feat, n_samples):
    # attention masks
    cond_example = np.zeros((32*5, 32*5), dtype=bool)
    attr_mask = np.eye(32, 32, dtype=bool)
    attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)
    attr_mask = torch.from_numpy(attr_mask).unsqueeze(0).repeat(n_samples, 1, 1)
    key_pad_mask = torch.from_numpy(cond_example).unsqueeze(0).repeat(n_samples, 1, 1)
    graph_mask = torch.from_numpy(cond_example).unsqueeze(0).repeat(n_samples, 1, 1)
    # input image feature
    f = feat.repeat(n_samples, 1, 1)
    # input noise
    data = np.zeros((32*5, 6), dtype=bool)
    noise = torch.randn(data.shape, dtype=torch.float32).repeat(n_samples, 1, 1)
    # dummy image feature (used for guided diffusion)
    dummy_feat = torch.from_numpy(np.load('systems/dino_dummy.npy').astype(np.float32))
    dummy_feat = dummy_feat.unsqueeze(0).repeat(n_samples, 1, 1)
    # dummy object category
    cat = torch.zeros(1, dtype=torch.long).repeat(n_samples)
    return {
        "noise": noise.cuda(),
        "attr_mask": attr_mask.cuda(),
        "key_pad_mask": key_pad_mask.cuda(),
        "graph_mask": graph_mask.cuda(),
        "dummy_f": dummy_feat.cuda(),
        'cat': cat.cuda(),
        'f': f.cuda(),  
    }

def save_graph(pred_graph, save_dir):
    print(f'Saving the predicted graph to {save_dir}/pred_graph.json')
    # save the response
    with open(os.path.join(save_dir, "pred_graph.json"), "w") as f:
        json.dump(pred_graph, f, indent=4)
    # Visualize the graph
    # img_graph = Image.fromarray(viz_graph(pred_graph))
    # img_graph.save(os.path.join(save_dir, "pred_graph.png"))

def forward(model, scheduler, inputs, omega=0.5):
    print('Running inference...')
    noisy_x = inputs['noise']
    for t in scheduler.timesteps:
        timesteps = torch.tensor([t], device=inputs['noise'].device)
        outputs_cond = model(
            x=noisy_x,
            cat=inputs['cat'],
            timesteps=timesteps,
            feat=inputs['f'], 
            key_pad_mask=inputs['key_pad_mask'],
            graph_mask=inputs['graph_mask'],
            attr_mask=inputs['attr_mask'],
            label_free=True,
        ) # take condtional image as input
        if omega != 0:
            outputs_free = model(
                x=noisy_x,
                cat=inputs['cat'],
                timesteps=timesteps,
                feat=inputs['dummy_f'], 
                key_pad_mask=inputs['key_pad_mask'],
                graph_mask=inputs['graph_mask'],
                attr_mask=inputs['attr_mask'],
                label_free=True,
            ) # take the dummy DINO features for the condition-free mode
            noise_pred = (1 + omega) * outputs_cond['noise_pred'] - omega * outputs_free['noise_pred']
        else:
            noise_pred = outputs_cond['noise_pred']
        noisy_x = scheduler.step(noise_pred, t, noisy_x).prev_sample
    return noisy_x

def _convert_json(x, c):
    out = {"meta": {}, "diffuse_tree": []}
    n_nodes = c["n_nodes"]
    par = c["parents"].tolist()
    adj = c["adj"]
    np.fill_diagonal(adj, 0) # remove self-loop for the root node
    if "obj_cat" in c:
        out["meta"]["obj_cat"] = c["obj_cat"]

    # convert the data to original range
    data = convert_data_range(x)
    # parse the tree
    out["diffuse_tree"] = parse_tree(data, n_nodes, par, adj)
    return out

def post_process(output, cond, save_root, gt_data_root, visualize=False):
    print('Post-processing...')
    # N = output.shape[0]
    N = 1
    for i in range(N):
        cond_n = {}
        cond_n['n_nodes'] = cond['n_nodes'][i] 
        cond_n['parents'] = cond['parents'][i]
        cond_n['adj'] = cond['adj'][i]
        cond_n['obj_cat'] = cond['cat']
        # convert the raw model output to the json format
        out_json = _convert_json(output, cond_n)
        save_dir = os.path.join(save_root, str(i))
        os.makedirs(save_dir, exist_ok=True)
        with open(os.path.join(save_dir, "object.json"), "w") as f:
            json.dump(out_json, f, indent=4)
        

        # retrieve part meshes (call python script)
        # print(f"Retrieving part meshes for the object {i}...")
        # os.system(f"python scripts/mesh_retrieval/retrieve.py --src_dir {save_dir} --json_name object.json --gt_data_root {gt_data_root}")


    

def load_model(ckpt_path, config):
    print('Loading model from checkpoint...')
    model = models.make(config.name, config)
    state_dict = torch.load(ckpt_path)
    state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    model.eval()
    return model.cuda()

def convert_pred_graph(pred_graph):
    cond = {}
    B, K = pred_graph.shape[:2]
    adj = np.zeros((B, K, K), dtype=np.float32)
    padding = np.zeros((B, 5 * K, 5* K), dtype=bool)
    parents = np.zeros((B, K), dtype=np.int32)
    n_nodes = np.zeros((B,), dtype=np.int32)
    for b in range(B):
        node_len = 0
        for k in range(K):
            if pred_graph[b, k] == k and k > 0:
                node_len = k
                break
            node = pred_graph[b, k]
            adj[b, k, node] = 1
            adj[b, node, k] = 1
            parents[b, k] = node
        adj[b, node_len:] = 1
        padding[b, :, :5 * node_len] = 1
        parents[b, 0] = -1
        n_nodes[b] = node_len
    adj_mask = adj.astype(bool).repeat(5, axis=1).repeat(5, axis=2)
    attr_mask = np.eye(32, 32, dtype=bool)
    attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1)

    cond['adj_mask'] = adj_mask
    cond['attr_mask'] = attr_mask
    cond['key_pad_mask'] = padding

    cond['adj'] = adj
    cond['parents'] = parents
    cond['n_nodes'] = n_nodes
    cond['cat'] = 'StorageFurniture'

    data = np.zeros((32*5, 6), dtype=bool)

    return data, cond

def bfs_tree_simple(tree_list):
    order = [0] * len(tree_list)
    queue = []
    current_node_idx = 0
    for node_idx, node in enumerate(tree_list):
        if node['parent'] == -1:
            queue.append(node['id'])
            order[current_node_idx] = node_idx
            current_node_idx += 1
            break
    while len(queue) > 0:
        current_node = queue.pop(0)
        for node_idx, node in enumerate(tree_list):
            if node['parent'] == current_node:
                queue.append(node['id'])
                order[current_node_idx] = node_idx
                current_node_idx += 1

    return order

def get_graph_from_gpt(img_path_1, img_path_2):
    first_img = Image.open(img_path_1)
    first_img_data = first_img.resize((1024, 1024))
    buffer = BytesIO()
    first_img_data.save(buffer, format="PNG")
    buffer.seek(0)
    # encode the image as base64
    first_encoded_image = base64.b64encode(buffer.read()).decode("utf-8")


    second_img = Image.open(img_path_2)
    second_img_data = second_img.resize((1024, 1024))
    buffer = BytesIO()
    second_img_data.save(buffer, format="PNG")
    buffer.seek(0)
    # encode the image as base64
    second_encoded_image = base64.b64encode(buffer.read()).decode("utf-8")

    pred_gpt = predict_graph_twomode('', first_img_data=first_encoded_image, second_img_data=second_encoded_image)
    print(pred_gpt)
    pred_graph = pred_gpt['diffuse_tree']
    # order = bfs_tree_simple(pred_graph)
    # pred_graph = [pred_graph[i] for i in order]
    
    
    # generate array [0, 1, 2, ..., 31] for init
    graph_array = np.array([i for i in range(32)])
    for node_idx, node in enumerate(pred_graph):
        if node['parent'] == -1:
            graph_array[node_idx] = node_idx
        else:
            graph_array[node_idx] = node['parent']

    # new axis for batch
    graph_array = np.expand_dims(graph_array, axis=0)

    cat_str = gpt_infer_image_category(first_encoded_image, second_encoded_image)
    
    return torch.from_numpy(graph_array).cuda().repeat(3, 1), cat_str
        
@spaces.GPU
def run_demo(args):
    # extract DINOV2 feature from the input image
    t1 = time.time()
    feat = extract_dino_feature(args.img_path_1, args.img_path_2)
    t2 = time.time()
    print(f'Extracted DINO feature in {t2 - t1:.2f} seconds')
    scheduler = set_scheduler(args.n_denoise_steps)
    # load the checkpoint of the model
    model = load_model(args.ckpt_path, args.config.system.model)

    # inference
    with torch.no_grad():
        t3 = time.time()
        pred_graph, cat_str = get_graph_from_gpt(args.img_path_1, args.img_path_2)
        t4 = time.time()
        print(f'Got the predicted graph in {t4 - t3:.2f} seconds')
        print(pred_graph)
        data, cond = convert_pred_graph(pred_graph)
        inputs = prepare_model_input(data, cond, feat, n_samples=args.n_samples)

        # Update the object category
        cond['cat'] = cat_str
        inputs['cat'][:] = cat_ref[cat_str]
        print(f'Object category predicted by GPT: {cat_str}, {cat_ref[cat_str]}')

        output = forward(model, scheduler, inputs, omega=args.omega).cpu().numpy()
        t5 = time.time()
        print(f'Forwarded the model in {t5 - t4:.2f} seconds')

        # post-process
        post_process(output, cond, args.save_dir, args.gt_data_root, visualize=True)

    # retrieve
    for sample in os.listdir(args.save_dir):
        sample_dir = os.path.join(args.save_dir, sample)
        t6 = time.time()
        run_retrieve(sample_dir, 'object.json', args.gt_data_root)
        t7 = time.time()
        print(f'Retrieved part meshes for in {t7 - t6:.2f} seconds')
    
    save_json_path = os.path.join(args.save_dir, "0", "object.json")
    with open(save_json_path, 'r') as file:
        json_data = json.load(file)
    create_urdf_from_json(json_data, save_json_path.replace('.json', '.urdf'))
    pybullet_render(save_json_path.replace('.json', '.urdf'), os.path.join(args.save_dir, "0"), 8)
    

if __name__ == '__main__':
    '''
    Script for running the inference on an example image input.
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument("--img_path_1", type=str, default='examples/1.png', help="path to the input image")
    parser.add_argument("--img_path_2", type=str, default='examples/1_open_1.png', help="path to the input image")
    parser.add_argument("--ckpt_path", type=str, default='exps/singapo/final/ckpts/last.ckpt', help="path to the checkpoint of the model")
    parser.add_argument("--config_path", type=str, default='exps/singapo/final/config/parsed.yaml', help="path to the config file")
    parser.add_argument("--use_example_graph", action="store_true", default=False, help="if you don't have the openai key yet, turn on to use the example graph for inference")
    parser.add_argument("--save_dir", type=str, default='results', help="path to save the output")
    parser.add_argument("--gt_data_root", type=str, default='./', help="the root directory of the original data, used for part mesh retrieval")
    parser.add_argument("--n_samples", type=int, default=3, help="number of samples to generate given the input")
    parser.add_argument("--omega", type=float, default=0.5, help="the weight of the condition-free mode in the inference")
    parser.add_argument("--n_denoise_steps", type=int, default=100, help="number of denoising steps")
    args = parser.parse_args()

    assert os.path.exists(args.img_path_1), "The input image does not exist"
    # assert os.path.exists(args.ckpt_path), "The checkpoint does not exist"
    assert os.path.exists(args.config_path), "The config file does not exist"
    os.makedirs(args.save_dir, exist_ok=True)

    config = load_config(args.config_path)
    args.config = config

    run_demo(args)