This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 306
/
Copy pathbiencoder.py
452 lines (384 loc) · 15.5 KB
/
biencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
BiEncoder component + loss function for 'all-in-batch' training
"""
import collections
import logging
import random
from typing import Tuple, List
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor as T
from torch import nn
from dpr.data.biencoder_data import BiEncoderSample
from dpr.utils.data_utils import Tensorizer
from dpr.utils.model_utils import CheckpointState
logger = logging.getLogger(__name__)
BiEncoderBatch = collections.namedtuple(
"BiENcoderInput",
[
"question_ids",
"question_segments",
"context_ids",
"ctx_segments",
"is_positive",
"hard_negatives",
"encoder_type",
],
)
# TODO: it is only used by _select_span_with_token. Move them to utils
rnd = random.Random(0)
def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T:
"""
calculates q->ctx scores for every row in ctx_vector
:param q_vector:
:param ctx_vector:
:return:
"""
# q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
return r
def cosine_scores(q_vector: T, ctx_vectors: T):
# q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
return F.cosine_similarity(q_vector, ctx_vectors, dim=1)
class BiEncoder(nn.Module):
"""Bi-Encoder model component. Encapsulates query/question and context/passage encoders."""
def __init__(
self,
question_model: nn.Module,
ctx_model: nn.Module,
fix_q_encoder: bool = False,
fix_ctx_encoder: bool = False,
):
super(BiEncoder, self).__init__()
self.question_model = question_model
self.ctx_model = ctx_model
self.fix_q_encoder = fix_q_encoder
self.fix_ctx_encoder = fix_ctx_encoder
@staticmethod
def get_representation(
sub_model: nn.Module,
ids: T,
segments: T,
attn_mask: T,
fix_encoder: bool = False,
representation_token_pos=0,
) -> (T, T, T):
sequence_output = None
pooled_output = None
hidden_states = None
if ids is not None:
if fix_encoder:
with torch.no_grad():
sequence_output, pooled_output, hidden_states = sub_model(
ids,
segments,
attn_mask,
representation_token_pos=representation_token_pos,
)
if sub_model.training:
sequence_output.requires_grad_(requires_grad=True)
pooled_output.requires_grad_(requires_grad=True)
else:
sequence_output, pooled_output, hidden_states = sub_model(
ids,
segments,
attn_mask,
representation_token_pos=representation_token_pos,
)
return sequence_output, pooled_output, hidden_states
def forward(
self,
question_ids: T,
question_segments: T,
question_attn_mask: T,
context_ids: T,
ctx_segments: T,
ctx_attn_mask: T,
encoder_type: str = None,
representation_token_pos=0,
) -> Tuple[T, T]:
q_encoder = (
self.question_model
if encoder_type is None or encoder_type == "question"
else self.ctx_model
)
_q_seq, q_pooled_out, _q_hidden = self.get_representation(
q_encoder,
question_ids,
question_segments,
question_attn_mask,
self.fix_q_encoder,
representation_token_pos=representation_token_pos,
)
ctx_encoder = (
self.ctx_model
if encoder_type is None or encoder_type == "ctx"
else self.question_model
)
_ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation(
ctx_encoder, context_ids, ctx_segments, ctx_attn_mask, self.fix_ctx_encoder
)
return q_pooled_out, ctx_pooled_out
# TODO delete once moved to the new method
@classmethod
def create_biencoder_input(
cls,
samples: List,
tensorizer: Tensorizer,
insert_title: bool,
num_hard_negatives: int = 0,
num_other_negatives: int = 0,
shuffle: bool = True,
shuffle_positives: bool = False,
hard_neg_fallback: bool = True,
) -> BiEncoderBatch:
"""
Creates a batch of the biencoder training tuple.
:param samples: list of data items (from json) to create the batch for
:param tensorizer: components to create model input tensors from a text sequence
:param insert_title: enables title insertion at the beginning of the context sequences
:param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
:param num_other_negatives: amount of other negatives per question (taken from samples' pools)
:param shuffle: shuffles negative passages pools
:param shuffle_positives: shuffles positive passages pools
:return: BiEncoderBatch tuple
"""
question_tensors = []
ctx_tensors = []
positive_ctx_indices = []
hard_neg_ctx_indices = []
for sample in samples:
# ctx+ & [ctx-] composition
# as of now, take the first(gold) ctx+ only
if shuffle and shuffle_positives:
positive_ctxs = sample["positive_ctxs"]
positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))]
else:
positive_ctx = sample["positive_ctxs"][0]
neg_ctxs = sample["negative_ctxs"]
hard_neg_ctxs = sample["hard_negative_ctxs"]
if shuffle:
random.shuffle(neg_ctxs)
random.shuffle(hard_neg_ctxs)
if hard_neg_fallback and len(hard_neg_ctxs) == 0:
hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]
neg_ctxs = neg_ctxs[0:num_other_negatives]
hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]
all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
hard_negatives_start_idx = 1
hard_negatives_end_idx = 1 + len(hard_neg_ctxs)
current_ctxs_len = len(ctx_tensors)
sample_ctxs_tensors = [
tensorizer.text_to_tensor(
ctx["text"],
title=ctx["title"] if (insert_title and "title" in ctx) else None,
)
for ctx in all_ctxs
]
ctx_tensors.extend(sample_ctxs_tensors)
positive_ctx_indices.append(current_ctxs_len)
hard_neg_ctx_indices.append(
[
i
for i in range(
current_ctxs_len + hard_negatives_start_idx,
current_ctxs_len + hard_negatives_end_idx,
)
]
)
question_tensors.append(tensorizer.text_to_tensor(question))
ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0)
questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0)
ctx_segments = torch.zeros_like(ctxs_tensor)
question_segments = torch.zeros_like(questions_tensor)
return BiEncoderBatch(
questions_tensor,
question_segments,
ctxs_tensor,
ctx_segments,
positive_ctx_indices,
hard_neg_ctx_indices,
"question",
)
@classmethod
def create_biencoder_input2(
cls,
samples: List[BiEncoderSample],
tensorizer: Tensorizer,
insert_title: bool,
num_hard_negatives: int = 0,
num_other_negatives: int = 0,
shuffle: bool = True,
shuffle_positives: bool = False,
hard_neg_fallback: bool = True,
query_token: str = None,
) -> BiEncoderBatch:
"""
Creates a batch of the biencoder training tuple.
:param samples: list of BiEncoderSample-s to create the batch for
:param tensorizer: components to create model input tensors from a text sequence
:param insert_title: enables title insertion at the beginning of the context sequences
:param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
:param num_other_negatives: amount of other negatives per question (taken from samples' pools)
:param shuffle: shuffles negative passages pools
:param shuffle_positives: shuffles positive passages pools
:return: BiEncoderBatch tuple
"""
question_tensors = []
ctx_tensors = []
positive_ctx_indices = []
hard_neg_ctx_indices = []
for sample in samples:
# ctx+ & [ctx-] composition
# as of now, take the first(gold) ctx+ only
if shuffle and shuffle_positives:
positive_ctxs = sample.positive_passages
positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))]
else:
positive_ctx = sample.positive_passages[0]
neg_ctxs = sample.negative_passages
hard_neg_ctxs = sample.hard_negative_passages
question = sample.query
# question = normalize_question(sample.query)
if shuffle:
random.shuffle(neg_ctxs)
random.shuffle(hard_neg_ctxs)
if hard_neg_fallback and len(hard_neg_ctxs) == 0:
hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]
neg_ctxs = neg_ctxs[0:num_other_negatives]
hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]
all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
hard_negatives_start_idx = 1
hard_negatives_end_idx = 1 + len(hard_neg_ctxs)
current_ctxs_len = len(ctx_tensors)
sample_ctxs_tensors = [
tensorizer.text_to_tensor(
ctx.text, title=ctx.title if (insert_title and ctx.title) else None
)
for ctx in all_ctxs
]
ctx_tensors.extend(sample_ctxs_tensors)
positive_ctx_indices.append(current_ctxs_len)
hard_neg_ctx_indices.append(
[
i
for i in range(
current_ctxs_len + hard_negatives_start_idx,
current_ctxs_len + hard_negatives_end_idx,
)
]
)
if query_token:
# TODO: tmp workaround for EL, remove or revise
if query_token == "[START_ENT]":
query_span = _select_span_with_token(
question, tensorizer, token_str=query_token
)
question_tensors.append(query_span)
else:
question_tensors.append(
tensorizer.text_to_tensor(" ".join([query_token, question]))
)
else:
question_tensors.append(tensorizer.text_to_tensor(question))
ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0)
questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0)
ctx_segments = torch.zeros_like(ctxs_tensor)
question_segments = torch.zeros_like(questions_tensor)
return BiEncoderBatch(
questions_tensor,
question_segments,
ctxs_tensor,
ctx_segments,
positive_ctx_indices,
hard_neg_ctx_indices,
"question",
)
def load_state(self, saved_state: CheckpointState):
# TODO: make a long term HF compatibility fix
if "question_model.embeddings.position_ids" in saved_state.model_dict:
del saved_state.model_dict["question_model.embeddings.position_ids"]
del saved_state.model_dict["ctx_model.embeddings.position_ids"]
self.load_state_dict(saved_state.model_dict)
def get_state_dict(self):
return self.state_dict()
class BiEncoderNllLoss(object):
def calc(
self,
q_vectors: T,
ctx_vectors: T,
positive_idx_per_question: list,
hard_negative_idx_per_question: list = None,
loss_scale: float = None,
) -> Tuple[T, int]:
"""
Computes nll loss for the given lists of question and ctx vectors.
Note that although hard_negative_idx_per_question in not currently in use, one can use it for the
loss modifications. For example - weighted NLL with different factors for hard vs regular negatives.
:return: a tuple of loss value and amount of correct predictions per batch
"""
scores = self.get_scores(q_vectors, ctx_vectors)
if len(q_vectors.size()) > 1:
q_num = q_vectors.size(0)
scores = scores.view(q_num, -1)
softmax_scores = F.log_softmax(scores, dim=1)
loss = F.nll_loss(
softmax_scores,
torch.tensor(positive_idx_per_question).to(softmax_scores.device),
reduction="mean",
)
max_score, max_idxs = torch.max(softmax_scores, 1)
correct_predictions_count = (
max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)
).sum()
if loss_scale:
loss.mul_(loss_scale)
return loss, correct_predictions_count
@staticmethod
def get_scores(q_vector: T, ctx_vectors: T) -> T:
f = BiEncoderNllLoss.get_similarity_function()
return f(q_vector, ctx_vectors)
@staticmethod
def get_similarity_function():
return dot_product_scores
def _select_span_with_token(
text: str, tensorizer: Tensorizer, token_str: str = "[START_ENT]"
) -> T:
id = tensorizer.get_token_id(token_str)
query_tensor = tensorizer.text_to_tensor(text)
if id not in query_tensor:
query_tensor_full = tensorizer.text_to_tensor(text, apply_max_len=False)
token_indexes = (query_tensor_full == id).nonzero()
if token_indexes.size(0) > 0:
start_pos = token_indexes[0, 0].item()
# add some randomization to avoid overfitting to a specific token position
left_shit = int(tensorizer.max_length / 2)
rnd_shift = int((rnd.random() - 0.5) * left_shit / 2)
left_shit += rnd_shift
query_tensor = query_tensor_full[start_pos - left_shit :]
cls_id = tensorizer.tokenizer.cls_token_id
if query_tensor[0] != cls_id:
query_tensor = torch.cat([torch.tensor([cls_id]), query_tensor], dim=0)
from dpr.models.reader import _pad_to_len
query_tensor = _pad_to_len(
query_tensor, tensorizer.get_pad_id(), tensorizer.max_length
)
query_tensor[-1] = tensorizer.tokenizer.sep_token_id
# logger.info('aligned query_tensor %s', query_tensor)
assert id in query_tensor, "query_tensor={}".format(query_tensor)
return query_tensor
else:
raise RuntimeError(
"[START_ENT] toke not found for Entity Linking sample query={}".format(
text
)
)
else:
return query_tensor