Skip to content

Commit

Permalink
Add a new method: trace
Browse files Browse the repository at this point in the history
  • Loading branch information
ignorejjj committed Jul 6, 2024
1 parent 2e8e17b commit 7a728a2
Show file tree
Hide file tree
Showing 10 changed files with 4,677 additions and 21 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ FlashRAG is still under development and there are many issues and room for impro


## :page_with_curl: Changelog

[24/07/06] We add support for a new method: [<u>Trace</u>](https://arxiv.org/abs/2406.11460), which refine text by constructing a knowledge graph. See it [<u>results</u>](#robot-supporting-methods) and [<u>details</u>](./docs/baseline_details.md).

[24/06/19] We add support for a new method: [<u>IRCoT</u>](https://arxiv.org/abs/2212.10509), and update the [<u>result table</u>](#robot-supporting-methods).

[24/06/15] We provide a [<u>demo</u>](./examples/quick_start/demo_en.py) to perform the RAG process using our toolkit.
Expand Down Expand Up @@ -262,7 +265,7 @@ In FlashRAG, we have built a series of common RAG components, including retrieve
<td>Calculate matching score using cross-encoder</td>
</tr>
<tr>
<td rowspan="4">Refiner</td>
<td rowspan="5">Refiner</td>
<td>Extractive Refiner</td>
<td>Refine input by extracting important context</td>
</tr>
Expand All @@ -278,6 +281,9 @@ In FlashRAG, we have built a series of common RAG components, including retrieve
<td>SelectiveContext Refiner</td>
<td><a href="https://arxiv.org/abs/2310.06201">Selective-Context</a> prompt compressor</td>
</tr>
<tr>
<td> KG Refiner </td>
<td>Use <a hred='https://arxiv.org/abs/2406.11460'>Trace method to construct a knowledge graph</td>
<tr>
<td rowspan="4">Generator</td>
<td>Encoder-Decoder Generator</td>
Expand Down Expand Up @@ -384,6 +390,7 @@ It’s important to note that, to ensure consistency, we have utilized a uniform
| [RECOMP-abstractive](https://arxiv.org/pdf/2310.04408) | Sequential | 33.1 | 56.4 | 37.5 | 32.4 | 39.9| 20.2| |
| [Selective-Context](https://arxiv.org/abs/2310.06201) | Sequential | 30.5 | 55.6 | 34.4 |18.5| 33.5| 17.3| Compress Ratio=0.5|
| [Ret-Robust](https://arxiv.org/abs/2310.01558) | Sequential | 42.9 | 68.2 | 35.8 |43.4|57.2|33.7| Use LLAMA2-13B with trained lora|
| [Trace](https://arxiv.org/abs/2406.11460) | Sequential | 30.7 | 50.2 | 34.0 | 15.5 | 37.4 | 19.9 | |
| [SuRe](https://arxiv.org/abs/2404.13081) | Branching | 37.1 | 53.2 | 33.4 |20.6|48.1|24.2| Use provided prompt|
| [REPLUG](https://arxiv.org/abs/2301.12652) | Branching | 28.9 | 57.7 | 31.2 |21.1|27.8|20.2| |
| [SKR](https://aclanthology.org/2023.findings-emnlp.691.pdf) | Conditional | 33.2 | 56.0 | 32.4 | 23.4 |31.7|17.0|Use infernece-time training data|
Expand Down
2 changes: 2 additions & 0 deletions docs/baseline_details.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@ In addition to the general settings mentioned above, each method often has its o
**SKR**: We implement the SKR-knn method, which requires an encoder model and inference-time training data. Specifically, it identifies the most similar queries from the training data based on the input query, determining whether the input query needs retrieval. Our library includes the training data provided by the authors; the corresponding encoder model can be downloaded [here](https://huggingface.co/princeton-nlp/sup-simcse-bert-base-uncased).

**Self-RAG**: We use the Llama2-7B checkpoint provided by Self-RAG [here](https://huggingface.co/selfrag/selfrag_llama2_7b), setting the max output tokens to 100 to ensure proper operation. The temperature is set to 0, and `top_p` is set to 1.

**Trace**: This method requires first extracting triples from the search results and then constructing a reasoning chain. These two steps depend on the prompt of the feed shot for LLM. Follow the original work, we use Llama3-8B-instruct to do these steps, use 3 examplars in each prompt. For datasets that don't have examplars, we use the examplars from 2WikiMultihopQA as a substitute. Other hyperparameters follow default settings in our code.
2 changes: 1 addition & 1 deletion docs/reproduce_experiment.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,5 @@ python run_exp.py --method_name 'naive' \

The method can be selected from the following:
```
naive zero-shot AAR-contriever llmlingua recomp selective-context sure replug skr flare iterretgen
naive zero-shot AAR-contriever llmlingua recomp selective-context sure replug skr flare iterretgen trace
```
30 changes: 29 additions & 1 deletion examples/methods/run_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,33 @@ def ircot(args):

result = pipeline.run(test_data)

def trace(args):
save_note = 'trace'
trace_config = {
'num_examplars': 3,
'max_chain_length': 4,
'topk_triple_select': 5, # num of candidate triples
'num_choices': 20,
'min_triple_prob': 1e-4,
'num_beams': 5, # number of selected prob at each step of constructing chain
'num_chains': 20, # number of generated chains
'n_context': 5, # number of used chains in generation
'context_type': 'triples', # triples/triple-doc
}
config_dict = {'save_note': save_note,
'gpu_id':args.gpu_id,
'dataset_name':args.dataset_name,
'trace_config': trace_config}

# preparation
config = Config('my_config.yaml',config_dict)
all_split = get_dataset(config)
test_data = all_split[args.split]
from flashrag.pipeline import SequentialPipeline
pipeline = SequentialPipeline(config)

result = pipeline.run(test_data)

if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "Running exp")
parser.add_argument('--method_name', type=str)
Expand All @@ -435,7 +462,8 @@ def ircot(args):
'selfrag': selfrag,
'flare': flare,
'iterretgen': iterretgen,
'ircot': ircot
'ircot': ircot,
'trace': trace,
}

args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion flashrag/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def __init__(self, config, prompt_template = None):

self.generator = None
if config['refiner_name'] is not None:
self.refiner = get_refiner(config)
if 'kg' in config['refiner_name'].lower():
self.generator = get_generator(config)
self.refiner = get_refiner(config, self.retriever, self.generator)
else:
self.refiner = None
self.generator = get_generator(config)
Expand Down
78 changes: 65 additions & 13 deletions flashrag/prompt/base_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ class PromptTemplate:
"\nThe following are given documents.\n\n{reference}"
base_user_prompt = "Question: {question}"

def __init__(self,
config,
system_prompt = "",
user_prompt = "",
reference_template = None,
enable_chat = True
def __init__(
self,
config,
system_prompt = "",
user_prompt = "",
reference_template = None,
enable_chat = True
):

self.config = config
Expand All @@ -37,7 +38,7 @@ def __init__(self,
self.enable_chat = enable_chat
self.reference_template = reference_template

self._check_placeholder()
#self._check_placeholder()

def _check_placeholder(self):
# check placeholder in prompt
Expand All @@ -51,12 +52,13 @@ def _check_placeholder(self):
if not flag and holder != 'reference':
assert False

def get_string(self,
question,
retrieval_result = None,
formatted_reference = None,
previous_gen = None,
**params
def get_string(
self,
question,
retrieval_result = None,
formatted_reference = None,
previous_gen = None,
**params
):

if formatted_reference is None:
Expand Down Expand Up @@ -94,6 +96,56 @@ def get_string(self,

return input

def get_string_with_varying_examplars(
self,
question,
retrieval_result = None,
formatted_reference = None,
previous_gen = None,
examplars = [],
tokenizer = None,
max_length = 2048,
**params
):
"""
Select the maximum number of examplars that can be placed in the prompt
"""

final_examplars = None
num = len(examplars)
while len(examplars) > 0:
for num in range(len(examplars), 0, -1):
possible_prompt = self.get_string(
question = question,
retrieval_result = retrieval_result,
formatted_reference = formatted_reference,
previous_gen = previous_gen,
examplars = "\n\n".join(examplars[:num]),
**params
)

possible_prompt_tokens = tokenizer.encode(possible_prompt)
if len(possible_prompt_tokens) <= max_length:
final_examplars = examplars[:num]
break
if final_examplars is None:
examplars = examplars[1:]
else:
break
if final_examplars is None:
final_examplars = []

final_prompt = self.get_string(
question = question,
retrieval_result = retrieval_result,
formatted_reference = formatted_reference,
previous_gen = previous_gen,
examplars = "\n\n".join(final_examplars[:num]),
**params
)

return final_prompt


def format_reference(self, retrieval_result):
format_reference = ''
Expand Down
Loading

0 comments on commit 7a728a2

Please sign in to comment.