DIPO / systems /base.py
xinjjj's picture
Upload 29 files
ce34030 verified
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
import json
import math
import numpy as np
import lightning.pytorch as pl
from metrics.iou_cdist import IoU_cDist
from my_utils.savermixins import SaverMixin
from my_utils.refs import sem_ref, joint_ref
from dataset.utils import convert_data_range, parse_tree
from my_utils.plot import viz_graph, make_grid, add_text
from my_utils.render import draw_boxes_axiss_anim, prepare_meshes
from PIL import Image
class BaseSystem(pl.LightningModule, SaverMixin):
def __init__(self, hparams):
super().__init__()
self.hparams.update(hparams)
def setup(self, stage: str):
# config the logger dir for images
self.hparams.save_dir = os.path.join(self.hparams.exp_dir, 'output', stage)
os.makedirs(self.hparams.save_dir, exist_ok=True)
# --------------------------------- visualization ---------------------------------
def convert_json(self, x, c, idx, prefix=''):
out = {"meta": {}, "diffuse_tree": []}
n_nodes = c[f"{prefix}n_nodes"][idx].item()
par = c[f"{prefix}parents"][idx].cpu().numpy().tolist()
adj = c[f"{prefix}adj"][idx].cpu().numpy()
np.fill_diagonal(adj, 0) # remove self-loop for the root node
if f"{prefix}obj_cat" in c:
out["meta"]["obj_cat"] = c[f"{prefix}obj_cat"][idx]
# convert the data to original range
data = convert_data_range(x.cpu().numpy())
# parse the tree
out["diffuse_tree"] = parse_tree(data, n_nodes, par, adj)
return out
# def save_val_img(self, pred, gt, cond):
# B = pred.shape[0]
# pred_imgs, gt_imgs, gt_graphs_view = [], [], []
# for b in range(B):
# print(b)
# # convert to humnan readable format json
# pred_json = self.convert_json(pred[b], cond, b)
# gt_json = self.convert_json(gt[b], cond, b)
# # visualize bbox and axis
# pred_meshes = prepare_meshes(pred_json)
# bbox_0, bbox_1, axiss = (
# pred_meshes["bbox_0"],
# pred_meshes["bbox_1"],
# pred_meshes["axiss"],
# )
# pred_img = draw_boxes_axiss_anim(
# bbox_0, bbox_1, axiss, mode="graph", resolution=128
# )
# gt_meshes = prepare_meshes(gt_json)
# bbox_0, bbox_1, axiss = (
# gt_meshes["bbox_0"],
# gt_meshes["bbox_1"],
# gt_meshes["axiss"],
# )
# gt_img = draw_boxes_axiss_anim(
# bbox_0, bbox_1, axiss, mode="graph", resolution=128
# )
# # visualize graph
# # gt_graph = viz_graph(gt_json, res=128)
# # gt_graph = add_text(cond["name"][b], gt_graph)
# # GT views
# rgb_view = cond["img"][b].cpu().numpy()
# pred_imgs.append(pred_img)
# gt_imgs.append(gt_img)
# gt_graphs_view.append(rgb_view)
# # gt_graphs_view.append(gt_graph)
# # save images for generated results
# epoch = str(self.current_epoch).zfill(5)
# # pred_thumbnails = np.concatenate(pred_imgs, axis=1) # concat batch in width
# import ipdb
# ipdb.set_trace()
# # save images for ground truth
# for i in range(math.ceil(len(gt_graphs_view) / 8)):
# start = i * 8
# end = min((i + 1) * 8, len(gt_graphs_view))
# pred_thumbnails = np.concatenate(pred_imgs[start:end], axis=1)
# gt_graph_imgs = np.concatenate(gt_graphs_view[start:end], axis=1)
# gt_thumbnails = np.concatenate(gt_imgs[start:end], axis=1) # concat batch in width
# grid = np.concatenate([gt_graph_imgs, gt_thumbnails, pred_thumbnails], axis=0)
# self.save_rgb_image(f"new_out_valid_{i}.png", grid)
def save_test_step(self, pred, gt, cond, batch_idx, res=128):
exp_name = self._get_exp_name()
model_name = cond["name"][0].replace("/", '@')
save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}"
# input image
input_img = cond["img"][0].cpu().numpy()
# GT recordings
if not self.hparams.get('test_no_GT', False):
gt_json = self.convert_json(gt[0], cond, 0)
# gt_graph = viz_graph(gt_json, res=256)
gt_meshes = prepare_meshes(gt_json)
bbox_0, bbox_1, axiss = (
gt_meshes["bbox_0"],
gt_meshes["bbox_1"],
gt_meshes["axiss"],
)
gt_img = draw_boxes_axiss_anim(bbox_0, bbox_1, axiss, mode="graph", resolution=res)
else:
# gt_graph = 255 * np.ones((res, res, 3), dtype=np.uint8)
gt_img = 255 * np.ones((res, 2 * res, 3), dtype=np.uint8)
gt_block = np.concatenate([input_img, gt_img], axis=1)
# recordings for generated results
img_blocks = []
for b in range(pred.shape[0]):
pred_json = self.convert_json(pred[b], cond, 0)
# visualize bbox and axis
pred_meshes = prepare_meshes(pred_json)
bbox_0, bbox_1, axiss = (
pred_meshes["bbox_0"],
pred_meshes["bbox_1"],
pred_meshes["axiss"],
)
pred_img = draw_boxes_axiss_anim(
bbox_0, bbox_1, axiss, mode="graph", resolution=res
)
img_blocks.append(pred_img)
self.save_json(f"{save_dir}/{b}/object.json", pred_json)
# save images for generated results
img_grid = make_grid(img_blocks, cols=5)
# visualize the input graph
# input_graph = viz_graph(pred_json, res=256)
# save images
# self.save_rgb_image(f"{save_dir}/gt_graph.png", gt_graph)
self.save_rgb_image(f"{save_dir}/output.png", img_grid)
self.save_rgb_image(f"{save_dir}/gt.png", gt_block)
# self.save_rgb_image(f"{save_dir}/input_graph.png", input_graph)
def _save_html_end(self):
exp_name = self._get_exp_name()
save_dir = self.get_save_path(exp_name)
cases = sorted(os.listdir(save_dir), key=lambda x: int(x.split("@")[0]))
html_head = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Test Image Results</title>
<style>
table {
width: 100%;
border-collapse: collapse;
}
th, td {
border: 1px solid black;
padding: 8px;
text-align: left;
}
.separator {
border-top: 2px solid black;
}
</style>
</head>
<body>
<table>
"""
total = len(cases)
each = 200
n_pages = total // each + 1
for p in range(n_pages):
html_content = html_head
for i in range(p * each, min((p + 1) * each, total)):
case = cases[i]
if self.hparams.get("test_no_GT", False):
aid_iou = rid_iou = aid_cdist = rid_cdist = aid_cd = rid_cd = aor = "N/A"
else:
with open(os.path.join(save_dir, case, "metrics.json"), "r") as f:
metrics = json.load(f)["avg"]
aid_iou = round(metrics["AS-IoU"], 4)
rid_iou = round(metrics["RS-IoU"], 4)
aid_cdist = round(metrics["AS-cDist"], 4)
rid_cdist = round(metrics["RS-cDist"], 4)
aid_cd = round(metrics["AS-CD"], 4)
rid_cd = round(metrics["RS-CD"], 4)
aor = metrics["AOR"]
if aor is not None:
aor = round(aor, 4)
html_content += f"""
<tr>
<th>Object ID</th>
<th>Metrics (avg) </th>
<th>Input image + GT object + GT graph</th>
<th>Input graph </th>
</tr>
<tr>
<td rowspan="3">{case}</td>
<td>
[AS-cDist] {aid_cdist}<br>
[RS-cDist] {rid_cdist}<br>
-----------------------<br>
[AS-IoU] {aid_iou}<br>
[RS-IoU] {rid_iou}<br>
-----------------------<br>
[RS-CD] {rid_cd}<br>
[AS-CD] {aid_cd}<br>
-----------------------<br>
[AOR] {aor}<br>
</td>
<td>
<img src="{exp_name}/{case}/gt.png" alt="GT Image" style="height: 128px; width: 3*128px;">
<img src="{exp_name}/{case}/gt_graph.png" alt="Graph Image" style="height: 128px; width: 3*128px;">
</td>
<td>
<img src="{exp_name}/{case}/input_graph.png" alt="Graph Image" style="height: 128px; width: 3*128px;">
</td>
</tr>
<tr><th colspan="3">Generated samples</th></tr>
<tr>
<td colspan="3"><img src="{exp_name}/{case}/output.png" alt="Generated Image" style="height: 3*128px; width: 10*128px;"></td>
</tr>
<tr class="separator"><td colspan="4"></td></tr>
"""
html_content += """</table></body></html>"""
outfile = self.get_save_path(f"{exp_name}_page_{p+1}.html")
with open(outfile, "w") as file:
file.write(html_content)
def val_compute_metrics(self, pred, gt, cond):
loss_dict = {}
B = pred.shape[0]
as_ious = 0.0
rs_ious = 0.0
as_cdists = 0.0
rs_cdists = 0.0
for b in range(B):
gt_json = self.convert_json(gt[b], cond, b)
pred_json = self.convert_json(pred[b], cond, b)
scores = IoU_cDist(
pred_json,
gt_json,
num_states=5,
compare_handles=True,
iou_include_base=True,
)
as_ious += scores['AS-IoU']
rs_ious += scores['RS-IoU']
as_cdists += scores['AS-cDist']
rs_cdists += scores['RS-cDist']
as_ious /= B
rs_ious /= B
as_cdists /= B
rs_cdists /= B
loss_dict['val/AS-IoU'] = as_ious
loss_dict['val/RS-IoU'] = rs_ious
loss_dict['val/AS-cDist'] = as_cdists
loss_dict['val/RS-cDist'] = rs_cdists
return loss_dict
def _get_exp_name(self):
which_ds = self.hparams.get("test_which", 'pm')
is_pred_G = self.hparams.get("test_pred_G", False)
is_label_free = self.hparams.get("test_label_free", False)
guidance_scaler = self.hparams.get("guidance_scaler", 0)
# config saving directory
exp_postfix = f"_w={guidance_scaler}_{which_ds}"
if is_pred_G:
exp_postfix += "_pred_G"
if is_label_free:
exp_postfix += "_label_free"
exp_name = "epoch_" + str(self.current_epoch).zfill(3) + exp_postfix
return exp_name