File size: 16,509 Bytes
ce34030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
import torch
import subprocess
import numpy as np
import models
import systems
import torch.nn.functional as F
from diffusers import DDPMScheduler
from systems.base import BaseSystem
from my_utils.lr_schedulers import LinearWarmupCosineAnnealingLR
from datetime import datetime
import logging

@systems.register("sys_origin")
class SingapoSystem(BaseSystem):
    """Trainer for the B9 model, incorporating the classifier-free for image condition."""

    def __init__(self, hparams):
        super().__init__(hparams)
        self.model = models.make(hparams.model.name, hparams.model)
        # configure the scheduler of DDPM
        self.scheduler = DDPMScheduler(**self.hparams.scheduler.config)
        # load the dummy DINO features
        self.dummy_dino = np.load('systems/dino_dummy.npy').astype(np.float32)
        # use the manual optimization
        self.automatic_optimization = False
        # save the hyperparameters
        self.save_hyperparameters()

        self.custom_logger = logging.getLogger(__name__)
        self.custom_logger.setLevel(logging.INFO)
        if self.global_rank == 0:
            self.custom_logger.addHandler(logging.StreamHandler())

    def load_cage_weights(self, pretrained_ckpt=None):
        ckpt = torch.load(pretrained_ckpt)
        state_dict = ckpt["state_dict"]
        # remove the "model." prefix from the keys
        state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
        # load the weights
        self.model.load_state_dict(state_dict, strict=False) 
        # separate the weights of CAGE and our new modules
        print("[INFO] loaded model weights of the pretrained CAGE.")
        

    def fg_loss(self, all_attn_maps, loss_masks):
        """
        Excite the attention maps within the object regions, while weaken the attention outside the object regions.
        
        Args:
            all_attn_maps: cross-attention maps from all layers, shape (B*L, H, 160, 256)
            loss_masks: object seg mask on the image patches, shape (B, 160, 256)

        Returns:
            loss: loss on the attention maps
        """
        valid_mask = loss_masks['valid_nodes']
        fg_mask = loss_masks['fg']
        # get the number of layers and batch size
        L = self.hparams.model.n_layers
        H = all_attn_maps.shape[1]
        # Reshape all the masks to the shape of the attention maps
        valid_node = valid_mask[:, :, 0].unsqueeze(1).expand(-1, H, -1).unsqueeze(-1).expand(-1, -1, -1, 256).repeat(L, 1, 1, 1)
        obj_region = fg_mask.unsqueeze(1).expand(-1, H, -1, -1).repeat(L, 1, 1, 1)
        # construct masks for the object and non-object regions
        fg_region = torch.logical_and(valid_node, obj_region)
        bg_region = torch.logical_and(valid_node, ~obj_region)
        # loss to excite the foreground regions
        loss = 1. - all_attn_maps[fg_region].mean() + all_attn_maps[bg_region].mean()
        return loss
    
    def diffuse_process(self, inputs):
        x = inputs["x"]
        # Sample Gaussian noise
        noise = torch.randn(x.shape, device=self.device, dtype=x.dtype)
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0,
            self.scheduler.config.num_train_timesteps,
            (x.shape[0],),
            device=self.device,
            dtype=torch.long,
        )
        # Add Gaussian noise to the input
        noisy_x = self.scheduler.add_noise(x, noise, timesteps)
        # update the inputs
        inputs["noise"] = noise
        inputs["timesteps"] = timesteps
        inputs["noisy_x"] = noisy_x

    def prepare_inputs(self, batch, mode='train', n_samples=1):
        x, c, f = batch

        cat = c["cat"]                   # object category
        attr_mask = c["attr_mask"]       # attention mask for local self-attention (follow the CAGE)
        key_pad_mask = c["key_pad_mask"] # key padding mask for global self-attention (follow the CAGE)
        graph_mask = c["adj_mask"]       # attention mask for graph relation self-attention (follow the CAGE)

        inputs = {}
        if mode == 'train':
            # the number of sampled timesteps per iteration
            n_repeat = self.hparams.n_time_samples
            # for sampling multiple timesteps
            x = x.repeat(n_repeat, 1, 1)
            cat = cat.repeat(n_repeat)
            f = f.repeat(n_repeat, 1, 1)
            key_pad_mask = key_pad_mask.repeat(n_repeat, 1, 1)
            graph_mask = graph_mask.repeat(n_repeat, 1, 1)
            attr_mask = attr_mask.repeat(n_repeat, 1, 1)
        elif mode == 'val':
            noisy_x = torch.randn(x.shape, device=x.device)
            dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f)
            inputs["noisy_x"] = noisy_x
            inputs["dummy_f"] = dummy_f
        elif mode == 'test':
            # for sampling multiple outputs
            x = x.repeat(n_samples, 1, 1)
            cat = cat.repeat(n_samples)
            f = f.repeat(n_samples, 1, 1)
            key_pad_mask = key_pad_mask.repeat(n_samples, 1, 1)
            graph_mask = graph_mask.repeat(n_samples, 1, 1)
            attr_mask = attr_mask.repeat(n_samples, 1, 1)
            noisy_x = torch.randn(x.shape, device=x.device)
            dummy_f = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(f)
            inputs["noisy_x"] = noisy_x
            inputs["dummy_f"] = dummy_f.repeat(1, 2, 1)
        else:
            raise ValueError(f"Invalid mode: {mode}")
        
        inputs["x"] = x
        inputs["f"] = f
        inputs["cat"] = cat
        inputs["key_pad_mask"] = key_pad_mask
        inputs["graph_mask"] = graph_mask
        inputs["attr_mask"] = attr_mask
        
        return inputs
    
    def prepare_loss_mask(self, batch):
        x, c, _ = batch
        n_repeat = self.hparams.n_time_samples # the number of sampled timesteps per iteration

        # mask on the image patches for the foreground regions
        # mask_fg = c["img_obj_mask"] 
        # if mask_fg is not None:
        #     mask_fg = mask_fg.repeat(n_repeat, 1, 1)
        
        # mask on the valid nodes
        index_tensor = torch.arange(x.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0)  # (1, N)
        valid_nodes = index_tensor < (c['n_nodes'] * 5).unsqueeze(-1)  
        mask_valid_nodes = valid_nodes.unsqueeze(-1).expand_as(x)
        mask_valid_nodes = mask_valid_nodes.repeat(n_repeat, 1, 1)

        return {"fg": None, "valid_nodes": mask_valid_nodes}
    
    def manage_cfg(self, inputs):
        '''
        Manage the classifier-free training for the image and graph condition.
        The CFG for object category is managed by the model (i.e. the CombinedTimestepLabelEmbeddings module in norm1 for each attention block)
        '''
        img_drop_prob = self.hparams.get("img_drop_prob", 0.0)
        graph_drop_prob = self.hparams.get("graph_drop_prob", 0.0)
        drop_img, drop_graph = False, False

        if img_drop_prob > 0.0:
            drop_img = torch.rand(1) < img_drop_prob
            if drop_img.item():
                dummy_batch = torch.tensor(self.dummy_dino, device=self.device).unsqueeze(0).repeat(1, 2, 1).expand_as(inputs['f'])
                inputs['f'] = dummy_batch  # use the dummy DINO features

        if graph_drop_prob > 0.0:
            if not drop_img:
                drop_graph = torch.rand(1) < graph_drop_prob
                if drop_graph.item():
                    inputs['graph_mask'] = None # for varify the model only, replace with the below line later and retrain the model
                    # inputs['graph_mask'] = inputs['key_pad_mask'] # use the key padding mask

    def compute_loss(self, batch, inputs, outputs):
        loss_dict = {}
        # loss_weight = self.hparams.get("loss_fg_weight", 1.0)

        # prepare the loss masks
        loss_masks = self.prepare_loss_mask(batch)

        # diffusion model loss: MSE on the residual noise
        loss_mse = F.mse_loss(outputs['noise_pred'] * loss_masks['valid_nodes'], inputs['noise'] * loss_masks['valid_nodes'])
        # attention mask loss: BCE loss on the attention maps
        # loss_fg = loss_weight * self.fg_loss(outputs['attn_maps'], loss_masks)
        
        # total loss
        loss = loss_mse
        
        # log the losses
        loss_dict["train/loss_mse"] = loss_mse
        loss_dict["train/loss_total"] = loss

        return loss, loss_dict
 
    def training_step(self, batch, batch_idx):
        # prepare the inputs and GT
        inputs = self.prepare_inputs(batch, mode='train')
        
        # manage the classifier-free training
        self.manage_cfg(inputs)

        # forward: diffusion process
        self.diffuse_process(inputs)

        # reverse: denoising process
        outputs = self.model(
            x=inputs['noisy_x'],
            cat=inputs['cat'],
            timesteps=inputs['timesteps'],
            feat=inputs['f'],
            key_pad_mask=inputs['key_pad_mask'],
            graph_mask=inputs['graph_mask'],
            attr_mask=inputs['attr_mask'],
        )

        # compute the loss
        loss, loss_dict = self.compute_loss(batch, inputs, outputs)

        # manual backward
        opt1, opt2 = self.optimizers()
        opt1.zero_grad()
        opt2.zero_grad()
        self.manual_backward(loss)
        opt1.step()
        opt2.step()

        if batch_idx % 20 == 0 and self.global_rank == 0:
            now = datetime.now()
            now_str = now.strftime("%Y-%m-%d %H:%M:%S")
            loss_str = f'Epoch:{self.current_epoch} | Step:{batch_idx:03d} | '
            for key, value in loss_dict.items():
                loss_str += f"{key}: {value.item():.4f} | "
            self.custom_logger.info(now_str + ' | ' + loss_str)
        # logging
        # self.log_dict(loss_dict, sync_dist=True, on_step=True, on_epoch=False)

    def on_train_epoch_end(self):
        # step the lr scheduler every epoch
        sch1, sch2 = self.lr_schedulers()
        sch1.step()
        sch2.step()

    def inference(self, inputs, is_label_free=False):
        device = inputs['x'].device
        omega = self.hparams.get("guidance_scaler", 0)
        noisy_x = inputs['noisy_x']

        # set scheduler to denoise every 100 steps
        self.scheduler.set_timesteps(100)
        # denoising process
        for t in self.scheduler.timesteps:
            timesteps = torch.tensor([t], device=device)
            outputs_cond = self.model(
                x=noisy_x,
                cat=inputs['cat'],
                timesteps=timesteps,
                feat=inputs['f'], 
                key_pad_mask=inputs['key_pad_mask'],
                graph_mask=inputs['graph_mask'],
                attr_mask=inputs['attr_mask'],
                label_free=is_label_free,
            ) # take condtional image as input
            if omega != 0:
                outputs_free = self.model(
                    x=noisy_x,
                    cat=inputs['cat'],
                    timesteps=timesteps,
                    feat=inputs['dummy_f'], 
                    key_pad_mask=inputs['key_pad_mask'],
                    graph_mask=inputs['graph_mask'],
                    attr_mask=inputs['attr_mask'],
                    label_free=is_label_free,
                ) # take the dummy DINO features for the condition-free mode
                noise_pred = (1 + omega) * outputs_cond['noise_pred'] - omega * outputs_free['noise_pred']
            else:
                noise_pred = outputs_cond['noise_pred']
            noisy_x = self.scheduler.step(noise_pred, t, noisy_x).prev_sample

        return noisy_x

    def validation_step(self, batch, batch_idx):
        # prepare the inputs and GT
        inputs = self.prepare_inputs(batch, mode='val')
        # denoising process for inference
        out = self.inference(inputs)
        # compute the metrics
        # new_out = torch.zeros_like(out).type_as(out).to(out.device)
        # for b in range(out.shape[0]):
        #     for k in range(32):
        #         if out[b][(k + 1) * 6 - 1].mean() > 0.5:
        #             new_out[b][k * 6: (k + 1) * 6] = out[b][k * 6: (k + 1) * 6]
        # zero center
        
        # rescale

        # ready
        # out = new_out
        # new_out = torch.zeros_like(out).type_as(out).to(out.device)
        # for b in range(out.shape[0]):
        #     for k in range(32):
        #         min_aabb_diff = 1e10
        #         min_index = k
        #         aabb_center = (out[b][k * 6][:3] + out[b][k * 6 ][3:]) / 2
        #         for k_gt in range(32):
        #             aabb_gt_center = (batch[1][b][k_gt * 6][:3] + batch[1][b][k_gt * 6][3:]) / 2
        #             aabb_diff = torch.norm(aabb_center - aabb_gt_center)
        #             if aabb_diff < min_aabb_diff:
        #                 min_aabb_diff = aabb_diff
        #                 min_index = k_gt
        #         new_out[b][min_index * 6: (min_index + 1) * 6] = out[b][k * 6: (k + 1) * 6]
        # out = new_out

        log_dict = self.val_compute_metrics(out, inputs['x'], batch[1])
        self.log_dict(log_dict, on_step=True)

        # visualize the first 10 results
        # self.save_val_img(out[:16], inputs['x'][:16], batch[1])

    def test_step(self, batch, batch_idx):
        # exp_name = self._get_exp_name()
        # print(self.get_save_path(exp_name))
        # if batch_idx > 2:
        #     return
        # return
        is_label_free = self.hparams.get("test_label_free", False)
        exp_name = self._get_exp_name()
        model_name = batch[1]["name"][0].replace("/", '@')
        save_dir = f"{exp_name}/{str(batch_idx)}@{model_name}"
        print(save_dir)
        if os.path.exists(self.get_save_path(f"{save_dir}/output.png")):

            return
        # prepare the inputs and GT
        inputs = self.prepare_inputs(batch, mode='test', n_samples=5)
        # denoising process for inference
        out = self.inference(inputs, is_label_free)
        # save the results
        self.save_test_step(out, inputs['x'], batch[1], batch_idx)

    def on_test_end(self):
        # only run the single GPU
        # if self.global_rank == 0:
        #     exp_name = self._get_exp_name()
        #     # retrieve parts
        #     subprocess.run(['python', 'scripts/mesh_retrieval/run_retrieve.py', '--src', self.get_save_path(exp_name), '--json_name', 'object.json', '--gt_data_root', '../singapo'])
        #     # save metrics
        #     if not self.hparams.get("test_no_GT", False):
        #         subprocess.run(['python', 'scripts/eval_metrics.py', '--exp_dir', self.get_save_path(exp_name), '--gt_root', '../acd_data/'])
        #     # save html
        #     self._save_html_end()
        pass

    def configure_optimizers(self):
        self.cage_params = self.adapter_params = []
        for name, param in self.model.named_parameters():
            if "img" in name or "norm5" in name or "norm6" in name:
                self.adapter_params.append(param)
            else:
                self.cage_params.append(param)
        optimizer_adapter = torch.optim.AdamW(
            self.adapter_params, **self.hparams.optimizer_adapter.args
        )
        lr_scheduler_adapter = LinearWarmupCosineAnnealingLR(
            optimizer_adapter,
            warmup_epochs=self.hparams.lr_scheduler_adapter.warmup_epochs,
            max_epochs=self.hparams.lr_scheduler_adapter.max_epochs,
            warmup_start_lr=self.hparams.lr_scheduler_adapter.warmup_start_lr,
            eta_min=self.hparams.lr_scheduler_adapter.eta_min,
        )

        optimizer_cage = torch.optim.AdamW(
            self.cage_params, **self.hparams.optimizer_cage.args
        )
        lr_scheduler_cage = LinearWarmupCosineAnnealingLR(
            optimizer_cage,
            warmup_epochs=self.hparams.lr_scheduler_cage.warmup_epochs,
            max_epochs=self.hparams.lr_scheduler_cage.max_epochs,
            warmup_start_lr=self.hparams.lr_scheduler_cage.warmup_start_lr,
            eta_min=self.hparams.lr_scheduler_cage.eta_min,
        )
        return (
            {"optimizer": optimizer_adapter, "lr_scheduler": lr_scheduler_adapter},
            {"optimizer": optimizer_cage, "lr_scheduler": lr_scheduler_cage},
        )