Ray-1026
update
a856109
import torch
import torch.nn.functional as F
from torch import nn
from scipy.optimize import linear_sum_assignment
class uncertainty_light_pos_loss(nn.Module):
def __init__(self):
super(uncertainty_light_pos_loss, self).__init__()
self.log_var_xyr = nn.Parameter(torch.tensor(1.0, requires_grad=True))
self.log_var_p = nn.Parameter(torch.tensor(1.0, requires_grad=True))
def forward(self, logits, targets):
B, N, P = logits.shape # (B, 4, 4)
position_loss = 0
confidence_loss = 0
w_xyr = 0.5 / (self.log_var_xyr**2) # uncertainty weight for position loss
w_p = 0.5 / (self.log_var_p**2) # uncertainty weight for confidence loss
weights = torch.tensor([1, 1, 2], device=logits.device) # weights for x, y, r
for b in range(B):
pred_xyr = logits[b, :, :3] # (N, 3)
pred_p = logits[b, :, 3] # (N,)
gt_xyr = targets[b, :, :3] # (N, 3)
gt_p = targets[b, :, 3] # (N,)
cost_matrix = torch.cdist(gt_xyr, pred_xyr, p=2) # (N, N)
with torch.no_grad():
row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().numpy())
matched_pred_xyr = pred_xyr[col_ind]
matched_gt_xyr = gt_xyr[row_ind]
matched_pred_p = pred_p[col_ind]
matched_gt_p = gt_p[row_ind]
valid_mask = matched_gt_p > 0
valid_cnt = valid_mask.sum().clamp(min=1)
xyr_loss = (
F.smooth_l1_loss(
matched_pred_xyr[valid_mask],
matched_gt_xyr[valid_mask],
reduction="none",
)
* weights
).sum()
p_loss = F.binary_cross_entropy(
matched_pred_p, matched_gt_p, reduction="mean"
)
position_loss += xyr_loss / valid_cnt
confidence_loss += p_loss
position_loss = w_xyr * (position_loss / B) + torch.log(1 + self.log_var_xyr**2)
confidence_loss = w_p * (confidence_loss / B) + torch.log(1 + self.log_var_p**2)
return position_loss, confidence_loss
class unet_3maps_loss(nn.Module):
def __init__(self):
super(unet_3maps_loss, self).__init__()
def forward(self, pred_prob, pred_rad, prob_gt, rad_gt):
focal = nn.BCELoss()
L_prob = focal(pred_prob, prob_gt)
pos_mask = prob_gt > 0.5
L_rad = (
nn.functional.smooth_l1_loss(pred_rad[pos_mask], rad_gt[pos_mask])
if pos_mask.any()
else pred_rad.sum() * 0
)
return L_prob + 10.0 * L_rad, L_prob, L_rad