import os, sys sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) import json import math import numpy as np import lightning.pytorch as pl from metrics.iou_cdist import IoU_cDist from my_utils.savermixins import SaverMixin from my_utils.refs import sem_ref, joint_ref from dataset.utils import convert_data_range, parse_tree from my_utils.plot import viz_graph, make_grid, add_text from my_utils.render import draw_boxes_axiss_anim, prepare_meshes from PIL import Image class BaseSystem(pl.LightningModule, SaverMixin): def __init__(self, hparams): super().__init__() self.hparams.update(hparams) def setup(self, stage: str): # config the logger dir for images self.hparams.save_dir = os.path.join(self.hparams.exp_dir, 'output', stage) os.makedirs(self.hparams.save_dir, exist_ok=True) # --------------------------------- visualization --------------------------------- def convert_json(self, x, c, idx, prefix=''): out = {"meta": {}, "diffuse_tree": []} n_nodes = c[f"{prefix}n_nodes"][idx].item() par = c[f"{prefix}parents"][idx].cpu().numpy().tolist() adj = c[f"{prefix}adj"][idx].cpu().numpy() np.fill_diagonal(adj, 0) # remove self-loop for the root node if f"{prefix}obj_cat" in c: out["meta"]["obj_cat"] = c[f"{prefix}obj_cat"][idx] # convert the data to original range data = convert_data_range(x.cpu().numpy()) # parse the tree out["diffuse_tree"] = parse_tree(data, n_nodes, par, adj) return out # def save_val_img(self, pred, gt, cond): # B = pred.shape[0] # pred_imgs, gt_imgs, gt_graphs_view = [], [], [] # for b in range(B): # print(b) # # convert to humnan readable format json # pred_json = self.convert_json(pred[b], cond, b) # gt_json = self.convert_json(gt[b], cond, b) # # visualize bbox and axis # pred_meshes = prepare_meshes(pred_json) # bbox_0, bbox_1, axiss = ( # pred_meshes["bbox_0"], # pred_meshes["bbox_1"], # pred_meshes["axiss"], # ) # pred_img = draw_boxes_axiss_anim( # bbox_0, bbox_1, axiss, mode="graph", resolution=128 # ) # gt_meshes = prepare_meshes(gt_json) # bbox_0, bbox_1, axiss = ( # gt_meshes["bbox_0"], # gt_meshes["bbox_1"], # gt_meshes["axiss"], # ) # gt_img = draw_boxes_axiss_anim( # bbox_0, bbox_1, axiss, mode="graph", resolution=128 # ) # # visualize graph # # gt_graph = viz_graph(gt_json, res=128) # # gt_graph = add_text(cond["name"][b], gt_graph) # # GT views # rgb_view = cond["img"][b].cpu().numpy() # pred_imgs.append(pred_img) # gt_imgs.append(gt_img) # gt_graphs_view.append(rgb_view) # # gt_graphs_view.append(gt_graph) # # save images for generated results # epoch = str(self.current_epoch).zfill(5) # # pred_thumbnails = np.concatenate(pred_imgs, axis=1) # concat batch in width # import ipdb # ipdb.set_trace() # # save images for ground truth # for i in range(math.ceil(len(gt_graphs_view) / 8)): # start = i * 8 # end = min((i + 1) * 8, len(gt_graphs_view)) # pred_thumbnails = np.concatenate(pred_imgs[start:end], axis=1) # gt_graph_imgs = np.concatenate(gt_graphs_view[start:end], axis=1) # gt_thumbnails = np.concatenate(gt_imgs[start:end], axis=1) # concat batch in width # grid = np.concatenate([gt_graph_imgs, gt_thumbnails, pred_thumbnails], axis=0) # self.save_rgb_image(f"new_out_valid_{i}.png", grid) def save_test_step(self, pred, gt, cond, batch_idx, res=128): exp_name = self._get_exp_name() model_name = cond["name"][0].replace("/", '@') save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}" # input image input_img = cond["img"][0].cpu().numpy() # GT recordings if not self.hparams.get('test_no_GT', False): gt_json = self.convert_json(gt[0], cond, 0) # gt_graph = viz_graph(gt_json, res=256) gt_meshes = prepare_meshes(gt_json) bbox_0, bbox_1, axiss = ( gt_meshes["bbox_0"], gt_meshes["bbox_1"], gt_meshes["axiss"], ) gt_img = draw_boxes_axiss_anim(bbox_0, bbox_1, axiss, mode="graph", resolution=res) else: # gt_graph = 255 * np.ones((res, res, 3), dtype=np.uint8) gt_img = 255 * np.ones((res, 2 * res, 3), dtype=np.uint8) gt_block = np.concatenate([input_img, gt_img], axis=1) # recordings for generated results img_blocks = [] for b in range(pred.shape[0]): pred_json = self.convert_json(pred[b], cond, 0) # visualize bbox and axis pred_meshes = prepare_meshes(pred_json) bbox_0, bbox_1, axiss = ( pred_meshes["bbox_0"], pred_meshes["bbox_1"], pred_meshes["axiss"], ) pred_img = draw_boxes_axiss_anim( bbox_0, bbox_1, axiss, mode="graph", resolution=res ) img_blocks.append(pred_img) self.save_json(f"{save_dir}/{b}/object.json", pred_json) # save images for generated results img_grid = make_grid(img_blocks, cols=5) # visualize the input graph # input_graph = viz_graph(pred_json, res=256) # save images # self.save_rgb_image(f"{save_dir}/gt_graph.png", gt_graph) self.save_rgb_image(f"{save_dir}/output.png", img_grid) self.save_rgb_image(f"{save_dir}/gt.png", gt_block) # self.save_rgb_image(f"{save_dir}/input_graph.png", input_graph) def _save_html_end(self): exp_name = self._get_exp_name() save_dir = self.get_save_path(exp_name) cases = sorted(os.listdir(save_dir), key=lambda x: int(x.split("@")[0])) html_head = """
| Object ID | Metrics (avg) | Input image + GT object + GT graph | Input graph |
|---|---|---|---|
| {case} |
[AS-cDist] {aid_cdist} [RS-cDist] {rid_cdist} ----------------------- [AS-IoU] {aid_iou} [RS-IoU] {rid_iou} ----------------------- [RS-CD] {rid_cd} [AS-CD] {aid_cd} ----------------------- [AOR] {aor} |
|
|
| Generated samples | |||
![]() |
|||