import os, sys sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) import json import math import numpy as np import lightning.pytorch as pl from metrics.iou_cdist import IoU_cDist from my_utils.savermixins import SaverMixin from my_utils.refs import sem_ref, joint_ref from dataset.utils import convert_data_range, parse_tree from my_utils.plot import viz_graph, make_grid, add_text from my_utils.render import draw_boxes_axiss_anim, prepare_meshes from PIL import Image class BaseSystem(pl.LightningModule, SaverMixin): def __init__(self, hparams): super().__init__() self.hparams.update(hparams) def setup(self, stage: str): # config the logger dir for images self.hparams.save_dir = os.path.join(self.hparams.exp_dir, 'output', stage) os.makedirs(self.hparams.save_dir, exist_ok=True) # --------------------------------- visualization --------------------------------- def convert_json(self, x, c, idx, prefix=''): out = {"meta": {}, "diffuse_tree": []} n_nodes = c[f"{prefix}n_nodes"][idx].item() par = c[f"{prefix}parents"][idx].cpu().numpy().tolist() adj = c[f"{prefix}adj"][idx].cpu().numpy() np.fill_diagonal(adj, 0) # remove self-loop for the root node if f"{prefix}obj_cat" in c: out["meta"]["obj_cat"] = c[f"{prefix}obj_cat"][idx] # convert the data to original range data = convert_data_range(x.cpu().numpy()) # parse the tree out["diffuse_tree"] = parse_tree(data, n_nodes, par, adj) return out # def save_val_img(self, pred, gt, cond): # B = pred.shape[0] # pred_imgs, gt_imgs, gt_graphs_view = [], [], [] # for b in range(B): # print(b) # # convert to humnan readable format json # pred_json = self.convert_json(pred[b], cond, b) # gt_json = self.convert_json(gt[b], cond, b) # # visualize bbox and axis # pred_meshes = prepare_meshes(pred_json) # bbox_0, bbox_1, axiss = ( # pred_meshes["bbox_0"], # pred_meshes["bbox_1"], # pred_meshes["axiss"], # ) # pred_img = draw_boxes_axiss_anim( # bbox_0, bbox_1, axiss, mode="graph", resolution=128 # ) # gt_meshes = prepare_meshes(gt_json) # bbox_0, bbox_1, axiss = ( # gt_meshes["bbox_0"], # gt_meshes["bbox_1"], # gt_meshes["axiss"], # ) # gt_img = draw_boxes_axiss_anim( # bbox_0, bbox_1, axiss, mode="graph", resolution=128 # ) # # visualize graph # # gt_graph = viz_graph(gt_json, res=128) # # gt_graph = add_text(cond["name"][b], gt_graph) # # GT views # rgb_view = cond["img"][b].cpu().numpy() # pred_imgs.append(pred_img) # gt_imgs.append(gt_img) # gt_graphs_view.append(rgb_view) # # gt_graphs_view.append(gt_graph) # # save images for generated results # epoch = str(self.current_epoch).zfill(5) # # pred_thumbnails = np.concatenate(pred_imgs, axis=1) # concat batch in width # import ipdb # ipdb.set_trace() # # save images for ground truth # for i in range(math.ceil(len(gt_graphs_view) / 8)): # start = i * 8 # end = min((i + 1) * 8, len(gt_graphs_view)) # pred_thumbnails = np.concatenate(pred_imgs[start:end], axis=1) # gt_graph_imgs = np.concatenate(gt_graphs_view[start:end], axis=1) # gt_thumbnails = np.concatenate(gt_imgs[start:end], axis=1) # concat batch in width # grid = np.concatenate([gt_graph_imgs, gt_thumbnails, pred_thumbnails], axis=0) # self.save_rgb_image(f"new_out_valid_{i}.png", grid) def save_test_step(self, pred, gt, cond, batch_idx, res=128): exp_name = self._get_exp_name() model_name = cond["name"][0].replace("/", '@') save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}" # input image input_img = cond["img"][0].cpu().numpy() # GT recordings if not self.hparams.get('test_no_GT', False): gt_json = self.convert_json(gt[0], cond, 0) # gt_graph = viz_graph(gt_json, res=256) gt_meshes = prepare_meshes(gt_json) bbox_0, bbox_1, axiss = ( gt_meshes["bbox_0"], gt_meshes["bbox_1"], gt_meshes["axiss"], ) gt_img = draw_boxes_axiss_anim(bbox_0, bbox_1, axiss, mode="graph", resolution=res) else: # gt_graph = 255 * np.ones((res, res, 3), dtype=np.uint8) gt_img = 255 * np.ones((res, 2 * res, 3), dtype=np.uint8) gt_block = np.concatenate([input_img, gt_img], axis=1) # recordings for generated results img_blocks = [] for b in range(pred.shape[0]): pred_json = self.convert_json(pred[b], cond, 0) # visualize bbox and axis pred_meshes = prepare_meshes(pred_json) bbox_0, bbox_1, axiss = ( pred_meshes["bbox_0"], pred_meshes["bbox_1"], pred_meshes["axiss"], ) pred_img = draw_boxes_axiss_anim( bbox_0, bbox_1, axiss, mode="graph", resolution=res ) img_blocks.append(pred_img) self.save_json(f"{save_dir}/{b}/object.json", pred_json) # save images for generated results img_grid = make_grid(img_blocks, cols=5) # visualize the input graph # input_graph = viz_graph(pred_json, res=256) # save images # self.save_rgb_image(f"{save_dir}/gt_graph.png", gt_graph) self.save_rgb_image(f"{save_dir}/output.png", img_grid) self.save_rgb_image(f"{save_dir}/gt.png", gt_block) # self.save_rgb_image(f"{save_dir}/input_graph.png", input_graph) def _save_html_end(self): exp_name = self._get_exp_name() save_dir = self.get_save_path(exp_name) cases = sorted(os.listdir(save_dir), key=lambda x: int(x.split("@")[0])) html_head = """ Test Image Results """ total = len(cases) each = 200 n_pages = total // each + 1 for p in range(n_pages): html_content = html_head for i in range(p * each, min((p + 1) * each, total)): case = cases[i] if self.hparams.get("test_no_GT", False): aid_iou = rid_iou = aid_cdist = rid_cdist = aid_cd = rid_cd = aor = "N/A" else: with open(os.path.join(save_dir, case, "metrics.json"), "r") as f: metrics = json.load(f)["avg"] aid_iou = round(metrics["AS-IoU"], 4) rid_iou = round(metrics["RS-IoU"], 4) aid_cdist = round(metrics["AS-cDist"], 4) rid_cdist = round(metrics["RS-cDist"], 4) aid_cd = round(metrics["AS-CD"], 4) rid_cd = round(metrics["RS-CD"], 4) aor = metrics["AOR"] if aor is not None: aor = round(aor, 4) html_content += f""" """ html_content += """
Object ID Metrics (avg) Input image + GT object + GT graph Input graph
{case} [AS-cDist] {aid_cdist}
[RS-cDist] {rid_cdist}
-----------------------
[AS-IoU] {aid_iou}
[RS-IoU] {rid_iou}
-----------------------
[RS-CD] {rid_cd}
[AS-CD] {aid_cd}
-----------------------
[AOR] {aor}
GT Image Graph Image Graph Image
Generated samples
Generated Image
""" outfile = self.get_save_path(f"{exp_name}_page_{p+1}.html") with open(outfile, "w") as file: file.write(html_content) def val_compute_metrics(self, pred, gt, cond): loss_dict = {} B = pred.shape[0] as_ious = 0.0 rs_ious = 0.0 as_cdists = 0.0 rs_cdists = 0.0 for b in range(B): gt_json = self.convert_json(gt[b], cond, b) pred_json = self.convert_json(pred[b], cond, b) scores = IoU_cDist( pred_json, gt_json, num_states=5, compare_handles=True, iou_include_base=True, ) as_ious += scores['AS-IoU'] rs_ious += scores['RS-IoU'] as_cdists += scores['AS-cDist'] rs_cdists += scores['RS-cDist'] as_ious /= B rs_ious /= B as_cdists /= B rs_cdists /= B loss_dict['val/AS-IoU'] = as_ious loss_dict['val/RS-IoU'] = rs_ious loss_dict['val/AS-cDist'] = as_cdists loss_dict['val/RS-cDist'] = rs_cdists return loss_dict def _get_exp_name(self): which_ds = self.hparams.get("test_which", 'pm') is_pred_G = self.hparams.get("test_pred_G", False) is_label_free = self.hparams.get("test_label_free", False) guidance_scaler = self.hparams.get("guidance_scaler", 0) # config saving directory exp_postfix = f"_w={guidance_scaler}_{which_ds}" if is_pred_G: exp_postfix += "_pred_G" if is_label_free: exp_postfix += "_label_free" exp_name = "epoch_" + str(self.current_epoch).zfill(3) + exp_postfix return exp_name