Skip to content

Commit

Permalink
Format code files for readability
Browse files Browse the repository at this point in the history
  • Loading branch information
ignorejjj committed Jul 30, 2024
1 parent 95d9fcc commit 785a901
Show file tree
Hide file tree
Showing 38 changed files with 2,280 additions and 2,463 deletions.
392 changes: 206 additions & 186 deletions examples/methods/run_exp.py

Large diffs are not rendered by default.

40 changes: 17 additions & 23 deletions examples/methods/utils/get_lm_probs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,9 @@ def run(self, dataset):
input_prompt = self.prompt_template.get_string(question=q, retrieval_result=[res])
score = self.calculate_prob(input_prompt, answer)
scores.append(score)
docs.append(res['contents'])
docs.append(res["contents"])
scores = torch.softmax(torch.tensor(scores), dim=-1).tolist()
data_ls.append(
{
"query": q,
"pos": docs,
"scores": scores
}
)
data_ls.append({"query": q, "pos": docs, "scores": scores})
return data_ls

def calculate_prob(self, prompt, answer):
Expand All @@ -57,31 +51,31 @@ def calculate_prob(self, prompt, answer):


def main(
dataset_name='nq', # qa dataset
split='test', # split
num=4000, # number of query-document pairs
gpu_id='0',
output="lmsft.jsonl", # output path
topk=20, # number of retrieved documents
dataset_name="nq", # qa dataset
split="test", # split
num=4000, # number of query-document pairs
gpu_id="0",
output="lmsft.jsonl", # output path
topk=20, # number of retrieved documents
):
config_dict = {
'save_note': "replug_lsr",
'gpu_id': gpu_id,
'dataset_name': dataset_name,
'test_sample_num': num,
'split': ['test', 'dev'],
"save_note": "replug_lsr",
"gpu_id": gpu_id,
"dataset_name": dataset_name,
"test_sample_num": num,
"split": ["test", "dev"],
"retrieval_topk": topk,
"framework": "hf"
"framework": "hf",
}
config = Config('my_config.yaml', config_dict)
config = Config("my_config.yaml", config_dict)
all_split = get_dataset(config)
test_data = all_split[split]
lm_prob_calculator = LMProbCalculator(config)
data_ls = lm_prob_calculator.run(test_data)
with open(output, 'w') as f:
with open(output, "w") as f:
for data in data_ls:
f.write(json.dumps(data, ensure_ascii=False) + "\n")


if __name__ == '__main__':
if __name__ == "__main__":
fire.Fire(main)
57 changes: 33 additions & 24 deletions examples/quick_start/demo_en.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,32 @@
from flashrag.prompt import PromptTemplate


config_dict = {"save_note":"demo",
'model2path': {'e5': 'intfloat/e5-base-v2', 'llama3-8B-instruct': 'meta-llama/Meta-Llama-3-8B-Instruct'},
"retrieval_method":"e5",
'generator_model': 'llama3-8B-instruct',
"corpus_path":"indexes/general_knowledge.jsonl",
"index_path":"indexes/e5_Flat.index"}
config_dict = {
"save_note": "demo",
"model2path": {"e5": "intfloat/e5-base-v2", "llama3-8B-instruct": "meta-llama/Meta-Llama-3-8B-Instruct"},
"retrieval_method": "e5",
"generator_model": "llama3-8B-instruct",
"corpus_path": "indexes/general_knowledge.jsonl",
"index_path": "indexes/e5_Flat.index",
}


@st.cache_resource
def load_retriever(_config):
return get_retriever(_config)


@st.cache_resource
def load_generator(_config):
return get_generator(_config)


custom_theme = {
"primaryColor": "#ff6347",
"backgroundColor": "#f0f0f0",
"secondaryBackgroundColor": "#d3d3d3",
"textColor": "#121212",
"font": "sans serif"
"font": "sans serif",
}
st.set_page_config(page_title="FlashRAG Demo", page_icon="⚡")

Expand All @@ -40,15 +45,19 @@ def load_generator(_config):

query = st.text_area("Enter your prompt:")

config = Config('my_config.yaml', config_dict=config_dict)
config = Config("my_config.yaml", config_dict=config_dict)
retriever = load_retriever(config)
generator = load_generator(config)

system_prompt_rag = "You are a friendly AI Assistant." \
"Respond to the input as a friendly AI assistant, generating human-like text, and follow the instructions in the input if applicable." \
"\nThe following are provided references. You can use them for answering question.\n\n{reference}"
system_prompt_no_rag = "You are a friendly AI Assistant." \
"Respond to the input as a friendly AI assistant, generating human-like text, and follow the instructions in the input if applicable.\n"
system_prompt_rag = (
"You are a friendly AI Assistant."
"Respond to the input as a friendly AI assistant, generating human-like text, and follow the instructions in the input if applicable."
"\nThe following are provided references. You can use them for answering question.\n\n{reference}"
)
system_prompt_no_rag = (
"You are a friendly AI Assistant."
"Respond to the input as a friendly AI assistant, generating human-like text, and follow the instructions in the input if applicable.\n"
)
base_user_prompt = "{question}"

prompt_template_rag = PromptTemplate(config, system_prompt=system_prompt_rag, user_prompt=base_user_prompt)
Expand All @@ -57,27 +66,27 @@ def load_generator(_config):

if st.button("Generate Responses"):
with st.spinner("Retrieving and Generating..."):
retrieved_docs = retriever.search(query,num=topk)
retrieved_docs = retriever.search(query, num=topk)

st.subheader("References",divider='gray')
st.subheader("References", divider="gray")
for i, doc in enumerate(retrieved_docs):
doc_title = doc.get('title','No Title')
doc_text = "\n".join(doc['contents'].split("\n")[1:])
doc_title = doc.get("title", "No Title")
doc_text = "\n".join(doc["contents"].split("\n")[1:])
expander = st.expander(f"**[{i+1}]: {doc_title}**", expanded=False)
with expander:
st.markdown(doc_text, unsafe_allow_html=True)

st.subheader("Generated Responses:",divider='gray')
st.subheader("Generated Responses:", divider="gray")

input_prompt_with_rag = prompt_template_rag.get_string(question=query, retrieval_result=retrieved_docs)
response_with_rag = generator.generate(input_prompt_with_rag,
temperature=temperature,
max_new_tokens=max_new_tokens)[0]
response_with_rag = generator.generate(
input_prompt_with_rag, temperature=temperature, max_new_tokens=max_new_tokens
)[0]
st.subheader("Response with RAG:")
st.write(response_with_rag)
input_prompt_without_rag = prompt_template_no_rag.get_string(question=query)
response_without_rag = generator.generate(input_prompt_without_rag,
temperature=temperature,
max_new_tokens=max_new_tokens)[0]
response_without_rag = generator.generate(
input_prompt_without_rag, temperature=temperature, max_new_tokens=max_new_tokens
)[0]
st.subheader("Response without RAG:")
st.markdown(response_without_rag)
56 changes: 32 additions & 24 deletions examples/quick_start/demo_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,32 @@
from flashrag.utils import get_retriever, get_generator
from flashrag.prompt import PromptTemplate

config_dict = {"save_note":"demo",
"generator_model":"qwen-14B",
"retrieval_method":"bge-zh",
"model2path": {"bge-zh":"BAAI/bge-large-zh-v1.5", "qwen-14B":"Qwen/Qwen1.5-14B-Chat"},
"corpus_path":"/data00/jiajie_jin/rd_corpus.jsonl",
"index_path":"/data00/jiajie_jin/flashrag_indexes/rd_corpus/bge_Flat.index"}
config_dict = {
"save_note": "demo",
"generator_model": "qwen-14B",
"retrieval_method": "bge-zh",
"model2path": {"bge-zh": "BAAI/bge-large-zh-v1.5", "qwen-14B": "Qwen/Qwen1.5-14B-Chat"},
"corpus_path": "/data00/jiajie_jin/rd_corpus.jsonl",
"index_path": "/data00/jiajie_jin/flashrag_indexes/rd_corpus/bge_Flat.index",
}


@st.cache_resource
def load_retriever(_config):
return get_retriever(_config)


@st.cache_resource
def load_generator(_config):
return get_generator(_config)


custom_theme = {
"primaryColor": "#ff6347",
"backgroundColor": "#f0f0f0",
"secondaryBackgroundColor": "#d3d3d3",
"textColor": "#121212",
"font": "sans serif"
"font": "sans serif",
}
st.set_page_config(page_title="FlashRAG Demo", page_icon="⚡")

Expand All @@ -40,15 +45,18 @@ def load_generator(_config):

query = st.text_area("Enter your prompt:")

config = Config('my_config.yaml', config_dict=config_dict)
config = Config("my_config.yaml", config_dict=config_dict)
retriever = load_retriever(config)
generator = load_generator(config)

system_prompt_rag = "你是一个友好的人工智能助手。" \
"请对用户的输出做出高质量的响应,生成类似于人类的内容,并尽量遵循输入中的指令。" \
"\n下面是一些可供参考的文档,你可以使用它们来回答问题。\n\n{reference}"
system_prompt_no_rag = "你是一个友好的人工智能助手。" \
"请对用户的输出做出高质量的响应,生成类似于人类的内容,并尽量遵循输入中的指令。\n"
system_prompt_rag = (
"你是一个友好的人工智能助手。"
"请对用户的输出做出高质量的响应,生成类似于人类的内容,并尽量遵循输入中的指令。"
"\n下面是一些可供参考的文档,你可以使用它们来回答问题。\n\n{reference}"
)
system_prompt_no_rag = (
"你是一个友好的人工智能助手。" "请对用户的输出做出高质量的响应,生成类似于人类的内容,并尽量遵循输入中的指令。\n"
)
base_user_prompt = "{question}"

prompt_template_rag = PromptTemplate(config, system_prompt=system_prompt_rag, user_prompt=base_user_prompt)
Expand All @@ -57,27 +65,27 @@ def load_generator(_config):

if st.button("Generate Responses"):
with st.spinner("Retrieving and Generating..."):
retrieved_docs = retriever.search(query,num=topk)
retrieved_docs = retriever.search(query, num=topk)

st.subheader("References",divider='gray')
st.subheader("References", divider="gray")
for i, doc in enumerate(retrieved_docs):
doc_title = doc.get('title','No Title')
doc_text = "\n".join(doc['contents'].split("\n")[1:])
doc_title = doc.get("title", "No Title")
doc_text = "\n".join(doc["contents"].split("\n")[1:])
expander = st.expander(f"**[{i+1}]: {doc_title}**", expanded=False)
with expander:
st.markdown(doc_text, unsafe_allow_html=True)

st.subheader("Generated Responses:",divider='gray')
st.subheader("Generated Responses:", divider="gray")

input_prompt_with_rag = prompt_template_rag.get_string(question=query, retrieval_result=retrieved_docs)
response_with_rag = generator.generate(input_prompt_with_rag,
temperature=temperature,
max_new_tokens=max_new_tokens)[0]
response_with_rag = generator.generate(
input_prompt_with_rag, temperature=temperature, max_new_tokens=max_new_tokens
)[0]
st.subheader("Response with RAG:")
st.write(response_with_rag)
input_prompt_without_rag = prompt_template_no_rag.get_string(question=query)
response_without_rag = generator.generate(input_prompt_without_rag,
temperature=temperature,
max_new_tokens=max_new_tokens)[0]
response_without_rag = generator.generate(
input_prompt_without_rag, temperature=temperature, max_new_tokens=max_new_tokens
)[0]
st.subheader("Response without RAG:")
st.markdown(response_without_rag)
36 changes: 18 additions & 18 deletions examples/quick_start/simple_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,35 @@
from flashrag.prompt import PromptTemplate

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str)
parser.add_argument('--retriever_path', type=str)
parser.add_argument("--model_path", type=str)
parser.add_argument("--retriever_path", type=str)
args = parser.parse_args()

config_dict = {
'data_dir': 'dataset/',
'index_path': 'indexes/e5_Flat.index',
'corpus_path': 'indexes/general_knowledge.jsonl',
'model2path': {'e5': args.retriever_path, 'llama3-8B-instruct': args.model_path},
'generator_model': 'llama3-8B-instruct',
'retrieval_method': 'e5',
'metrics': ['em','f1','acc'],
'retrieval_topk': 1,
'save_intermediate_data': True
}
"data_dir": "dataset/",
"index_path": "indexes/e5_Flat.index",
"corpus_path": "indexes/general_knowledge.jsonl",
"model2path": {"e5": args.retriever_path, "llama3-8B-instruct": args.model_path},
"generator_model": "llama3-8B-instruct",
"retrieval_method": "e5",
"metrics": ["em", "f1", "acc"],
"retrieval_topk": 1,
"save_intermediate_data": True,
}

config = Config(config_dict = config_dict)
config = Config(config_dict=config_dict)

all_split = get_dataset(config)
test_data = all_split['test']
test_data = all_split["test"]
prompt_templete = PromptTemplate(
config,
system_prompt = "Answer the question based on the given document. \
system_prompt="Answer the question based on the given document. \
Only give me the answer and do not output any other words. \
\nThe following are given documents.\n\n{reference}",
user_prompt = "Question: {question}\nAnswer:"
user_prompt="Question: {question}\nAnswer:",
)
pipeline = SequentialPipeline(config, prompt_template=prompt_templete)

output_dataset = pipeline.run(test_data,do_eval=True)
output_dataset = pipeline.run(test_data, do_eval=True)
print("---generation output---")
print(output_dataset.pred)
print(output_dataset.pred)
Loading

0 comments on commit 785a901

Please sign in to comment.