File size: 2,681 Bytes
a856109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn.functional as F
from torch import nn
from scipy.optimize import linear_sum_assignment


class uncertainty_light_pos_loss(nn.Module):
    def __init__(self):
        super(uncertainty_light_pos_loss, self).__init__()
        self.log_var_xyr = nn.Parameter(torch.tensor(1.0, requires_grad=True))
        self.log_var_p = nn.Parameter(torch.tensor(1.0, requires_grad=True))

    def forward(self, logits, targets):
        B, N, P = logits.shape  # (B, 4, 4)

        position_loss = 0
        confidence_loss = 0

        w_xyr = 0.5 / (self.log_var_xyr**2)  # uncertainty weight for position loss
        w_p = 0.5 / (self.log_var_p**2)  # uncertainty weight for confidence loss
        weights = torch.tensor([1, 1, 2], device=logits.device)  # weights for x, y, r

        for b in range(B):
            pred_xyr = logits[b, :, :3]  # (N, 3)
            pred_p = logits[b, :, 3]  # (N,)

            gt_xyr = targets[b, :, :3]  # (N, 3)
            gt_p = targets[b, :, 3]  # (N,)

            cost_matrix = torch.cdist(gt_xyr, pred_xyr, p=2)  # (N, N)

            with torch.no_grad():
                row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().numpy())

            matched_pred_xyr = pred_xyr[col_ind]
            matched_gt_xyr = gt_xyr[row_ind]
            matched_pred_p = pred_p[col_ind]
            matched_gt_p = gt_p[row_ind]

            valid_mask = matched_gt_p > 0
            valid_cnt = valid_mask.sum().clamp(min=1)

            xyr_loss = (
                F.smooth_l1_loss(
                    matched_pred_xyr[valid_mask],
                    matched_gt_xyr[valid_mask],
                    reduction="none",
                )
                * weights
            ).sum()

            p_loss = F.binary_cross_entropy(
                matched_pred_p, matched_gt_p, reduction="mean"
            )

            position_loss += xyr_loss / valid_cnt
            confidence_loss += p_loss

        position_loss = w_xyr * (position_loss / B) + torch.log(1 + self.log_var_xyr**2)
        confidence_loss = w_p * (confidence_loss / B) + torch.log(1 + self.log_var_p**2)

        return position_loss, confidence_loss


class unet_3maps_loss(nn.Module):
    def __init__(self):
        super(unet_3maps_loss, self).__init__()

    def forward(self, pred_prob, pred_rad, prob_gt, rad_gt):
        focal = nn.BCELoss()
        L_prob = focal(pred_prob, prob_gt)

        pos_mask = prob_gt > 0.5
        L_rad = (
            nn.functional.smooth_l1_loss(pred_rad[pos_mask], rad_gt[pos_mask])
            if pos_mask.any()
            else pred_rad.sum() * 0
        )

        return L_prob + 10.0 * L_rad, L_prob, L_rad