-
Notifications
You must be signed in to change notification settings - Fork 287
/
Copy pathretrieval.py
293 lines (247 loc) · 10.8 KB
/
retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# Copyright 2022 The TensorFlow Recommenders Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint-as: python3
"""A factorized retrieval task."""
from typing import Optional, Sequence, Union, Text, List
import tensorflow as tf
from tensorflow_recommenders import layers
from tensorflow_recommenders import metrics as tfrs_metrics
from tensorflow_recommenders.tasks import base
class Retrieval(tf.keras.layers.Layer, base.Task):
"""A factorized retrieval task.
Recommender systems are often composed of two components:
- a retrieval model, retrieving O(thousands) candidates from a corpus of
O(millions) candidates.
- a ranker model, scoring the candidates retrieved by the retrieval model to
return a ranked shortlist of a few dozen candidates.
This task defines models that facilitate efficient retrieval of candidates
from large corpora by maintaining a two-tower, factorized structure: separate
query and candidate representation towers, joined at the top via a lightweight
scoring function.
"""
def __init__(self,
loss: Optional[tf.keras.losses.Loss] = None,
metrics: Optional[Union[
Sequence[tfrs_metrics.Factorized],
tfrs_metrics.Factorized
]] = None,
batch_metrics: Optional[List[tf.keras.metrics.Metric]] = None,
loss_metrics: Optional[List[tf.keras.metrics.Metric]] = None,
temperature: Optional[float] = None,
num_hard_negatives: Optional[int] = None,
remove_accidental_hits: bool = False,
name: Optional[Text] = None) -> None:
"""Initializes the task.
Args:
loss: Loss function. Defaults to
`tf.keras.losses.CategoricalCrossentropy`.
metrics: Object for evaluating top-K metrics over a
corpus of candidates. These metrics measure how good the model is at
picking the true candidate out of all possible candidates in the system.
Note, because the metrics range over the entire candidate set, they are
usually much slower to compute. Consider setting `compute_metrics=False`
during training to save the time in computing the metrics.
batch_metrics: Metrics measuring how good the model is at picking out the
true candidate for a query from other candidates in the batch. For
example, a batch AUC metric would measure the probability that the true
candidate is scored higher than the other candidates in the batch.
loss_metrics: List of Keras metrics used to summarize the loss.
temperature: Temperature of the softmax.
num_hard_negatives: If positive, the `num_hard_negatives` negative
examples with largest logits are kept when computing cross-entropy loss.
If larger than batch size or non-positive, all the negative examples are
kept.
remove_accidental_hits: When given
enables removing accidental hits of examples used as negatives. An
accidental hit is defined as a candidate that is used as an in-batch
negative but has the same id with the positive candidate.
name: Optional task name.
"""
super().__init__(name=name)
self._loss = loss if loss is not None else tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
if metrics is None:
metrics = []
if not isinstance(metrics, Sequence):
metrics = [metrics]
self._factorized_metrics = metrics
self._batch_metrics = batch_metrics or []
self._loss_metrics = loss_metrics or []
self._temperature = temperature
self._num_hard_negatives = num_hard_negatives
self._remove_accidental_hits = remove_accidental_hits
@property
def factorized_metrics(self) -> Optional[
Sequence[tfrs_metrics.Factorized]]:
"""The metrics object used to compute retrieval metrics."""
return self._factorized_metrics
@factorized_metrics.setter
def factorized_metrics(self,
value: Optional[Union[
Sequence[tfrs_metrics.Factorized],
tfrs_metrics.Factorized
]]) -> None:
"""Sets factorized metrics."""
if not isinstance(value, Sequence):
value = []
self._factorized_metrics = value
def call(self,
query_embeddings: tf.Tensor,
candidate_embeddings: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None,
candidate_sampling_probability: Optional[tf.Tensor] = None,
candidate_ids: Optional[tf.Tensor] = None,
compute_metrics: bool = True,
compute_batch_metrics: bool = True) -> tf.Tensor:
"""Computes the task loss and metrics.
The main argument are pairs of query and candidate embeddings: the first row
of query_embeddings denotes a query for which the candidate from the first
row of candidate embeddings was selected by the user.
The task will try to maximize the affinity of these query, candidate pairs
while minimizing the affinity between the query and candidates belonging
to other queries in the batch.
Args:
query_embeddings: [num_queries, embedding_dim] tensor of query
representations.
candidate_embeddings: [num_queries, embedding_dim] tensor of candidate
representations.
sample_weight: [num_queries] tensor of sample weights.
candidate_sampling_probability: Optional tensor of candidate sampling
probabilities. When given will be be used to correct the logits to
reflect the sampling probability of negative candidates.
candidate_ids: Optional tensor containing candidate ids. When given,
factorized top-K evaluation will be id-based rather than score-based.
compute_metrics: Whether to compute metrics. Set this to False
during training for faster training.
compute_batch_metrics: Whether to compute batch level metrics.
In-batch loss_metrics will still be computed.
Returns:
loss: Tensor of loss values.
"""
scores = tf.linalg.matmul(
query_embeddings, candidate_embeddings, transpose_b=True)
num_queries = tf.shape(scores)[0]
num_candidates = tf.shape(scores)[1]
labels = tf.eye(num_queries, num_candidates)
if self._temperature is not None:
scores = scores / self._temperature
if candidate_sampling_probability is not None:
scores = layers.loss.SamplingProbablityCorrection()(
scores, candidate_sampling_probability)
if self._remove_accidental_hits:
if candidate_ids is None:
raise ValueError(
"When accidental hit removal is enabled, candidate ids "
"must be supplied."
)
scores = layers.loss.RemoveAccidentalHits()(labels, scores, candidate_ids)
if self._num_hard_negatives is not None:
scores, labels = layers.loss.HardNegativeMining(self._num_hard_negatives)(
scores,
labels)
loss = self._loss(y_true=labels, y_pred=scores, sample_weight=sample_weight)
update_ops = []
for metric in self._loss_metrics:
update_ops.append(
metric.update_state(loss, sample_weight=sample_weight))
if compute_metrics:
for metric in self._factorized_metrics:
update_ops.append(
metric.update_state(
query_embeddings,
# Slice to the size of query embeddings
# if `candidate_embeddings` contains extra negatives.
candidate_embeddings[:tf.shape(query_embeddings)[0]],
true_candidate_ids=candidate_ids)
)
if compute_batch_metrics:
for metric in self._batch_metrics:
update_ops.append(metric.update_state(labels, scores))
with tf.control_dependencies(update_ops):
return tf.identity(loss)
def _cross_replica_concat(values: tf.Tensor) -> tf.Tensor:
"""Combine tensors, one from each TPU core, into a single concatenated tensor.
The resulting tensor's elements are in the order of the IDs of the cores that
contributed them, but offset so that the first element on each core is the one
contributed by that core. On the ith core, out of N total, it would look like:
tf.Tensor([
values from core i,
values from core i+1,
...
values from core N,
values from core 1,
...
values from core i-1
]).
Here is an example that is meant to run on 4 TPU cores:
>>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="")
>>> tf.config.experimental_connect_to_cluster(resolver)
>>> tf.tpu.experimental.initialize_tpu_system(resolver)
>>> strategy = tf.distribute.experimental.TPUStrategy(resolver)
>>> data = np.array([
... [0, 0, 0, 0],
... [1, 1, 1, 1],
... [2, 2, 2, 2],
... [3, 3, 3, 3]
... ])
>>> dataset = tf.data.Dataset.from_tensor_slices(data).repeat().batch(4)
>>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> distributed_values = next(dataset_iterator)
>>> strategy.run(tf.function(_cross_replica_concat), (distributed_values,))
PerReplica: {
0: <tf.Tensor: shape=(4, 4), dtype=int64, numpy=array([
[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]
], dtype=int64)>,
1: <tf.Tensor: shape=(4, 4), dtype=int64, numpy=array([
[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[0, 0, 0, 0]
], dtype=int64)>,
2: <tf.Tensor: shape=(4, 4), dtype=int64, numpy=array([
[2, 2, 2, 2],
[3, 3, 3, 3],
[0, 0, 0, 0],
[1, 1, 1, 1]
], dtype=int64)>,
3: <tf.Tensor: shape=(4, 4), dtype=int64, numpy=array([
[3, 3, 3, 3],
[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]
], dtype=int64)>
}
Args:
values: The current TPU core's contribution to the concatenated tensor.
Returns:
A concatenated tensor that is made up of one tensor from each TPU core.
Raises:
ValueError: The current TPU core's tensor is dynamically shaped in the
batch dimension.
"""
if values.shape[0] is None:
raise ValueError(
f"Tensor {values} should not be dynamically shaped in the batch "
"dimension."
)
context = tf.distribute.get_replica_context()
gathered = context.all_gather(values, axis=0)
return tf.roll(
gathered,
-context.replica_id_in_sync_group * values.shape[0],
axis=0
)