DIPO / systems /system_origin.py
xinjjj's picture
Upload 29 files
ce34030 verified
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
import torch
import subprocess
import numpy as np
import models
import systems
import torch.nn.functional as F
from diffusers import DDPMScheduler
from systems.base import BaseSystem
from my_utils.lr_schedulers import LinearWarmupCosineAnnealingLR
from datetime import datetime
import logging
@systems.register("sys_origin")
class SingapoSystem(BaseSystem):
"""Trainer for the B9 model, incorporating the classifier-free for image condition."""
def __init__(self, hparams):
super().__init__(hparams)
self.model = models.make(hparams.model.name, hparams.model)
# configure the scheduler of DDPM
self.scheduler = DDPMScheduler(**self.hparams.scheduler.config)
# load the dummy DINO features
self.dummy_dino = np.load('systems/dino_dummy.npy').astype(np.float32)
# use the manual optimization
self.automatic_optimization = False
# save the hyperparameters
self.save_hyperparameters()
self.custom_logger = logging.getLogger(__name__)
self.custom_logger.setLevel(logging.INFO)
if self.global_rank == 0:
self.custom_logger.addHandler(logging.StreamHandler())
def load_cage_weights(self, pretrained_ckpt=None):
ckpt = torch.load(pretrained_ckpt)
state_dict = ckpt["state_dict"]
# remove the "model." prefix from the keys
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
# load the weights
self.model.load_state_dict(state_dict, strict=False)
# separate the weights of CAGE and our new modules
print("[INFO] loaded model weights of the pretrained CAGE.")
def fg_loss(self, all_attn_maps, loss_masks):
"""
Excite the attention maps within the object regions, while weaken the attention outside the object regions.
Args:
all_attn_maps: cross-attention maps from all layers, shape (B*L, H, 160, 256)
loss_masks: object seg mask on the image patches, shape (B, 160, 256)
Returns:
loss: loss on the attention maps
"""
valid_mask = loss_masks['valid_nodes']
fg_mask = loss_masks['fg']
# get the number of layers and batch size
L = self.hparams.model.n_layers
H = all_attn_maps.shape[1]
# Reshape all the masks to the shape of the attention maps
valid_node = valid_mask[:, :, 0].unsqueeze(1).expand(-1, H, -1).unsqueeze(-1).expand(-1, -1, -1, 256).repeat(L, 1, 1, 1)
obj_region = fg_mask.unsqueeze(1).expand(-1, H, -1, -1).repeat(L, 1, 1, 1)
# construct masks for the object and non-object regions
fg_region = torch.logical_and(valid_node, obj_region)
bg_region = torch.logical_and(valid_node, ~obj_region)
# loss to excite the foreground regions
loss = 1. - all_attn_maps[fg_region].mean() + all_attn_maps[bg_region].mean()
return loss
def diffuse_process(self, inputs):
x = inputs["x"]
# Sample Gaussian noise
noise = torch.randn(x.shape, device=self.device, dtype=x.dtype)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
self.scheduler.config.num_train_timesteps,
(x.shape[0],),
device=self.device,
dtype=torch.long,
)
# Add Gaussian noise to the input
noisy_x = self.scheduler.add_noise(x, noise, timesteps)
# update the inputs
inputs["noise"] = noise
inputs["timesteps"] = timesteps
inputs["noisy_x"] = noisy_x
def prepare_inputs(self, batch, mode='train', n_samples=1):
x, c, f = batch
cat = c["cat"] # object category
attr_mask = c["attr_mask"] # attention mask for local self-attention (follow the CAGE)
key_pad_mask = c["key_pad_mask"] # key padding mask for global self-attention (follow the CAGE)
graph_mask = c["adj_mask"] # attention mask for graph relation self-attention (follow the CAGE)
inputs = {}
if mode == 'train':
# the number of sampled timesteps per iteration
n_repeat = self.hparams.n_time_samples
# for sampling multiple timesteps
x = x.repeat(n_repeat, 1, 1)
cat = cat.repeat(n_repeat)
f = f.repeat(n_repeat, 1, 1)
key_pad_mask = key_pad_mask.repeat(n_repeat, 1, 1)
graph_mask = graph_mask.repeat(n_repeat, 1, 1)
attr_mask = attr_mask.repeat(n_repeat, 1, 1)
elif mode == 'val':
noisy_x = torch.randn(x.shape, device=x.device)
dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f)
inputs["noisy_x"] = noisy_x
inputs["dummy_f"] = dummy_f
elif mode == 'test':
# for sampling multiple outputs
x = x.repeat(n_samples, 1, 1)
cat = cat.repeat(n_samples)
f = f.repeat(n_samples, 1, 1)
key_pad_mask = key_pad_mask.repeat(n_samples, 1, 1)
graph_mask = graph_mask.repeat(n_samples, 1, 1)
attr_mask = attr_mask.repeat(n_samples, 1, 1)
noisy_x = torch.randn(x.shape, device=x.device)
dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f)
inputs["noisy_x"] = noisy_x
inputs["dummy_f"] = dummy_f.repeat(1, 2, 1)
else:
raise ValueError(f"Invalid mode: {mode}")
inputs["x"] = x
inputs["f"] = f
inputs["cat"] = cat
inputs["key_pad_mask"] = key_pad_mask
inputs["graph_mask"] = graph_mask
inputs["attr_mask"] = attr_mask
return inputs
def prepare_loss_mask(self, batch):
x, c, _ = batch
n_repeat = self.hparams.n_time_samples # the number of sampled timesteps per iteration
# mask on the image patches for the foreground regions
# mask_fg = c["img_obj_mask"]
# if mask_fg is not None:
# mask_fg = mask_fg.repeat(n_repeat, 1, 1)
# mask on the valid nodes
index_tensor = torch.arange(x.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0) # (1, N)
valid_nodes = index_tensor < (c['n_nodes'] * 5).unsqueeze(-1)
mask_valid_nodes = valid_nodes.unsqueeze(-1).expand_as(x)
mask_valid_nodes = mask_valid_nodes.repeat(n_repeat, 1, 1)
return {"fg": None, "valid_nodes": mask_valid_nodes}
def manage_cfg(self, inputs):
'''
Manage the classifier-free training for the image and graph condition.
The CFG for object category is managed by the model (i.e. the CombinedTimestepLabelEmbeddings module in norm1 for each attention block)
'''
img_drop_prob = self.hparams.get("img_drop_prob", 0.0)
graph_drop_prob = self.hparams.get("graph_drop_prob", 0.0)
drop_img, drop_graph = False, False
if img_drop_prob > 0.0:
drop_img = torch.rand(1) < img_drop_prob
if drop_img.item():
dummy_batch = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(inputs['f'])
inputs['f'] = dummy_batch # use the dummy DINO features
if graph_drop_prob > 0.0:
if not drop_img:
drop_graph = torch.rand(1) < graph_drop_prob
if drop_graph.item():
inputs['graph_mask'] = None # for varify the model only, replace with the below line later and retrain the model
# inputs['graph_mask'] = inputs['key_pad_mask'] # use the key padding mask
def compute_loss(self, batch, inputs, outputs):
loss_dict = {}
# loss_weight = self.hparams.get("loss_fg_weight", 1.0)
# prepare the loss masks
loss_masks = self.prepare_loss_mask(batch)
# diffusion model loss: MSE on the residual noise
loss_mse = F.mse_loss(outputs['noise_pred'] * loss_masks['valid_nodes'], inputs['noise'] * loss_masks['valid_nodes'])
# attention mask loss: BCE loss on the attention maps
# loss_fg = loss_weight * self.fg_loss(outputs['attn_maps'], loss_masks)
# total loss
loss = loss_mse
# log the losses
loss_dict["train/loss_mse"] = loss_mse
loss_dict["train/loss_total"] = loss
return loss, loss_dict
def training_step(self, batch, batch_idx):
# prepare the inputs and GT
inputs = self.prepare_inputs(batch, mode='train')
# manage the classifier-free training
self.manage_cfg(inputs)
# forward: diffusion process
self.diffuse_process(inputs)
# reverse: denoising process
outputs = self.model(
x=inputs['noisy_x'],
cat=inputs['cat'],
timesteps=inputs['timesteps'],
feat=inputs['f'],
key_pad_mask=inputs['key_pad_mask'],
graph_mask=inputs['graph_mask'],
attr_mask=inputs['attr_mask'],
)
# compute the loss
loss, loss_dict = self.compute_loss(batch, inputs, outputs)
# manual backward
opt1, opt2 = self.optimizers()
opt1.zero_grad()
opt2.zero_grad()
self.manual_backward(loss)
opt1.step()
opt2.step()
if batch_idx % 20 == 0 and self.global_rank == 0:
now = datetime.now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
loss_str = f'Epoch:{self.current_epoch} | Step:{batch_idx:03d} | '
for key, value in loss_dict.items():
loss_str += f"{key}: {value.item():.4f} | "
self.custom_logger.info(now_str + ' | ' + loss_str)
# logging
# self.log_dict(loss_dict, sync_dist=True, on_step=True, on_epoch=False)
def on_train_epoch_end(self):
# step the lr scheduler every epoch
sch1, sch2 = self.lr_schedulers()
sch1.step()
sch2.step()
def inference(self, inputs, is_label_free=False):
device = inputs['x'].device
omega = self.hparams.get("guidance_scaler", 0)
noisy_x = inputs['noisy_x']
# set scheduler to denoise every 100 steps
self.scheduler.set_timesteps(100)
# denoising process
for t in self.scheduler.timesteps:
timesteps = torch.tensor([t], device=device)
outputs_cond = self.model(
x=noisy_x,
cat=inputs['cat'],
timesteps=timesteps,
feat=inputs['f'],
key_pad_mask=inputs['key_pad_mask'],
graph_mask=inputs['graph_mask'],
attr_mask=inputs['attr_mask'],
label_free=is_label_free,
) # take condtional image as input
if omega != 0:
outputs_free = self.model(
x=noisy_x,
cat=inputs['cat'],
timesteps=timesteps,
feat=inputs['dummy_f'],
key_pad_mask=inputs['key_pad_mask'],
graph_mask=inputs['graph_mask'],
attr_mask=inputs['attr_mask'],
label_free=is_label_free,
) # take the dummy DINO features for the condition-free mode
noise_pred = (1 + omega) * outputs_cond['noise_pred'] - omega * outputs_free['noise_pred']
else:
noise_pred = outputs_cond['noise_pred']
noisy_x = self.scheduler.step(noise_pred, t, noisy_x).prev_sample
return noisy_x
def validation_step(self, batch, batch_idx):
# prepare the inputs and GT
inputs = self.prepare_inputs(batch, mode='val')
# denoising process for inference
out = self.inference(inputs)
# compute the metrics
# new_out = torch.zeros_like(out).type_as(out).to(out.device)
# for b in range(out.shape[0]):
# for k in range(32):
# if out[b][(k + 1) * 6 - 1].mean() > 0.5:
# new_out[b][k * 6: (k + 1) * 6] = out[b][k * 6: (k + 1) * 6]
# zero center
# rescale
# ready
# out = new_out
# new_out = torch.zeros_like(out).type_as(out).to(out.device)
# for b in range(out.shape[0]):
# for k in range(32):
# min_aabb_diff = 1e10
# min_index = k
# aabb_center = (out[b][k * 6][:3] + out[b][k * 6 ][3:]) / 2
# for k_gt in range(32):
# aabb_gt_center = (batch[1][b][k_gt * 6][:3] + batch[1][b][k_gt * 6][3:]) / 2
# aabb_diff = torch.norm(aabb_center - aabb_gt_center)
# if aabb_diff < min_aabb_diff:
# min_aabb_diff = aabb_diff
# min_index = k_gt
# new_out[b][min_index * 6: (min_index + 1) * 6] = out[b][k * 6: (k + 1) * 6]
# out = new_out
log_dict = self.val_compute_metrics(out, inputs['x'], batch[1])
self.log_dict(log_dict, on_step=True)
# visualize the first 10 results
# self.save_val_img(out[:16], inputs['x'][:16], batch[1])
def test_step(self, batch, batch_idx):
# exp_name = self._get_exp_name()
# print(self.get_save_path(exp_name))
# if batch_idx > 2:
# return
# return
is_label_free = self.hparams.get("test_label_free", False)
exp_name = self._get_exp_name()
model_name = batch[1]["name"][0].replace("/", '@')
save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}"
print(save_dir)
if os.path.exists(self.get_save_path(f"{save_dir}/output.png")):
return
# prepare the inputs and GT
inputs = self.prepare_inputs(batch, mode='test', n_samples=5)
# denoising process for inference
out = self.inference(inputs, is_label_free)
# save the results
self.save_test_step(out, inputs['x'], batch[1], batch_idx)
def on_test_end(self):
# only run the single GPU
# if self.global_rank == 0:
# exp_name = self._get_exp_name()
# # retrieve parts
# subprocess.run(['python', 'scripts/mesh_retrieval/run_retrieve.py', '--src', self.get_save_path(exp_name), '--json_name', 'object.json', '--gt_data_root', '../singapo'])
# # save metrics
# if not self.hparams.get("test_no_GT", False):
# subprocess.run(['python', 'scripts/eval_metrics.py', '--exp_dir', self.get_save_path(exp_name), '--gt_root', '../acd_data/'])
# # save html
# self._save_html_end()
pass
def configure_optimizers(self):
self.cage_params = self.adapter_params = []
for name, param in self.model.named_parameters():
if "img" in name or "norm5" in name or "norm6" in name:
self.adapter_params.append(param)
else:
self.cage_params.append(param)
optimizer_adapter = torch.optim.AdamW(
self.adapter_params, **self.hparams.optimizer_adapter.args
)
lr_scheduler_adapter = LinearWarmupCosineAnnealingLR(
optimizer_adapter,
warmup_epochs=self.hparams.lr_scheduler_adapter.warmup_epochs,
max_epochs=self.hparams.lr_scheduler_adapter.max_epochs,
warmup_start_lr=self.hparams.lr_scheduler_adapter.warmup_start_lr,
eta_min=self.hparams.lr_scheduler_adapter.eta_min,
)
optimizer_cage = torch.optim.AdamW(
self.cage_params, **self.hparams.optimizer_cage.args
)
lr_scheduler_cage = LinearWarmupCosineAnnealingLR(
optimizer_cage,
warmup_epochs=self.hparams.lr_scheduler_cage.warmup_epochs,
max_epochs=self.hparams.lr_scheduler_cage.max_epochs,
warmup_start_lr=self.hparams.lr_scheduler_cage.warmup_start_lr,
eta_min=self.hparams.lr_scheduler_cage.eta_min,
)
return (
{"optimizer": optimizer_adapter, "lr_scheduler": lr_scheduler_adapter},
{"optimizer": optimizer_cage, "lr_scheduler": lr_scheduler_cage},
)