Spaces:
Running
on
Zero
Running
on
Zero
| import os, sys | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) | |
| import torch | |
| import subprocess | |
| import numpy as np | |
| import models | |
| import systems | |
| import torch.nn.functional as F | |
| from diffusers import DDPMScheduler | |
| from systems.base import BaseSystem | |
| from my_utils.lr_schedulers import LinearWarmupCosineAnnealingLR | |
| from datetime import datetime | |
| import logging | |
| class SingapoSystem(BaseSystem): | |
| """Trainer for the B9 model, incorporating the classifier-free for image condition.""" | |
| def __init__(self, hparams): | |
| super().__init__(hparams) | |
| self.model = models.make(hparams.model.name, hparams.model) | |
| # configure the scheduler of DDPM | |
| self.scheduler = DDPMScheduler(**self.hparams.scheduler.config) | |
| # load the dummy DINO features | |
| self.dummy_dino = np.load('systems/dino_dummy.npy').astype(np.float32) | |
| # use the manual optimization | |
| self.automatic_optimization = False | |
| # save the hyperparameters | |
| self.save_hyperparameters() | |
| self.custom_logger = logging.getLogger(__name__) | |
| self.custom_logger.setLevel(logging.INFO) | |
| if self.global_rank == 0: | |
| self.custom_logger.addHandler(logging.StreamHandler()) | |
| def load_cage_weights(self, pretrained_ckpt=None): | |
| ckpt = torch.load(pretrained_ckpt) | |
| state_dict = ckpt["state_dict"] | |
| # remove the "model." prefix from the keys | |
| state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} | |
| # load the weights | |
| self.model.load_state_dict(state_dict, strict=False) | |
| # separate the weights of CAGE and our new modules | |
| print("[INFO] loaded model weights of the pretrained CAGE.") | |
| def fg_loss(self, all_attn_maps, loss_masks): | |
| """ | |
| Excite the attention maps within the object regions, while weaken the attention outside the object regions. | |
| Args: | |
| all_attn_maps: cross-attention maps from all layers, shape (B*L, H, 160, 256) | |
| loss_masks: object seg mask on the image patches, shape (B, 160, 256) | |
| Returns: | |
| loss: loss on the attention maps | |
| """ | |
| valid_mask = loss_masks['valid_nodes'] | |
| fg_mask = loss_masks['fg'] | |
| # get the number of layers and batch size | |
| L = self.hparams.model.n_layers | |
| H = all_attn_maps.shape[1] | |
| # Reshape all the masks to the shape of the attention maps | |
| valid_node = valid_mask[:, :, 0].unsqueeze(1).expand(-1, H, -1).unsqueeze(-1).expand(-1, -1, -1, 256).repeat(L, 1, 1, 1) | |
| obj_region = fg_mask.unsqueeze(1).expand(-1, H, -1, -1).repeat(L, 1, 1, 1) | |
| # construct masks for the object and non-object regions | |
| fg_region = torch.logical_and(valid_node, obj_region) | |
| bg_region = torch.logical_and(valid_node, ~obj_region) | |
| # loss to excite the foreground regions | |
| loss = 1. - all_attn_maps[fg_region].mean() + all_attn_maps[bg_region].mean() | |
| return loss | |
| def diffuse_process(self, inputs): | |
| x = inputs["x"] | |
| # Sample Gaussian noise | |
| noise = torch.randn(x.shape, device=self.device, dtype=x.dtype) | |
| # Sample a random timestep for each image | |
| timesteps = torch.randint( | |
| 0, | |
| self.scheduler.config.num_train_timesteps, | |
| (x.shape[0],), | |
| device=self.device, | |
| dtype=torch.long, | |
| ) | |
| # Add Gaussian noise to the input | |
| noisy_x = self.scheduler.add_noise(x, noise, timesteps) | |
| # update the inputs | |
| inputs["noise"] = noise | |
| inputs["timesteps"] = timesteps | |
| inputs["noisy_x"] = noisy_x | |
| def prepare_inputs(self, batch, mode='train', n_samples=1): | |
| x, c, f = batch | |
| cat = c["cat"] # object category | |
| attr_mask = c["attr_mask"] # attention mask for local self-attention (follow the CAGE) | |
| key_pad_mask = c["key_pad_mask"] # key padding mask for global self-attention (follow the CAGE) | |
| graph_mask = c["adj_mask"] # attention mask for graph relation self-attention (follow the CAGE) | |
| inputs = {} | |
| if mode == 'train': | |
| # the number of sampled timesteps per iteration | |
| n_repeat = self.hparams.n_time_samples | |
| # for sampling multiple timesteps | |
| x = x.repeat(n_repeat, 1, 1) | |
| cat = cat.repeat(n_repeat) | |
| f = f.repeat(n_repeat, 1, 1) | |
| key_pad_mask = key_pad_mask.repeat(n_repeat, 1, 1) | |
| graph_mask = graph_mask.repeat(n_repeat, 1, 1) | |
| attr_mask = attr_mask.repeat(n_repeat, 1, 1) | |
| elif mode == 'val': | |
| noisy_x = torch.randn(x.shape, device=x.device) | |
| dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f) | |
| inputs["noisy_x"] = noisy_x | |
| inputs["dummy_f"] = dummy_f | |
| elif mode == 'test': | |
| # for sampling multiple outputs | |
| x = x.repeat(n_samples, 1, 1) | |
| cat = cat.repeat(n_samples) | |
| f = f.repeat(n_samples, 1, 1) | |
| key_pad_mask = key_pad_mask.repeat(n_samples, 1, 1) | |
| graph_mask = graph_mask.repeat(n_samples, 1, 1) | |
| attr_mask = attr_mask.repeat(n_samples, 1, 1) | |
| noisy_x = torch.randn(x.shape, device=x.device) | |
| dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f) | |
| inputs["noisy_x"] = noisy_x | |
| inputs["dummy_f"] = dummy_f.repeat(1, 2, 1) | |
| else: | |
| raise ValueError(f"Invalid mode: {mode}") | |
| inputs["x"] = x | |
| inputs["f"] = f | |
| inputs["cat"] = cat | |
| inputs["key_pad_mask"] = key_pad_mask | |
| inputs["graph_mask"] = graph_mask | |
| inputs["attr_mask"] = attr_mask | |
| return inputs | |
| def prepare_loss_mask(self, batch): | |
| x, c, _ = batch | |
| n_repeat = self.hparams.n_time_samples # the number of sampled timesteps per iteration | |
| # mask on the image patches for the foreground regions | |
| # mask_fg = c["img_obj_mask"] | |
| # if mask_fg is not None: | |
| # mask_fg = mask_fg.repeat(n_repeat, 1, 1) | |
| # mask on the valid nodes | |
| index_tensor = torch.arange(x.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0) # (1, N) | |
| valid_nodes = index_tensor < (c['n_nodes'] * 5).unsqueeze(-1) | |
| mask_valid_nodes = valid_nodes.unsqueeze(-1).expand_as(x) | |
| mask_valid_nodes = mask_valid_nodes.repeat(n_repeat, 1, 1) | |
| return {"fg": None, "valid_nodes": mask_valid_nodes} | |
| def manage_cfg(self, inputs): | |
| ''' | |
| Manage the classifier-free training for the image and graph condition. | |
| The CFG for object category is managed by the model (i.e. the CombinedTimestepLabelEmbeddings module in norm1 for each attention block) | |
| ''' | |
| img_drop_prob = self.hparams.get("img_drop_prob", 0.0) | |
| graph_drop_prob = self.hparams.get("graph_drop_prob", 0.0) | |
| drop_img, drop_graph = False, False | |
| if img_drop_prob > 0.0: | |
| drop_img = torch.rand(1) < img_drop_prob | |
| if drop_img.item(): | |
| dummy_batch = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(inputs['f']) | |
| inputs['f'] = dummy_batch # use the dummy DINO features | |
| if graph_drop_prob > 0.0: | |
| if not drop_img: | |
| drop_graph = torch.rand(1) < graph_drop_prob | |
| if drop_graph.item(): | |
| inputs['graph_mask'] = None # for varify the model only, replace with the below line later and retrain the model | |
| # inputs['graph_mask'] = inputs['key_pad_mask'] # use the key padding mask | |
| def compute_loss(self, batch, inputs, outputs): | |
| loss_dict = {} | |
| # loss_weight = self.hparams.get("loss_fg_weight", 1.0) | |
| # prepare the loss masks | |
| loss_masks = self.prepare_loss_mask(batch) | |
| # diffusion model loss: MSE on the residual noise | |
| loss_mse = F.mse_loss(outputs['noise_pred'] * loss_masks['valid_nodes'], inputs['noise'] * loss_masks['valid_nodes']) | |
| # attention mask loss: BCE loss on the attention maps | |
| # loss_fg = loss_weight * self.fg_loss(outputs['attn_maps'], loss_masks) | |
| # total loss | |
| loss = loss_mse | |
| # log the losses | |
| loss_dict["train/loss_mse"] = loss_mse | |
| loss_dict["train/loss_total"] = loss | |
| return loss, loss_dict | |
| def training_step(self, batch, batch_idx): | |
| # prepare the inputs and GT | |
| inputs = self.prepare_inputs(batch, mode='train') | |
| # manage the classifier-free training | |
| self.manage_cfg(inputs) | |
| # forward: diffusion process | |
| self.diffuse_process(inputs) | |
| # reverse: denoising process | |
| outputs = self.model( | |
| x=inputs['noisy_x'], | |
| cat=inputs['cat'], | |
| timesteps=inputs['timesteps'], | |
| feat=inputs['f'], | |
| key_pad_mask=inputs['key_pad_mask'], | |
| graph_mask=inputs['graph_mask'], | |
| attr_mask=inputs['attr_mask'], | |
| ) | |
| # compute the loss | |
| loss, loss_dict = self.compute_loss(batch, inputs, outputs) | |
| # manual backward | |
| opt1, opt2 = self.optimizers() | |
| opt1.zero_grad() | |
| opt2.zero_grad() | |
| self.manual_backward(loss) | |
| opt1.step() | |
| opt2.step() | |
| if batch_idx % 20 == 0 and self.global_rank == 0: | |
| now = datetime.now() | |
| now_str = now.strftime("%Y-%m-%d %H:%M:%S") | |
| loss_str = f'Epoch:{self.current_epoch} | Step:{batch_idx:03d} | ' | |
| for key, value in loss_dict.items(): | |
| loss_str += f"{key}: {value.item():.4f} | " | |
| self.custom_logger.info(now_str + ' | ' + loss_str) | |
| # logging | |
| # self.log_dict(loss_dict, sync_dist=True, on_step=True, on_epoch=False) | |
| def on_train_epoch_end(self): | |
| # step the lr scheduler every epoch | |
| sch1, sch2 = self.lr_schedulers() | |
| sch1.step() | |
| sch2.step() | |
| def inference(self, inputs, is_label_free=False): | |
| device = inputs['x'].device | |
| omega = self.hparams.get("guidance_scaler", 0) | |
| noisy_x = inputs['noisy_x'] | |
| # set scheduler to denoise every 100 steps | |
| self.scheduler.set_timesteps(100) | |
| # denoising process | |
| for t in self.scheduler.timesteps: | |
| timesteps = torch.tensor([t], device=device) | |
| outputs_cond = self.model( | |
| x=noisy_x, | |
| cat=inputs['cat'], | |
| timesteps=timesteps, | |
| feat=inputs['f'], | |
| key_pad_mask=inputs['key_pad_mask'], | |
| graph_mask=inputs['graph_mask'], | |
| attr_mask=inputs['attr_mask'], | |
| label_free=is_label_free, | |
| ) # take condtional image as input | |
| if omega != 0: | |
| outputs_free = self.model( | |
| x=noisy_x, | |
| cat=inputs['cat'], | |
| timesteps=timesteps, | |
| feat=inputs['dummy_f'], | |
| key_pad_mask=inputs['key_pad_mask'], | |
| graph_mask=inputs['graph_mask'], | |
| attr_mask=inputs['attr_mask'], | |
| label_free=is_label_free, | |
| ) # take the dummy DINO features for the condition-free mode | |
| noise_pred = (1 + omega) * outputs_cond['noise_pred'] - omega * outputs_free['noise_pred'] | |
| else: | |
| noise_pred = outputs_cond['noise_pred'] | |
| noisy_x = self.scheduler.step(noise_pred, t, noisy_x).prev_sample | |
| return noisy_x | |
| def validation_step(self, batch, batch_idx): | |
| # prepare the inputs and GT | |
| inputs = self.prepare_inputs(batch, mode='val') | |
| # denoising process for inference | |
| out = self.inference(inputs) | |
| # compute the metrics | |
| # new_out = torch.zeros_like(out).type_as(out).to(out.device) | |
| # for b in range(out.shape[0]): | |
| # for k in range(32): | |
| # if out[b][(k + 1) * 6 - 1].mean() > 0.5: | |
| # new_out[b][k * 6: (k + 1) * 6] = out[b][k * 6: (k + 1) * 6] | |
| # zero center | |
| # rescale | |
| # ready | |
| # out = new_out | |
| # new_out = torch.zeros_like(out).type_as(out).to(out.device) | |
| # for b in range(out.shape[0]): | |
| # for k in range(32): | |
| # min_aabb_diff = 1e10 | |
| # min_index = k | |
| # aabb_center = (out[b][k * 6][:3] + out[b][k * 6 ][3:]) / 2 | |
| # for k_gt in range(32): | |
| # aabb_gt_center = (batch[1][b][k_gt * 6][:3] + batch[1][b][k_gt * 6][3:]) / 2 | |
| # aabb_diff = torch.norm(aabb_center - aabb_gt_center) | |
| # if aabb_diff < min_aabb_diff: | |
| # min_aabb_diff = aabb_diff | |
| # min_index = k_gt | |
| # new_out[b][min_index * 6: (min_index + 1) * 6] = out[b][k * 6: (k + 1) * 6] | |
| # out = new_out | |
| log_dict = self.val_compute_metrics(out, inputs['x'], batch[1]) | |
| self.log_dict(log_dict, on_step=True) | |
| # visualize the first 10 results | |
| # self.save_val_img(out[:16], inputs['x'][:16], batch[1]) | |
| def test_step(self, batch, batch_idx): | |
| # exp_name = self._get_exp_name() | |
| # print(self.get_save_path(exp_name)) | |
| # if batch_idx > 2: | |
| # return | |
| # return | |
| is_label_free = self.hparams.get("test_label_free", False) | |
| exp_name = self._get_exp_name() | |
| model_name = batch[1]["name"][0].replace("/", '@') | |
| save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}" | |
| print(save_dir) | |
| if os.path.exists(self.get_save_path(f"{save_dir}/output.png")): | |
| return | |
| # prepare the inputs and GT | |
| inputs = self.prepare_inputs(batch, mode='test', n_samples=5) | |
| # denoising process for inference | |
| out = self.inference(inputs, is_label_free) | |
| # save the results | |
| self.save_test_step(out, inputs['x'], batch[1], batch_idx) | |
| def on_test_end(self): | |
| # only run the single GPU | |
| # if self.global_rank == 0: | |
| # exp_name = self._get_exp_name() | |
| # # retrieve parts | |
| # subprocess.run(['python', 'scripts/mesh_retrieval/run_retrieve.py', '--src', self.get_save_path(exp_name), '--json_name', 'object.json', '--gt_data_root', '../singapo']) | |
| # # save metrics | |
| # if not self.hparams.get("test_no_GT", False): | |
| # subprocess.run(['python', 'scripts/eval_metrics.py', '--exp_dir', self.get_save_path(exp_name), '--gt_root', '../acd_data/']) | |
| # # save html | |
| # self._save_html_end() | |
| pass | |
| def configure_optimizers(self): | |
| self.cage_params = self.adapter_params = [] | |
| for name, param in self.model.named_parameters(): | |
| if "img" in name or "norm5" in name or "norm6" in name: | |
| self.adapter_params.append(param) | |
| else: | |
| self.cage_params.append(param) | |
| optimizer_adapter = torch.optim.AdamW( | |
| self.adapter_params, **self.hparams.optimizer_adapter.args | |
| ) | |
| lr_scheduler_adapter = LinearWarmupCosineAnnealingLR( | |
| optimizer_adapter, | |
| warmup_epochs=self.hparams.lr_scheduler_adapter.warmup_epochs, | |
| max_epochs=self.hparams.lr_scheduler_adapter.max_epochs, | |
| warmup_start_lr=self.hparams.lr_scheduler_adapter.warmup_start_lr, | |
| eta_min=self.hparams.lr_scheduler_adapter.eta_min, | |
| ) | |
| optimizer_cage = torch.optim.AdamW( | |
| self.cage_params, **self.hparams.optimizer_cage.args | |
| ) | |
| lr_scheduler_cage = LinearWarmupCosineAnnealingLR( | |
| optimizer_cage, | |
| warmup_epochs=self.hparams.lr_scheduler_cage.warmup_epochs, | |
| max_epochs=self.hparams.lr_scheduler_cage.max_epochs, | |
| warmup_start_lr=self.hparams.lr_scheduler_cage.warmup_start_lr, | |
| eta_min=self.hparams.lr_scheduler_cage.eta_min, | |
| ) | |
| return ( | |
| {"optimizer": optimizer_adapter, "lr_scheduler": lr_scheduler_adapter}, | |
| {"optimizer": optimizer_cage, "lr_scheduler": lr_scheduler_cage}, | |
| ) | |