Spaces:
Running
on
Zero
Running
on
Zero
| import os, sys | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | |
| import json | |
| import torch | |
| import argparse | |
| import numpy as np | |
| from PIL import Image, ImageOps | |
| import imageio | |
| # from my_utils.plot import viz_graph | |
| from my_utils.misc import load_config | |
| import torchvision.transforms as T | |
| from diffusers import DDPMScheduler | |
| from models.denoiser import Denoiser | |
| from scripts.json2urdf import create_urdf_from_json, pybullet_render | |
| from dataset.utils import make_white_background, load_input_from, convert_data_range, parse_tree | |
| import models | |
| import torch.nn.functional as F | |
| from io import BytesIO | |
| import base64 | |
| from scripts.graph_pred.api import predict_graph_twomode, gpt_infer_image_category | |
| import subprocess | |
| import spaces | |
| import time | |
| cat_ref = { | |
| "Table": 0, | |
| "Dishwasher": 1, | |
| "StorageFurniture": 2, | |
| "Refrigerator": 3, | |
| "WashingMachine": 4, | |
| "Microwave": 5, | |
| "Oven": 6, | |
| } | |
| def run_retrieve(src_dir, json_name, data_root): | |
| fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root] | |
| try: | |
| subprocess.run(fn_call, check=True, stderr=subprocess.STDOUT) | |
| except subprocess.CalledProcessError as e: | |
| print(f'Error from run_retrieve: {src_dir}') | |
| print(f'Error: {e}') | |
| def make_white_background(src_img): | |
| '''Make the white background for the input RGBA image.''' | |
| src_img.load() | |
| background = Image.new("RGB", src_img.size, (255, 255, 255)) | |
| background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel | |
| return background | |
| def pad_to_square(img, fill=0): | |
| """Pad image to square with given fill value (default: 0 = black).""" | |
| width, height = img.size | |
| if width == height: | |
| return img | |
| max_side = max(width, height) | |
| delta_w = max_side - width | |
| delta_h = max_side - height | |
| padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2) | |
| return ImageOps.expand(img, padding, fill=fill) | |
| def load_img(img_path): | |
| transform = T.Compose([ | |
| T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ]) | |
| with Image.open(img_path) as img: | |
| if img.mode == 'RGBA': | |
| img = make_white_background(img) | |
| img = img.convert('RGB') # Ensure it's 3-channel for normalization | |
| img = pad_to_square(img, fill=0) | |
| img = transform(img) | |
| img_batch = img.unsqueeze(0).cuda() | |
| return img_batch | |
| def load_frame_with_imageio(frame): | |
| """ | |
| 将单帧图像处理为符合 DINO 模型输入的格式。 | |
| """ | |
| transform = T.Compose([ | |
| T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ]) | |
| img = Image.fromarray(frame) # 转为 PIL 图像 | |
| if img.mode == 'RGBA': | |
| img = make_white_background(img) | |
| img = transform(img) # 应用预处理 | |
| return img.unsqueeze(0).cuda() # 增加 batch 维度 | |
| def read_video_as_batch_with_imageio(video_path): | |
| """ | |
| 使用 imageio 读取视频并将所有帧处理为 batch 格式 (B, C, H, W)。 | |
| """ | |
| reader = imageio.get_reader(video_path) | |
| batch_frames = [] | |
| try: | |
| for frame in reader: | |
| # 加载帧并处理为 (1, C, H, W) | |
| processed_frame = load_frame_with_imageio(frame) | |
| batch_frames.append(processed_frame) | |
| reader.close() | |
| if batch_frames: | |
| return torch.cat(batch_frames, dim=0).cuda() # 在 batch 维度堆叠,并转移到 GPU | |
| else: | |
| print("视频没有有效帧") | |
| return None | |
| except Exception as e: | |
| print(f"处理视频时出错: {e}") | |
| return None | |
| def extract_dino_feature(img_path_1, img_path_2): | |
| print('Extracting DINO feature...') | |
| feat_1 = load_img(img_path_1) | |
| feat_2 = load_img(img_path_2) | |
| frames = torch.cat([feat_1, feat_2], dim=0) | |
| dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg', pretrained=True).cuda() | |
| print('step4') | |
| with torch.no_grad(): | |
| feat = dinov2_vitb14_reg.forward_features(frames)["x_norm_patchtokens"] | |
| # release the GPU memory of the model | |
| feat_input = torch.cat([feat[0], feat[-1]], dim=0).unsqueeze(0) | |
| print('Extracting DINO feature over') | |
| torch.cuda.empty_cache() | |
| return feat_input | |
| def set_scheduler(n_steps=100): | |
| scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='linear', prediction_type='epsilon') | |
| scheduler.set_timesteps(n_steps) | |
| return scheduler | |
| def prepare_model_input(data, cond, feat, n_samples): | |
| # attention masks | |
| attr_mask = torch.from_numpy(cond['attr_mask']).unsqueeze(0).repeat(n_samples, 1, 1) | |
| key_pad_mask = torch.from_numpy(cond['key_pad_mask']) | |
| graph_mask = torch.from_numpy(cond['adj_mask']) | |
| # input image feature | |
| f = feat.repeat(n_samples, 1, 1) | |
| # input noise | |
| B, C = data.shape | |
| noise = torch.randn([n_samples, B, C], dtype=torch.float32) | |
| # dummy image feature (used for guided diffusion) | |
| dummy_feat = torch.from_numpy(np.load('systems/dino_dummy.npy').astype(np.float32)) | |
| dummy_feat = dummy_feat.unsqueeze(0).repeat(n_samples, 1, 1) | |
| # dummy object category | |
| cat = torch.zeros(1, dtype=torch.long).repeat(n_samples) | |
| return { | |
| "noise": noise.cuda(), | |
| "attr_mask": attr_mask.cuda(), | |
| "key_pad_mask": key_pad_mask.cuda(), | |
| "graph_mask": graph_mask.cuda(), | |
| "dummy_f": dummy_feat.cuda(), | |
| 'cat': cat.cuda(), | |
| 'f': f.cuda(), | |
| } | |
| def prepare_model_input_nocond(feat, n_samples): | |
| # attention masks | |
| cond_example = np.zeros((32*5, 32*5), dtype=bool) | |
| attr_mask = np.eye(32, 32, dtype=bool) | |
| attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1) | |
| attr_mask = torch.from_numpy(attr_mask).unsqueeze(0).repeat(n_samples, 1, 1) | |
| key_pad_mask = torch.from_numpy(cond_example).unsqueeze(0).repeat(n_samples, 1, 1) | |
| graph_mask = torch.from_numpy(cond_example).unsqueeze(0).repeat(n_samples, 1, 1) | |
| # input image feature | |
| f = feat.repeat(n_samples, 1, 1) | |
| # input noise | |
| data = np.zeros((32*5, 6), dtype=bool) | |
| noise = torch.randn(data.shape, dtype=torch.float32).repeat(n_samples, 1, 1) | |
| # dummy image feature (used for guided diffusion) | |
| dummy_feat = torch.from_numpy(np.load('systems/dino_dummy.npy').astype(np.float32)) | |
| dummy_feat = dummy_feat.unsqueeze(0).repeat(n_samples, 1, 1) | |
| # dummy object category | |
| cat = torch.zeros(1, dtype=torch.long).repeat(n_samples) | |
| return { | |
| "noise": noise.cuda(), | |
| "attr_mask": attr_mask.cuda(), | |
| "key_pad_mask": key_pad_mask.cuda(), | |
| "graph_mask": graph_mask.cuda(), | |
| "dummy_f": dummy_feat.cuda(), | |
| 'cat': cat.cuda(), | |
| 'f': f.cuda(), | |
| } | |
| def save_graph(pred_graph, save_dir): | |
| print(f'Saving the predicted graph to {save_dir}/pred_graph.json') | |
| # save the response | |
| with open(os.path.join(save_dir, "pred_graph.json"), "w") as f: | |
| json.dump(pred_graph, f, indent=4) | |
| # Visualize the graph | |
| # img_graph = Image.fromarray(viz_graph(pred_graph)) | |
| # img_graph.save(os.path.join(save_dir, "pred_graph.png")) | |
| def forward(model, scheduler, inputs, omega=0.5): | |
| print('Running inference...') | |
| noisy_x = inputs['noise'] | |
| for t in scheduler.timesteps: | |
| timesteps = torch.tensor([t], device=inputs['noise'].device) | |
| outputs_cond = model( | |
| x=noisy_x, | |
| cat=inputs['cat'], | |
| timesteps=timesteps, | |
| feat=inputs['f'], | |
| key_pad_mask=inputs['key_pad_mask'], | |
| graph_mask=inputs['graph_mask'], | |
| attr_mask=inputs['attr_mask'], | |
| label_free=True, | |
| ) # take condtional image as input | |
| if omega != 0: | |
| outputs_free = model( | |
| x=noisy_x, | |
| cat=inputs['cat'], | |
| timesteps=timesteps, | |
| feat=inputs['dummy_f'], | |
| key_pad_mask=inputs['key_pad_mask'], | |
| graph_mask=inputs['graph_mask'], | |
| attr_mask=inputs['attr_mask'], | |
| label_free=True, | |
| ) # take the dummy DINO features for the condition-free mode | |
| noise_pred = (1 + omega) * outputs_cond['noise_pred'] - omega * outputs_free['noise_pred'] | |
| else: | |
| noise_pred = outputs_cond['noise_pred'] | |
| noisy_x = scheduler.step(noise_pred, t, noisy_x).prev_sample | |
| return noisy_x | |
| def _convert_json(x, c): | |
| out = {"meta": {}, "diffuse_tree": []} | |
| n_nodes = c["n_nodes"] | |
| par = c["parents"].tolist() | |
| adj = c["adj"] | |
| np.fill_diagonal(adj, 0) # remove self-loop for the root node | |
| if "obj_cat" in c: | |
| out["meta"]["obj_cat"] = c["obj_cat"] | |
| # convert the data to original range | |
| data = convert_data_range(x) | |
| # parse the tree | |
| out["diffuse_tree"] = parse_tree(data, n_nodes, par, adj) | |
| return out | |
| def post_process(output, cond, save_root, gt_data_root, visualize=False): | |
| print('Post-processing...') | |
| # N = output.shape[0] | |
| N = 1 | |
| for i in range(N): | |
| cond_n = {} | |
| cond_n['n_nodes'] = cond['n_nodes'][i] | |
| cond_n['parents'] = cond['parents'][i] | |
| cond_n['adj'] = cond['adj'][i] | |
| cond_n['obj_cat'] = cond['cat'] | |
| # convert the raw model output to the json format | |
| out_json = _convert_json(output, cond_n) | |
| save_dir = os.path.join(save_root, str(i)) | |
| os.makedirs(save_dir, exist_ok=True) | |
| with open(os.path.join(save_dir, "object.json"), "w") as f: | |
| json.dump(out_json, f, indent=4) | |
| # retrieve part meshes (call python script) | |
| # print(f"Retrieving part meshes for the object {i}...") | |
| # os.system(f"python scripts/mesh_retrieval/retrieve.py --src_dir {save_dir} --json_name object.json --gt_data_root {gt_data_root}") | |
| def load_model(ckpt_path, config): | |
| print('Loading model from checkpoint...') | |
| model = models.make(config.name, config) | |
| state_dict = torch.load(ckpt_path) | |
| state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model.cuda() | |
| def convert_pred_graph(pred_graph): | |
| cond = {} | |
| B, K = pred_graph.shape[:2] | |
| adj = np.zeros((B, K, K), dtype=np.float32) | |
| padding = np.zeros((B, 5 * K, 5* K), dtype=bool) | |
| parents = np.zeros((B, K), dtype=np.int32) | |
| n_nodes = np.zeros((B,), dtype=np.int32) | |
| for b in range(B): | |
| node_len = 0 | |
| for k in range(K): | |
| if pred_graph[b, k] == k and k > 0: | |
| node_len = k | |
| break | |
| node = pred_graph[b, k] | |
| adj[b, k, node] = 1 | |
| adj[b, node, k] = 1 | |
| parents[b, k] = node | |
| adj[b, node_len:] = 1 | |
| padding[b, :, :5 * node_len] = 1 | |
| parents[b, 0] = -1 | |
| n_nodes[b] = node_len | |
| adj_mask = adj.astype(bool).repeat(5, axis=1).repeat(5, axis=2) | |
| attr_mask = np.eye(32, 32, dtype=bool) | |
| attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1) | |
| cond['adj_mask'] = adj_mask | |
| cond['attr_mask'] = attr_mask | |
| cond['key_pad_mask'] = padding | |
| cond['adj'] = adj | |
| cond['parents'] = parents | |
| cond['n_nodes'] = n_nodes | |
| cond['cat'] = 'StorageFurniture' | |
| data = np.zeros((32*5, 6), dtype=bool) | |
| return data, cond | |
| def bfs_tree_simple(tree_list): | |
| order = [0] * len(tree_list) | |
| queue = [] | |
| current_node_idx = 0 | |
| for node_idx, node in enumerate(tree_list): | |
| if node['parent'] == -1: | |
| queue.append(node['id']) | |
| order[current_node_idx] = node_idx | |
| current_node_idx += 1 | |
| break | |
| while len(queue) > 0: | |
| current_node = queue.pop(0) | |
| for node_idx, node in enumerate(tree_list): | |
| if node['parent'] == current_node: | |
| queue.append(node['id']) | |
| order[current_node_idx] = node_idx | |
| current_node_idx += 1 | |
| return order | |
| def get_graph_from_gpt(img_path_1, img_path_2): | |
| first_img = Image.open(img_path_1) | |
| first_img_data = first_img.resize((1024, 1024)) | |
| buffer = BytesIO() | |
| first_img_data.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| # encode the image as base64 | |
| first_encoded_image = base64.b64encode(buffer.read()).decode("utf-8") | |
| second_img = Image.open(img_path_2) | |
| second_img_data = second_img.resize((1024, 1024)) | |
| buffer = BytesIO() | |
| second_img_data.save(buffer, format="PNG") | |
| buffer.seek(0) | |
| # encode the image as base64 | |
| second_encoded_image = base64.b64encode(buffer.read()).decode("utf-8") | |
| pred_gpt = predict_graph_twomode('', first_img_data=first_encoded_image, second_img_data=second_encoded_image) | |
| print(pred_gpt) | |
| pred_graph = pred_gpt['diffuse_tree'] | |
| # order = bfs_tree_simple(pred_graph) | |
| # pred_graph = [pred_graph[i] for i in order] | |
| # generate array [0, 1, 2, ..., 31] for init | |
| graph_array = np.array([i for i in range(32)]) | |
| for node_idx, node in enumerate(pred_graph): | |
| if node['parent'] == -1: | |
| graph_array[node_idx] = node_idx | |
| else: | |
| graph_array[node_idx] = node['parent'] | |
| # new axis for batch | |
| graph_array = np.expand_dims(graph_array, axis=0) | |
| cat_str = gpt_infer_image_category(first_encoded_image, second_encoded_image) | |
| return torch.from_numpy(graph_array).cuda().repeat(3, 1), cat_str | |
| def run_demo(args): | |
| # extract DINOV2 feature from the input image | |
| t1 = time.time() | |
| feat = extract_dino_feature(args.img_path_1, args.img_path_2) | |
| t2 = time.time() | |
| print(f'Extracted DINO feature in {t2 - t1:.2f} seconds') | |
| scheduler = set_scheduler(args.n_denoise_steps) | |
| # load the checkpoint of the model | |
| model = load_model(args.ckpt_path, args.config.system.model) | |
| # inference | |
| with torch.no_grad(): | |
| t3 = time.time() | |
| pred_graph, cat_str = get_graph_from_gpt(args.img_path_1, args.img_path_2) | |
| t4 = time.time() | |
| print(f'Got the predicted graph in {t4 - t3:.2f} seconds') | |
| print(pred_graph) | |
| data, cond = convert_pred_graph(pred_graph) | |
| inputs = prepare_model_input(data, cond, feat, n_samples=args.n_samples) | |
| # Update the object category | |
| cond['cat'] = cat_str | |
| inputs['cat'][:] = cat_ref[cat_str] | |
| print(f'Object category predicted by GPT: {cat_str}, {cat_ref[cat_str]}') | |
| output = forward(model, scheduler, inputs, omega=args.omega).cpu().numpy() | |
| t5 = time.time() | |
| print(f'Forwarded the model in {t5 - t4:.2f} seconds') | |
| # post-process | |
| post_process(output, cond, args.save_dir, args.gt_data_root, visualize=True) | |
| # retrieve | |
| for sample in os.listdir(args.save_dir): | |
| sample_dir = os.path.join(args.save_dir, sample) | |
| t6 = time.time() | |
| run_retrieve(sample_dir, 'object.json', args.gt_data_root) | |
| t7 = time.time() | |
| print(f'Retrieved part meshes for in {t7 - t6:.2f} seconds') | |
| save_json_path = os.path.join(args.save_dir, "0", "object.json") | |
| with open(save_json_path, 'r') as file: | |
| json_data = json.load(file) | |
| create_urdf_from_json(json_data, save_json_path.replace('.json', '.urdf')) | |
| pybullet_render(save_json_path.replace('.json', '.urdf'), os.path.join(args.save_dir, "0"), 8) | |
| if __name__ == '__main__': | |
| ''' | |
| Script for running the inference on an example image input. | |
| ''' | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--img_path_1", type=str, default='examples/1.png', help="path to the input image") | |
| parser.add_argument("--img_path_2", type=str, default='examples/1_open_1.png', help="path to the input image") | |
| parser.add_argument("--ckpt_path", type=str, default='exps/singapo/final/ckpts/last.ckpt', help="path to the checkpoint of the model") | |
| parser.add_argument("--config_path", type=str, default='exps/singapo/final/config/parsed.yaml', help="path to the config file") | |
| parser.add_argument("--use_example_graph", action="store_true", default=False, help="if you don't have the openai key yet, turn on to use the example graph for inference") | |
| parser.add_argument("--save_dir", type=str, default='results', help="path to save the output") | |
| parser.add_argument("--gt_data_root", type=str, default='./', help="the root directory of the original data, used for part mesh retrieval") | |
| parser.add_argument("--n_samples", type=int, default=3, help="number of samples to generate given the input") | |
| parser.add_argument("--omega", type=float, default=0.5, help="the weight of the condition-free mode in the inference") | |
| parser.add_argument("--n_denoise_steps", type=int, default=100, help="number of denoising steps") | |
| args = parser.parse_args() | |
| assert os.path.exists(args.img_path_1), "The input image does not exist" | |
| # assert os.path.exists(args.ckpt_path), "The checkpoint does not exist" | |
| assert os.path.exists(args.config_path), "The config file does not exist" | |
| os.makedirs(args.save_dir, exist_ok=True) | |
| config = load_config(args.config_path) | |
| args.config = config | |
| run_demo(args) |