Spaces:
Running
on
Zero
Running
on
Zero
File size: 16,509 Bytes
ce34030 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 |
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
import torch
import subprocess
import numpy as np
import models
import systems
import torch.nn.functional as F
from diffusers import DDPMScheduler
from systems.base import BaseSystem
from my_utils.lr_schedulers import LinearWarmupCosineAnnealingLR
from datetime import datetime
import logging
@systems.register("sys_origin")
class SingapoSystem(BaseSystem):
"""Trainer for the B9 model, incorporating the classifier-free for image condition."""
def __init__(self, hparams):
super().__init__(hparams)
self.model = models.make(hparams.model.name, hparams.model)
# configure the scheduler of DDPM
self.scheduler = DDPMScheduler(**self.hparams.scheduler.config)
# load the dummy DINO features
self.dummy_dino = np.load('systems/dino_dummy.npy').astype(np.float32)
# use the manual optimization
self.automatic_optimization = False
# save the hyperparameters
self.save_hyperparameters()
self.custom_logger = logging.getLogger(__name__)
self.custom_logger.setLevel(logging.INFO)
if self.global_rank == 0:
self.custom_logger.addHandler(logging.StreamHandler())
def load_cage_weights(self, pretrained_ckpt=None):
ckpt = torch.load(pretrained_ckpt)
state_dict = ckpt["state_dict"]
# remove the "model." prefix from the keys
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
# load the weights
self.model.load_state_dict(state_dict, strict=False)
# separate the weights of CAGE and our new modules
print("[INFO] loaded model weights of the pretrained CAGE.")
def fg_loss(self, all_attn_maps, loss_masks):
"""
Excite the attention maps within the object regions, while weaken the attention outside the object regions.
Args:
all_attn_maps: cross-attention maps from all layers, shape (B*L, H, 160, 256)
loss_masks: object seg mask on the image patches, shape (B, 160, 256)
Returns:
loss: loss on the attention maps
"""
valid_mask = loss_masks['valid_nodes']
fg_mask = loss_masks['fg']
# get the number of layers and batch size
L = self.hparams.model.n_layers
H = all_attn_maps.shape[1]
# Reshape all the masks to the shape of the attention maps
valid_node = valid_mask[:, :, 0].unsqueeze(1).expand(-1, H, -1).unsqueeze(-1).expand(-1, -1, -1, 256).repeat(L, 1, 1, 1)
obj_region = fg_mask.unsqueeze(1).expand(-1, H, -1, -1).repeat(L, 1, 1, 1)
# construct masks for the object and non-object regions
fg_region = torch.logical_and(valid_node, obj_region)
bg_region = torch.logical_and(valid_node, ~obj_region)
# loss to excite the foreground regions
loss = 1. - all_attn_maps[fg_region].mean() + all_attn_maps[bg_region].mean()
return loss
def diffuse_process(self, inputs):
x = inputs["x"]
# Sample Gaussian noise
noise = torch.randn(x.shape, device=self.device, dtype=x.dtype)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
self.scheduler.config.num_train_timesteps,
(x.shape[0],),
device=self.device,
dtype=torch.long,
)
# Add Gaussian noise to the input
noisy_x = self.scheduler.add_noise(x, noise, timesteps)
# update the inputs
inputs["noise"] = noise
inputs["timesteps"] = timesteps
inputs["noisy_x"] = noisy_x
def prepare_inputs(self, batch, mode='train', n_samples=1):
x, c, f = batch
cat = c["cat"] # object category
attr_mask = c["attr_mask"] # attention mask for local self-attention (follow the CAGE)
key_pad_mask = c["key_pad_mask"] # key padding mask for global self-attention (follow the CAGE)
graph_mask = c["adj_mask"] # attention mask for graph relation self-attention (follow the CAGE)
inputs = {}
if mode == 'train':
# the number of sampled timesteps per iteration
n_repeat = self.hparams.n_time_samples
# for sampling multiple timesteps
x = x.repeat(n_repeat, 1, 1)
cat = cat.repeat(n_repeat)
f = f.repeat(n_repeat, 1, 1)
key_pad_mask = key_pad_mask.repeat(n_repeat, 1, 1)
graph_mask = graph_mask.repeat(n_repeat, 1, 1)
attr_mask = attr_mask.repeat(n_repeat, 1, 1)
elif mode == 'val':
noisy_x = torch.randn(x.shape, device=x.device)
dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f)
inputs["noisy_x"] = noisy_x
inputs["dummy_f"] = dummy_f
elif mode == 'test':
# for sampling multiple outputs
x = x.repeat(n_samples, 1, 1)
cat = cat.repeat(n_samples)
f = f.repeat(n_samples, 1, 1)
key_pad_mask = key_pad_mask.repeat(n_samples, 1, 1)
graph_mask = graph_mask.repeat(n_samples, 1, 1)
attr_mask = attr_mask.repeat(n_samples, 1, 1)
noisy_x = torch.randn(x.shape, device=x.device)
dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f)
inputs["noisy_x"] = noisy_x
inputs["dummy_f"] = dummy_f.repeat(1, 2, 1)
else:
raise ValueError(f"Invalid mode: {mode}")
inputs["x"] = x
inputs["f"] = f
inputs["cat"] = cat
inputs["key_pad_mask"] = key_pad_mask
inputs["graph_mask"] = graph_mask
inputs["attr_mask"] = attr_mask
return inputs
def prepare_loss_mask(self, batch):
x, c, _ = batch
n_repeat = self.hparams.n_time_samples # the number of sampled timesteps per iteration
# mask on the image patches for the foreground regions
# mask_fg = c["img_obj_mask"]
# if mask_fg is not None:
# mask_fg = mask_fg.repeat(n_repeat, 1, 1)
# mask on the valid nodes
index_tensor = torch.arange(x.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0) # (1, N)
valid_nodes = index_tensor < (c['n_nodes'] * 5).unsqueeze(-1)
mask_valid_nodes = valid_nodes.unsqueeze(-1).expand_as(x)
mask_valid_nodes = mask_valid_nodes.repeat(n_repeat, 1, 1)
return {"fg": None, "valid_nodes": mask_valid_nodes}
def manage_cfg(self, inputs):
'''
Manage the classifier-free training for the image and graph condition.
The CFG for object category is managed by the model (i.e. the CombinedTimestepLabelEmbeddings module in norm1 for each attention block)
'''
img_drop_prob = self.hparams.get("img_drop_prob", 0.0)
graph_drop_prob = self.hparams.get("graph_drop_prob", 0.0)
drop_img, drop_graph = False, False
if img_drop_prob > 0.0:
drop_img = torch.rand(1) < img_drop_prob
if drop_img.item():
dummy_batch = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(inputs['f'])
inputs['f'] = dummy_batch # use the dummy DINO features
if graph_drop_prob > 0.0:
if not drop_img:
drop_graph = torch.rand(1) < graph_drop_prob
if drop_graph.item():
inputs['graph_mask'] = None # for varify the model only, replace with the below line later and retrain the model
# inputs['graph_mask'] = inputs['key_pad_mask'] # use the key padding mask
def compute_loss(self, batch, inputs, outputs):
loss_dict = {}
# loss_weight = self.hparams.get("loss_fg_weight", 1.0)
# prepare the loss masks
loss_masks = self.prepare_loss_mask(batch)
# diffusion model loss: MSE on the residual noise
loss_mse = F.mse_loss(outputs['noise_pred'] * loss_masks['valid_nodes'], inputs['noise'] * loss_masks['valid_nodes'])
# attention mask loss: BCE loss on the attention maps
# loss_fg = loss_weight * self.fg_loss(outputs['attn_maps'], loss_masks)
# total loss
loss = loss_mse
# log the losses
loss_dict["train/loss_mse"] = loss_mse
loss_dict["train/loss_total"] = loss
return loss, loss_dict
def training_step(self, batch, batch_idx):
# prepare the inputs and GT
inputs = self.prepare_inputs(batch, mode='train')
# manage the classifier-free training
self.manage_cfg(inputs)
# forward: diffusion process
self.diffuse_process(inputs)
# reverse: denoising process
outputs = self.model(
x=inputs['noisy_x'],
cat=inputs['cat'],
timesteps=inputs['timesteps'],
feat=inputs['f'],
key_pad_mask=inputs['key_pad_mask'],
graph_mask=inputs['graph_mask'],
attr_mask=inputs['attr_mask'],
)
# compute the loss
loss, loss_dict = self.compute_loss(batch, inputs, outputs)
# manual backward
opt1, opt2 = self.optimizers()
opt1.zero_grad()
opt2.zero_grad()
self.manual_backward(loss)
opt1.step()
opt2.step()
if batch_idx % 20 == 0 and self.global_rank == 0:
now = datetime.now()
now_str = now.strftime("%Y-%m-%d %H:%M:%S")
loss_str = f'Epoch:{self.current_epoch} | Step:{batch_idx:03d} | '
for key, value in loss_dict.items():
loss_str += f"{key}: {value.item():.4f} | "
self.custom_logger.info(now_str + ' | ' + loss_str)
# logging
# self.log_dict(loss_dict, sync_dist=True, on_step=True, on_epoch=False)
def on_train_epoch_end(self):
# step the lr scheduler every epoch
sch1, sch2 = self.lr_schedulers()
sch1.step()
sch2.step()
def inference(self, inputs, is_label_free=False):
device = inputs['x'].device
omega = self.hparams.get("guidance_scaler", 0)
noisy_x = inputs['noisy_x']
# set scheduler to denoise every 100 steps
self.scheduler.set_timesteps(100)
# denoising process
for t in self.scheduler.timesteps:
timesteps = torch.tensor([t], device=device)
outputs_cond = self.model(
x=noisy_x,
cat=inputs['cat'],
timesteps=timesteps,
feat=inputs['f'],
key_pad_mask=inputs['key_pad_mask'],
graph_mask=inputs['graph_mask'],
attr_mask=inputs['attr_mask'],
label_free=is_label_free,
) # take condtional image as input
if omega != 0:
outputs_free = self.model(
x=noisy_x,
cat=inputs['cat'],
timesteps=timesteps,
feat=inputs['dummy_f'],
key_pad_mask=inputs['key_pad_mask'],
graph_mask=inputs['graph_mask'],
attr_mask=inputs['attr_mask'],
label_free=is_label_free,
) # take the dummy DINO features for the condition-free mode
noise_pred = (1 + omega) * outputs_cond['noise_pred'] - omega * outputs_free['noise_pred']
else:
noise_pred = outputs_cond['noise_pred']
noisy_x = self.scheduler.step(noise_pred, t, noisy_x).prev_sample
return noisy_x
def validation_step(self, batch, batch_idx):
# prepare the inputs and GT
inputs = self.prepare_inputs(batch, mode='val')
# denoising process for inference
out = self.inference(inputs)
# compute the metrics
# new_out = torch.zeros_like(out).type_as(out).to(out.device)
# for b in range(out.shape[0]):
# for k in range(32):
# if out[b][(k + 1) * 6 - 1].mean() > 0.5:
# new_out[b][k * 6: (k + 1) * 6] = out[b][k * 6: (k + 1) * 6]
# zero center
# rescale
# ready
# out = new_out
# new_out = torch.zeros_like(out).type_as(out).to(out.device)
# for b in range(out.shape[0]):
# for k in range(32):
# min_aabb_diff = 1e10
# min_index = k
# aabb_center = (out[b][k * 6][:3] + out[b][k * 6 ][3:]) / 2
# for k_gt in range(32):
# aabb_gt_center = (batch[1][b][k_gt * 6][:3] + batch[1][b][k_gt * 6][3:]) / 2
# aabb_diff = torch.norm(aabb_center - aabb_gt_center)
# if aabb_diff < min_aabb_diff:
# min_aabb_diff = aabb_diff
# min_index = k_gt
# new_out[b][min_index * 6: (min_index + 1) * 6] = out[b][k * 6: (k + 1) * 6]
# out = new_out
log_dict = self.val_compute_metrics(out, inputs['x'], batch[1])
self.log_dict(log_dict, on_step=True)
# visualize the first 10 results
# self.save_val_img(out[:16], inputs['x'][:16], batch[1])
def test_step(self, batch, batch_idx):
# exp_name = self._get_exp_name()
# print(self.get_save_path(exp_name))
# if batch_idx > 2:
# return
# return
is_label_free = self.hparams.get("test_label_free", False)
exp_name = self._get_exp_name()
model_name = batch[1]["name"][0].replace("/", '@')
save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}"
print(save_dir)
if os.path.exists(self.get_save_path(f"{save_dir}/output.png")):
return
# prepare the inputs and GT
inputs = self.prepare_inputs(batch, mode='test', n_samples=5)
# denoising process for inference
out = self.inference(inputs, is_label_free)
# save the results
self.save_test_step(out, inputs['x'], batch[1], batch_idx)
def on_test_end(self):
# only run the single GPU
# if self.global_rank == 0:
# exp_name = self._get_exp_name()
# # retrieve parts
# subprocess.run(['python', 'scripts/mesh_retrieval/run_retrieve.py', '--src', self.get_save_path(exp_name), '--json_name', 'object.json', '--gt_data_root', '../singapo'])
# # save metrics
# if not self.hparams.get("test_no_GT", False):
# subprocess.run(['python', 'scripts/eval_metrics.py', '--exp_dir', self.get_save_path(exp_name), '--gt_root', '../acd_data/'])
# # save html
# self._save_html_end()
pass
def configure_optimizers(self):
self.cage_params = self.adapter_params = []
for name, param in self.model.named_parameters():
if "img" in name or "norm5" in name or "norm6" in name:
self.adapter_params.append(param)
else:
self.cage_params.append(param)
optimizer_adapter = torch.optim.AdamW(
self.adapter_params, **self.hparams.optimizer_adapter.args
)
lr_scheduler_adapter = LinearWarmupCosineAnnealingLR(
optimizer_adapter,
warmup_epochs=self.hparams.lr_scheduler_adapter.warmup_epochs,
max_epochs=self.hparams.lr_scheduler_adapter.max_epochs,
warmup_start_lr=self.hparams.lr_scheduler_adapter.warmup_start_lr,
eta_min=self.hparams.lr_scheduler_adapter.eta_min,
)
optimizer_cage = torch.optim.AdamW(
self.cage_params, **self.hparams.optimizer_cage.args
)
lr_scheduler_cage = LinearWarmupCosineAnnealingLR(
optimizer_cage,
warmup_epochs=self.hparams.lr_scheduler_cage.warmup_epochs,
max_epochs=self.hparams.lr_scheduler_cage.max_epochs,
warmup_start_lr=self.hparams.lr_scheduler_cage.warmup_start_lr,
eta_min=self.hparams.lr_scheduler_cage.eta_min,
)
return (
{"optimizer": optimizer_adapter, "lr_scheduler": lr_scheduler_adapter},
{"optimizer": optimizer_cage, "lr_scheduler": lr_scheduler_cage},
)
|