-
Notifications
You must be signed in to change notification settings - Fork 306
How to use in-batch negative and gold when training? #110
Comments
Hi @hongyuntw , |
@vlad-karpukhin |
Yes, this code uses in-batch negative training trick. |
@vlad-karpukhin |
The training pipeline code with distributed loss calculation is in train_dense_encoder.py |
Hi @vlad-karpukhin and @hongyuntw From my understading, the implementation of in-batch negative sampling and corresponding loss is computed as follows
scores = self.get_scores(q_vectors, ctx_vectors)
if len(q_vectors.size()) > 1:
q_num = q_vectors.size(0)
scores = scores.view(q_num, -1)
softmax_scores = F.log_softmax(scores, dim=1)
loss = F.nll_loss(
softmax_scores,
torch.tensor(positive_idx_per_question).to(softmax_scores.device),
reduction="mean",
)
Please DO correct me if I'm wrong
scores = self.get_scores(q_vectors, ctx_vectors)
if len(q_vectors.size()) > 1:
q_num = q_vectors.size(0)
scores = scores.view(q_num, -1)
softmax_scores = F.log_softmax(scores, dim=1)
softmax_scores_sliced = self.slice(softmax_scores)
loss = F.nll_loss(
softmax_scores_sliced,
torch.tensor([0,0,0,0]).to(softmax_scores_sliced.device),
reduction="mean",
) |
@robinsongh381 |
Hi @robinsongh381 , Your understanding is almost correct. |
Hi @vlad-karpukhin You are absolutely right, the blue lines should have been named "hard" negatives. I am not quite sure
Regards |
|
Hi @vlad-karpukhin, thank you for providing the code! And @robinsongh381, I appreciate the clear illustration. I have a question regarding in-batch negatives across devices, as demonstrated here: Line 618 in a31212d
From my understanding, the gather operation aims to supply more negative samples for a single question (please correct me if I'm mistaken). With that said, could you please explain why it's necessary to gather questions across GPUs as well? |
as title
paper used in-batch negative and gold when training
but how could i setting these?
somebody can help me? thanks a lot!
The text was updated successfully, but these errors were encountered: