Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from scipy.optimize import linear_sum_assignment | |
| class uncertainty_light_pos_loss(nn.Module): | |
| def __init__(self): | |
| super(uncertainty_light_pos_loss, self).__init__() | |
| self.log_var_xyr = nn.Parameter(torch.tensor(1.0, requires_grad=True)) | |
| self.log_var_p = nn.Parameter(torch.tensor(1.0, requires_grad=True)) | |
| def forward(self, logits, targets): | |
| B, N, P = logits.shape # (B, 4, 4) | |
| position_loss = 0 | |
| confidence_loss = 0 | |
| w_xyr = 0.5 / (self.log_var_xyr**2) # uncertainty weight for position loss | |
| w_p = 0.5 / (self.log_var_p**2) # uncertainty weight for confidence loss | |
| weights = torch.tensor([1, 1, 2], device=logits.device) # weights for x, y, r | |
| for b in range(B): | |
| pred_xyr = logits[b, :, :3] # (N, 3) | |
| pred_p = logits[b, :, 3] # (N,) | |
| gt_xyr = targets[b, :, :3] # (N, 3) | |
| gt_p = targets[b, :, 3] # (N,) | |
| cost_matrix = torch.cdist(gt_xyr, pred_xyr, p=2) # (N, N) | |
| with torch.no_grad(): | |
| row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().numpy()) | |
| matched_pred_xyr = pred_xyr[col_ind] | |
| matched_gt_xyr = gt_xyr[row_ind] | |
| matched_pred_p = pred_p[col_ind] | |
| matched_gt_p = gt_p[row_ind] | |
| valid_mask = matched_gt_p > 0 | |
| valid_cnt = valid_mask.sum().clamp(min=1) | |
| xyr_loss = ( | |
| F.smooth_l1_loss( | |
| matched_pred_xyr[valid_mask], | |
| matched_gt_xyr[valid_mask], | |
| reduction="none", | |
| ) | |
| * weights | |
| ).sum() | |
| p_loss = F.binary_cross_entropy( | |
| matched_pred_p, matched_gt_p, reduction="mean" | |
| ) | |
| position_loss += xyr_loss / valid_cnt | |
| confidence_loss += p_loss | |
| position_loss = w_xyr * (position_loss / B) + torch.log(1 + self.log_var_xyr**2) | |
| confidence_loss = w_p * (confidence_loss / B) + torch.log(1 + self.log_var_p**2) | |
| return position_loss, confidence_loss | |
| class unet_3maps_loss(nn.Module): | |
| def __init__(self): | |
| super(unet_3maps_loss, self).__init__() | |
| def forward(self, pred_prob, pred_rad, prob_gt, rad_gt): | |
| focal = nn.BCELoss() | |
| L_prob = focal(pred_prob, prob_gt) | |
| pos_mask = prob_gt > 0.5 | |
| L_rad = ( | |
| nn.functional.smooth_l1_loss(pred_rad[pos_mask], rad_gt[pos_mask]) | |
| if pos_mask.any() | |
| else pred_rad.sum() * 0 | |
| ) | |
| return L_prob + 10.0 * L_rad, L_prob, L_rad | |