Skip to content
/ UAR Public

[Findings of EMNLP'2024] Unified Active Retrieval for Retrieval Augmented Generation

Notifications You must be signed in to change notification settings

xiami2019/UAR

Repository files navigation

UAR: Unified Active Retrieval for Retrieval Augmented Generation

This is the official repository for the paper Unified Active Retrieval for Retrieval Augmented Generation.

Abstract

In Retrieval-Augmented Generation (RAG), retrieval is not always helpful and applying it to every instruction is sub-optimal. Therefore, determining whether to retrieve is crucial for RAG, which is usually referred to as Active Retrieval.
We propose Unified Active Retrieval (UAR). UAR contains four orthogonal criteria and casts them into plug-and-play classification tasks, which achieves multifaceted retrieval timing judgements with negligible extra inference cost. We further introduce the Unified Active Retrieval Criteria (UAR-Criteria), designed to process diverse active retrieval scenarios through a standardized procedure. Experiments on four representative types of user instructions show that UAR significantly outperforms existing work on the retrieval timing judgement and the performance of downstream tasks, which shows the effectiveness of UAR and its helpfulness to downstream tasks.

image

Training

Preparing the Environmen

git clone https://github.com/xiami2019/UAR.git
cd UAR
pip install -U pip setuptools
pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 -e .

Preparethe Dataset

unzip training_data.zip
unzip benchmarks.zip
unzip results.zip

Classifier Training

python llama_recipes/finetuning.py \
    --model_name meta-llama/Llama-2-7b-chat-hf\
    --output_dir time_aware_llama2-7b-chat \
    --dataset time_aware_cls_ce \
    --batch_size_training 32 \
    --batching_strategy padding \
    --lr 5e-5 \
    --num_epochs 10 \
    --reward_model_loss_type "ce" \
    --only_cls_for_rmce
python llama_recipes/finetuning.py \
    --model_name meta-llama/Llama-2-7b-chat-hf\
    --output_dir knowledge_aware_llama2-7b-chat \
    --dataset knowledge_aware_cls_ce \
    --batch_size_training 32 \
    --batching_strategy padding \
    --lr 5e-5 \
    --num_epochs 10 \
    --reward_model_loss_type "ce" \
    --only_cls_for_rmce
python llama_recipes/finetuning.py \
    --model_name meta-llama/Llama-2-7b-chat-hf\
    --output_dir self_aware_llama2-7b-chat \
    --dataset self_aware_cls_ce_llama2_7b_chat \
    --batch_size_training 32 \
    --batching_strategy padding \
    --lr 5e-5 \
    --num_epochs 10 \
    --reward_model_loss_type "ce" \
    --only_cls_for_rmce
python llama_recipes/finetuning.py \
    --model_name meta-llama/Llama-2-7b-chat-hf\
    --output_dir intent_aware_llama2-7b-chat \
    --dataset intent_aware_cls_ce \
    --batch_size_training 32 \
    --batching_strategy padding \
    --lr 5e-5 \
    --num_epochs 10 \
    --reward_model_loss_type "ce" \
    --only_cls_for_rmce

Evaluation

Active Retrieval Judgement

AR-Bench Inference

python uar_infer.py \
    --model_name meta-llama/Llama-2-7b-chat-hf \
    --prompt_file benchmarks/AR_bench/ar_bench_llama2-7b-chat.json \
    --save_name results/AR_bench/my_ar_bench_7b_uar_output.json \
    --data_type normal \
    --batch_size 8

GSM8K Inference

python uar_infer.py \
    --model_name meta-llama/Llama-2-7b-chat-hf \
    --prompt_file benchmarks/downstream_tasks/gsm8k_test_with_ret.json \
    --save_name results/downstream_tasks/my_gsm8k_test_llama2_7b_chat_uar.json \
    --data_type gsm8k \
    --batch_size 8

Drop Inference

python uar_infer.py \
    --model_name meta-llama/Llama-2-7b-chat-hf \
    --prompt_file benchmarks/downstream_tasks/drop_dataset_dev_passage_qa_with_ret.json \
    --save_name results/downstream_tasks/my_drop_output_llama2_7b_chat_uar.json \
    --data_type drop \
    --batch_size 8

TriviaQA & WQ & TAQA & FreshQA

python uar_infer.py \
    --model_name meta-llama/Llama-2-7b-chat-hf \
    --prompt_file benchmarks/downstream_tasks/triviaqa_with_ret.json \
    --save_name results/downstream_tasks/my_triviaqa_llama2_7b_chat_results_uar.json \
    --data_type normal \
    --batch_size 8

Response Generation

GSM8K

python vllm_infer.py \
    --model_path  meta-llama/Llama-2-7b-chat-hf\
    --input_file results/downstream_tasks/gsm8k_test_llama2_7b_chat_uar.json \
    --output_file results/downstream_tasks/gsm8k_test_llama2_7b_chat_uar_generation_results.json \
    --data_type gsm8k

Drop

python vllm_infer.py \
    --model_path  meta-llama/Llama-2-7b-chat-hf\
    --input_file results/downstream_tasks/drop_output_llama2_7b_chat_uar.json \
    --output_file results/downstream_tasks/drop_output_llama2_7b_chat_uar_generation_results.json \
    --data_type drop

TriviaQA & WQ & TAQA & FreshQA

python vllm_infer.py \
    --model_path  meta-llama/Llama-2-7b-chat-hf\
    --input_file results/downstream_tasks/triviaqa_llama2_7b_chat_results_uar.json \
    --output_file results/downstream_tasks/triviaqa_llama2_7b_chat_results_uar_generation_results.json \
    --data_type normal

Compute Results

Downstream tasks

Drop
python /evaluations/drop_eval.py \
    --gold_path benchmarks/downstream_tasks/drop_dataset_dev.json \
    --prediction_path results/downstream_tasks/drop_output_llama2_7b_chat_uar.json \
    --output_path results/downstream_tasks/drop_output_llama2_7b_chat_uar_eval_output.json
GSM8K
python /evaluations/gsm8k_eval.py \
    --file_name results/downstream_tasks/gsm8k_test_llama2_7b_chat_uar.json
TriviaQA & WQ
python /evaluations/em_eval.py \
    --file_name results/downstream_tasks/triviaqa_llama2_7b_chat_results_uar.json
TAQA & FreshQA

Directly compute accuracy on TAQA dataset using our provided ChatGPT evaluation reuslts.

python evaluations/chatgpt_acc.py --input_file results/downstream_tasks/freshqa_without_false_premise_time_change_llama2_7b_chat_als_ret.json --only_cal_acc

Evaluate using ChatGPT API: First provide your openai_api_key and base_url in evaluations/api_keys_config.json, then:

python evaluations/chatgpt_acc.py --input_file results/downstream_tasks/freshqa_without_false_premise_time_change_llama2_7b_chat_als_ret.json --output_file results/downstream_tasks/test.json

We use GPTWrapper for ChatGPT API calling. Thanks to Mianqiu~

AR-Bench

python evaluations/cal_ar_acc.py \
    --file_name results/AR_bench/ar_bench_llama2-7b-chat_uar_output.json

About

[Findings of EMNLP'2024] Unified Active Retrieval for Retrieval Augmented Generation

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages