Skip to content

Commit

Permalink
Fix bug in saving retrieval cache
Browse files Browse the repository at this point in the history
  • Loading branch information
ignorejjj committed Aug 1, 2024
1 parent d0f403e commit 18a727e
Showing 1 changed file with 47 additions and 15 deletions.
62 changes: 47 additions & 15 deletions flashrag/retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def wrapper(self, query_list, num=None, return_score=False):
if query in self.cache:
cache_res = self.cache[query]
if len(cache_res) < num:
warnings.warn(f"The number of cached retrieval results is less than topk ({num})")
warnings.warn(
f"The number of cached retrieval results is less than topk ({num})"
)
cache_res = cache_res[:num]
# separate the doc score
doc_scores = [item.pop("score") for item in cache_res]
Expand All @@ -44,27 +46,41 @@ def wrapper(self, query_list, num=None, return_score=False):

if no_cache_query != []:
# use batch search without decorator
no_cache_results, no_cache_scores = self._batch_search_with_rerank(no_cache_query, num, True)
no_cache_results, no_cache_scores = (
self._batch_search_with_rerank(no_cache_query, num, True)
)
no_cache_idx = 0
for idx, res in enumerate(cache_results):
if res is None:
assert new_query_list[idx] == no_cache_query[no_cache_idx]
cache_results = (no_cache_results[no_cache_idx], no_cache_scores[no_cache_scores])
assert (
new_query_list[idx] == no_cache_query[no_cache_idx]
)
cache_results = (
no_cache_results[no_cache_idx],
no_cache_scores[no_cache_scores],
)
no_cache_idx += 1

results, scores = ([t[0] for t in cache_results], [t[1] for t in cache_results])
results, scores = (
[t[0] for t in cache_results],
[t[1] for t in cache_results],
)

else:
results, scores = func(self, query_list, num, True)

if self.save_cache:
# merge result and score
save_results = results.copy()
save_scores = scores.copy()
if isinstance(query_list, str):
query_list = [query_list]
if "batch" not in func.__name__:
results = [results]
scores = [scores]
for query, doc_items, doc_scores in zip(query_list, results, scores):
save_results = [save_results]
save_scores = [save_scores]
for query, doc_items, doc_scores in zip(
query_list, save_results, save_scores
):
for item, score in zip(doc_items, doc_scores):
item["score"] = score
self.cache[query] = doc_items
Expand Down Expand Up @@ -118,7 +134,9 @@ def __init__(self, config):
self.reranker = get_reranker(config)

if self.save_cache:
self.cache_save_path = os.path.join(config["save_dir"], "retrieval_cache.json")
self.cache_save_path = os.path.join(
config["save_dir"], "retrieval_cache.json"
)
self.cache = {}
if self.use_cache:
assert self.cache_path is not None
Expand All @@ -129,7 +147,9 @@ def _save_cache(self):
with open(self.cache_save_path, "w") as f:
json.dump(self.cache, f, indent=4)

def _search(self, query: str, num: int, return_score: bool) -> List[Dict[str, str]]:
def _search(
self, query: str, num: int, return_score: bool
) -> List[Dict[str, str]]:
r"""Retrieve topk relevant documents in corpus.
Return:
Expand Down Expand Up @@ -181,7 +201,9 @@ def _check_contain_doc(self):
r"""Check if the index contains document content"""
return self.searcher.doc(0).raw() is not None

def _search(self, query: str, num: int = None, return_score=False) -> List[Dict[str, str]]:
def _search(
self, query: str, num: int = None, return_score=False
) -> List[Dict[str, str]]:
if num is None:
num = self.topk
hits = self.searcher.search(query, num)
Expand All @@ -198,7 +220,10 @@ def _search(self, query: str, num: int = None, return_score=False) -> List[Dict[
hits = hits[:num]

if self.contain_doc:
all_contents = [json.loads(self.searcher.doc(hit.docid).raw())["contents"] for hit in hits]
all_contents = [
json.loads(self.searcher.doc(hit.docid).raw())["contents"]
for hit in hits
]
results = [
{
"title": content.split("\n")[0].strip('"'),
Expand Down Expand Up @@ -276,7 +301,9 @@ def _search(self, query: str, num: int = None, return_score=False):
else:
return results

def _batch_search(self, query_list: List[str], num: int = None, return_score=False):
def _batch_search(
self, query_list: List[str], num: int = None, return_score=False
):
if isinstance(query_list, str):
query_list = [query_list]
if num is None:
Expand All @@ -287,7 +314,9 @@ def _batch_search(self, query_list: List[str], num: int = None, return_score=Fal
results = []
scores = []

for start_idx in tqdm(range(0, len(query_list), batch_size), desc="Retrieval process: "):
for start_idx in tqdm(
range(0, len(query_list), batch_size), desc="Retrieval process: "
):
query_batch = query_list[start_idx : start_idx + batch_size]
batch_emb = self.encoder.encode(query_batch)
batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
Expand All @@ -296,7 +325,10 @@ def _batch_search(self, query_list: List[str], num: int = None, return_score=Fal

flat_idxs = sum(batch_idxs, [])
batch_results = load_docs(self.corpus, flat_idxs)
batch_results = [batch_results[i * num : (i + 1) * num] for i in range(len(batch_idxs))]
batch_results = [
batch_results[i * num : (i + 1) * num]
for i in range(len(batch_idxs))
]

scores.extend(batch_scores)
results.extend(batch_results)
Expand Down

0 comments on commit 18a727e

Please sign in to comment.