forked from dilligencer-zrj/code_zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbinary_segmentation_hardmining
37 lines (30 loc) · 1.28 KB
/
binary_segmentation_hardmining
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def hard_mining(neg_output, neg_labels, num_hard):
l = len(neg_output)
neg_output, idcs = torch.topk(neg_output, l)
idcs = idcs[200:]
interval = num_hard/5
l_devide = l/5
selected_idcs = idcs[:interval]
for i in range(4):
i+=1
selected_idcs = torch.cat((selected_idcs, idcs[i*l_devide: i*l_devide+interval]))
neg_output = torch.index_select(neg_output, 0, selected_idcs)
neg_labels = torch.index_select(neg_labels, 0, selected_idcs)
return neg_output, neg_labels
class LossWithHardmining(nn.Module):
def __init__(self, weight=None):
super(LossWithHardmining, self).__init__()
self.bce = nn.BCELoss(weight)
def forward(self, outputs, targets):
outputs = F.sigmoid(outputs[:,1])
# print outputs.size()
pos_idcs = targets>0
neg_idcs = targets<1
pos_output = outputs[pos_idcs]
pos_target = targets[pos_idcs].type(torch.FloatTensor).cuda()
neg_output = outputs[neg_idcs]
neg_target = targets[neg_idcs].type(torch.FloatTensor).cuda()
neg_output, neg_target = hard_mining(neg_output, neg_target, num_hard=4*len(pos_target))
pos_loss = self.bce(pos_output, pos_target)
neg_loss = self.bce(neg_output, neg_target)
return pos_loss+neg_loss