DIPO / metrics /cd.py
xinjie.wang
init commit
c28dddb
"""
This script computes the Chamfer Distance (CD) between two objects\n
"""
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import numpy as np
import trimesh
from copy import deepcopy
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance
from objects.motions import transform_all_parts
from objects.dict_utils import (
zero_center_object,
rescale_object,
compute_overall_bbox_size,
get_base_part_idx,
find_part_mapping
)
def _load_and_combine_plys(dir, ply_files, scale=None, z_rotate=None, translate=None):
"""
Load and combine the ply files into one PyTorch3D mesh
- dir: the directory of the object in which the ply files are from\n
- ply_files: the list of ply files\n
- scale: the scale factor to apply to the vertices\n
- z_rotate: whether to rotate the object around the z-axis by 90 degrees\n
- translate: the translation to apply to the vertices\n
Return:\n
- mesh: one PyTorch3D mesh of the combined ply files
"""
# Combine the ply files into one
meshes = []
for ply_file in ply_files:
meshes.append(trimesh.load(os.path.join(dir, ply_file), force="mesh"))
full_part_mesh = trimesh.util.concatenate(meshes)
# Apply the transformations
full_part_mesh.vertices -= full_part_mesh.bounding_box.centroid
transformation = trimesh.transformations.compose_matrix(
scale=scale,
angles=[0, 0, np.radians(90) if z_rotate else 0],
translate=translate,
)
full_part_mesh.apply_transform(transformation)
# Create the PyTorch3D mesh
mesh = Meshes(
verts=torch.as_tensor(full_part_mesh.vertices, dtype=torch.float32, device='cuda').unsqueeze(
0
),
faces=torch.as_tensor(full_part_mesh.faces, dtype=torch.int32, device='cuda').unsqueeze(0),
)
return mesh
def _compute_chamfer_distance(
obj1_part_points, obj2_part_points, part_mapping=None, exclude_id=-1
):
"""
Compute the chamfer distance between the two set of points representing the two objects
- obj1_part_points: the set of points representing the first object\n
- obj2_part_points: the set of points representing the second object\n
- part_mapping (optional): the part mapping from the first object to the second object, if provided, the chamfer distance will be computed between the corresponding parts\n
- exclude_id (optional): the part id to exclude from the chamfer distance computation, the default if provided is the base part id\n
Return:\n
- distance: the chamfer distance between the two objects
"""
if part_mapping is not None:
n_parts = part_mapping.shape[0]
distance = 0
for i in range(n_parts):
if i == exclude_id:
continue
obj1_part_points_i = obj1_part_points[i]
obj2_part_points_i = obj2_part_points[int(part_mapping[i, 0])]
with torch.no_grad():
obj1_part_points_i = obj1_part_points_i.cuda()
obj2_part_points_i = obj2_part_points_i.cuda()
# symmetric chamfer distance
forward_distance, _ = chamfer_distance(
obj1_part_points_i[None, :],
obj2_part_points_i[None, :],
batch_reduction=None,
)
backward_distance, _ = chamfer_distance(
obj2_part_points_i[None, :],
obj1_part_points_i[None, :],
batch_reduction=None,
)
distance += (forward_distance.item() + backward_distance.item()) * 0.5
distance /= n_parts
else:
# Merge the points of all parts into one tensor
obj1_part_points = obj1_part_points.reshape(-1, 3)
obj2_part_points = obj2_part_points.reshape(-1, 3)
# Compute the chamfer distance between the two objects
with torch.no_grad():
obj1_part_points = obj1_part_points.cuda()
obj2_part_points = obj2_part_points.cuda()
forward_distance, _ = chamfer_distance(
obj1_part_points[None, :],
obj2_part_points[None, :],
batch_reduction=None,
)
backward_distance, _ = chamfer_distance(
obj2_part_points[None, :],
obj1_part_points[None, :],
batch_reduction=None,
)
distance = (forward_distance.item() + backward_distance.item()) * 0.5
return distance
def _get_scores(
src_dict,
tgt_dict,
original_src_part_points,
original_tgt_part_points,
part_mapping,
num_states,
include_base,
src_base_idx,
):
chamfer_distances = np.zeros(num_states, dtype=np.float32)
joint_states = np.linspace(0, 1, num_states)
for state_idx, state in enumerate(joint_states):
# Reset the part point clouds
src_part_points = deepcopy(original_src_part_points)
tgt_part_points = deepcopy(original_tgt_part_points)
# Transform the part point clouds to the current state using the joints
transform_all_parts(src_part_points.numpy(), src_dict, state, dry_run=False)
transform_all_parts(tgt_part_points.numpy(), tgt_dict, state, dry_run=False)
# Compute the chamfer distance between the two objects
chamfer_distances[state_idx] = _compute_chamfer_distance(
src_part_points,
tgt_part_points,
part_mapping=part_mapping,
exclude_id=-1 if include_base else src_base_idx,
)
# Compute the ID
aid_cd = np.mean(chamfer_distances)
rid_cd = chamfer_distances[0]
return {
"AS-CD": float(aid_cd),
"RS-CD": float(rid_cd),
}
def CD(
gen_obj_dict,
gen_obj_path,
gt_obj_dict,
gt_obj_path,
num_states=5,
num_samples=2048,
include_base=False,
):
"""
Compute the Chamfer Distance\n
This metric is the average of per-part chamfer distance between the two objects over a number of articulation states\n
- gen_obj_dict: the generated object dictionary\n
- gen_obj_path: the directory to the predicted object\n
- gt_obj_dict: the ground truth object dictionary\n
- gt_obj_path: the directory to the ground truth object\n
- num_states (optional): the number of articulation states to compute the metric\n
- num_samples (optional): the number of samples to use\n
- include_base (optional): whether to include the base part in the chamfer distance computation\n
Return:\n
- aid_score: the score over the sampled articulated states\n
- rid_score: the score at the resting state\n
- The score is in the range of [0, inf), lower is better
"""
# Make copies of the dictionaries to avoid modifying the original dictionaries
gen_dict = deepcopy(gen_obj_dict)
gt_dict = deepcopy(gt_obj_dict)
# Zero center the objects
zero_center_object(gen_dict)
zero_center_object(gt_dict)
# Compute the scale factor by comparing the overall bbox size and scale the candidate object as a whole
gen_bbox_size = compute_overall_bbox_size(gen_dict)
gt_bbox_size = compute_overall_bbox_size(gt_dict)
scale_factor = gen_bbox_size / gt_bbox_size
rescale_object(gen_obj_dict, scale_factor)
# Record the indices of the base parts of the two objects
gen_base_idx = get_base_part_idx(gen_dict)
gt_base_idx = get_base_part_idx(gt_dict)
# Find mapping between the parts of the two objects based on closest bbox centers
mapping_gen2gt = find_part_mapping(gen_dict, gt_dict, use_hungarian=True)
mapping_gt2gen = find_part_mapping(gt_dict, gen_dict, use_hungarian=True)
# Get the number of parts of the two objects
gen_tree = gen_dict["diffuse_tree"]
gt_tree = gt_dict["diffuse_tree"]
gen_num_parts = len(gen_tree)
gt_num_parts = len(gt_tree)
# Get the paths of the ply files of the two objects
gen_part_ply_paths = [
{"dir": gen_obj_path, "files": gen_tree[i]["plys"]}
for i in range(gen_num_parts)
]
gt_part_ply_paths = [
{"dir": gt_obj_path, "files": gt_tree[i]["plys"]}
for i in range(gt_num_parts)
]
# Load the ply files of the two objects and sample points from them
gen_part_points = torch.zeros(
(gen_num_parts, num_samples, 3), dtype=torch.float32
)
for i in range(gen_num_parts):
part_mesh = _load_and_combine_plys(
gen_part_ply_paths[i]["dir"],
gen_part_ply_paths[i]["files"],
scale=scale_factor,
translate=gen_tree[i]["aabb"]["center"],
)
gen_part_points[i] = sample_points_from_meshes(
part_mesh, num_samples=num_samples
).squeeze(0).cpu()
gt_part_points = torch.zeros(
(gt_num_parts, num_samples, 3), dtype=torch.float32
)
for i in range(gt_num_parts):
part_mesh = _load_and_combine_plys(
gt_part_ply_paths[i]["dir"],
gt_part_ply_paths[i]["files"],
translate=gt_tree[i]["aabb"]["center"],
)
gt_part_points[i] = sample_points_from_meshes(
part_mesh, num_samples=num_samples
).squeeze(0).cpu()
cd_gen2gt = _get_scores(
gen_dict,
gt_dict,
gen_part_points,
gt_part_points,
mapping_gen2gt,
num_states,
include_base,
gen_base_idx,
)
cd_gt2gen = _get_scores(
gt_dict,
gen_dict,
gt_part_points,
gen_part_points,
mapping_gt2gen,
num_states,
include_base,
gt_base_idx,
)
return {
"AS-CD": (cd_gen2gt["AS-CD"] + cd_gt2gen["AS-CD"]) / 2,
"RS-CD": (cd_gen2gt["RS-CD"] + cd_gt2gen["RS-CD"]) / 2,
}