Skip to content

Commit

Permalink
Merge pull request shmsw25#20 from shmsw25/atomic-refactor
Browse files Browse the repository at this point in the history
Moving abstain detection to a separate module.
  • Loading branch information
shmsw25 authored Jun 25, 2023
2 parents c0e98d2 + 6b24a12 commit 2e608ea
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 70 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ This command does the following.

## Running FActScore using a command line

We expect running FActScore costs about $1 of the API cost per 100 sentences. For instance, if you have 100 generations, each with 5 sentences on average, it costs $5 in total.
We expect running FActScore costs about $1 of the API cost per 100 sentences. For instance, if you have 100 generations, each with 5 sentences on average, it costs $5 in total.

```bash
python -m factscore.factscorer --input_path {input_path} --model_name {estimator_name} --openai_key {openai_key}
Expand All @@ -72,6 +72,7 @@ python -m factscore.factscorer --input_path {input_path} --model_name {estimator
- `--verbose`: If specified, it shows the progress bar.
- `--print_rate_limit_error`: It specified, it prints out rate limit errors from OpenAI API.
- `--cost_estimate`: This flag decides the type of OpenAI API cost estimation that we provide before calling it. It can be `"consider_cache"` (default) or `"ignore_cache"`.
- `--abstain_detection`: This flag optionally enables automatic detection of abstained responses. By default this is disabled, but it is recommended to add your own function tailored to your model. The currently supported detectors are `"generic"` and `"perplexity_ai"`, and their implementations can be found in [`factscore/abstain_detection.py`](factscore/abstain_detection.py). There are two methods to add your own abstain function: a) clone our GitHub repository to install `factscore` locally (`pip install --editable .`), and then add your function to [`factscore/abstain_detection.py`](factscore/abstain_detection.py) directly; b) process your abstain detection outside our package, and use empty strings in the `output` key for the JSONL file used in `--input_path`.

This command uses the English Wikipedia from 2023/04/01 as a knowledge source. See [this section](#To-use-a-custom-knowledge-source) to use your own database as a knowledge source!

Expand Down
58 changes: 58 additions & 0 deletions factscore/abstain_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import re

invalid_ppl_mentions = [
"I could not find any information",
"The search results do not provide",
"There is no information",
"There are no search results",
"there are no provided search results",
"not provided in the search results",
"is not mentioned in the provided search results",
"There seems to be a mistake in the question",
"Not sources found",
"No sources found",
"Try a more general question"
]

def remove_citation(text):
# text = re.sub(r'\[\d+\]', '', text)
text = re.sub(r"\s*\[\d+\]\s*","", text)
if text.startswith("According to , "):
text = text.replace("According to , ", "According to the search results, ")
return text

def is_invalid_ppl(text):
return np.any([text.lower().startswith(mention.lower()) for mention in invalid_ppl_mentions])

def is_invalid_paragraph_ppl(text):
return len(text.strip())==0 or np.any([mention.lower() in text.lower() for mention in invalid_ppl_mentions])

def perplexity_ai_abstain_detect(generation):
output = remove_citation(generation)
if is_invalid_ppl(output):
return True
valid_paras = []
for para in output.split("\n\n"):
if is_invalid_paragraph_ppl(para):
break
valid_paras.append(para.strip())

if len(valid_paras) == 0:
return True
else:
return False

def generic_abstain_detect(generation):
return generation.startswith("I'm sorry") or "provide more" in generation

def is_response_abstained(generation, fn_type):
if fn_type == "perplexity_ai":
return perplexity_ai_abstain_detect(generation)

elif fn_type == "generic":
return generic_abstain_detect(generation)

else:
return False

72 changes: 4 additions & 68 deletions factscore/atomic_facts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@


class AtomicFactGenerator(object):
def __init__(self, key_path, demon_dir, model_name=None, gpt3_cache_file=None):

self.model = model_name
if model_name:
self.preprocess_fn = functools.partial(preprocess_fn, model=model_name)
else:
self.preprocess_fn = None
def __init__(self, key_path, demon_dir, gpt3_cache_file=None):
self.nlp = spacy.load("en_core_web_sm")
self.is_bio = True
self.demon_path = os.path.join(demon_dir, "demons.json" if self.is_bio else "demons_complex.json")
Expand All @@ -43,11 +37,8 @@ def save_cache(self):

def run(self, generation, cost_estimate=None):
"""Convert the generation into a set of atomic facts. Return a total words cost if cost_estimate != None."""
if self.preprocess_fn:
paragraphs = self.preprocess(generation)
else:
paragraphs = [para.strip() for para in generation.split("\n") if len(para.strip()) > 0]

assert isinstance(generation, str), "generation must be a string"
paragraphs = [para.strip() for para in generation.split("\n") if len(para.strip()) > 0]
return self.get_atomic_facts_from_paragraph(paragraphs, cost_estimate=cost_estimate)

def get_atomic_facts_from_paragraph(self, paragraphs, cost_estimate=None):
Expand Down Expand Up @@ -154,35 +145,6 @@ def get_init_atomic_facts_from_sentence(self, sentences, cost_estimate=None):
return atoms


def preprocess_fn(generation, model):
if model in ["instruct", "gpt4", "vicuna-7b", "vicuna-13b", "chatgpt"]:
if not generation.startswith("I'm sorry") and not "provide more" in generation:
paragraphs = [para.strip() for para in generation.split("\n") if len(para.strip()) > 0]
else:
return None

elif model == "perplexity":
output = remove_citation(generation)
if is_invalid_ppl(output):
return None
paragraphs = []
for para in output.split("\n\n"):
if is_invalid_paragraph_ppl(para):
break
paragraphs.append(para.strip())

if len(paragraphs) == 0:
return None

elif model in ["mpt-7b", "stablelm-alpha-7b"]:
if not "sorry" in generation and not "provide" in generation.split(" "):
paragraphs = [para.strip() for para in generation.split("\n") if len(para.strip()) > 0]

else:
paragraphs = [para.strip() for para in generation.split("\n") if len(para.strip()) > 0]

return paragraphs

def best_demos(query, bm25, demons_sents, k):
tokenized_query = query.split(" ")
top_machings = bm25.get_top_n(tokenized_query, demons_sents, k)
Expand Down Expand Up @@ -333,32 +295,6 @@ def is_integer(s):
except Exception:
return False

def remove_citation(text):
# text = re.sub(r'\[\d+\]', '', text)
text = re.sub(r"\s*\[\d+\]\s*","", text)
if text.startswith("According to , "):
text = text.replace("According to , ", "According to the search results, ")
return text

invalid_ppl_mentions = [
"I could not find any information",
"The search results do not provide",
"There is no information",
"There are no search results",
"there are no provided search results",
"not provided in the search results",
"is not mentioned in the provided search results",
"There seems to be a mistake in the question",
"Not sources found",
"Try a more general question"
]

def is_invalid_ppl(text):
return np.any([text.lower().startswith(mention.lower()) for mention in invalid_ppl_mentions])

def is_invalid_paragraph_ppl(text):
return len(text.strip())==0 or np.any([mention.lower() in text.lower() for mention in invalid_ppl_mentions])

def detect_initials(text):
pattern = r"[A-Z]\. ?[A-Z]\."
match = re.findall(pattern, text)
Expand Down Expand Up @@ -399,7 +335,7 @@ def fix_sentence_splitter(curr_sentences, initials):


def main():
generator = AtomicFactGenerator("api.key", "demos", model_name=None, gpt3_cache_dir=None)
generator = AtomicFactGenerator("api.key", "demos", gpt3_cache_dir=None)
atomic_facts, para_breaks = generator.run("Thierry Henry (born 17 August 1977) is a French professional football coach, pundit, and former player. He is considered one of the greatest strikers of all time, and one the greatest players of the Premier League history. He has been named Arsenal F.C's greatest ever player.\n\nHenry made his professional debut with Monaco in 1994 before signing for defending Serie A champions Juventus. However, limited playing time, coupled with disagreements with the club's hierarchy, led to him signing for Premier League club Arsenal for £11 million in 1999.")

print(atomic_facts)
Expand Down
16 changes: 15 additions & 1 deletion factscore/factscorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging

from tqdm import tqdm
from factscore.abstain_detection import is_response_abstained
from factscore.atomic_facts import AtomicFactGenerator
from factscore.clm import CLM
from factscore.npm import NPM
Expand All @@ -21,6 +22,7 @@ def __init__(self,
cache_dir=".cache/factscore",
openai_key="api.key",
cost_estimate="consider_cache",
abstain_detection_type=None,
batch_size=256):
assert model_name in ["retrieval+llama", "retrieval+llama+npm", "retrieval+ChatGPT", "npm", "retrieval+ChatGPT+npm"]
self.model_name = model_name
Expand All @@ -30,6 +32,7 @@ def __init__(self,
self.npm = {}
self.batch_size = batch_size # batch size for retrieval
self.openai_key = openai_key
self.abstain_detection_type = abstain_detection_type

self.data_dir = data_dir
self.cache_dir = cache_dir
Expand Down Expand Up @@ -141,6 +144,12 @@ def get_score(self,

atomic_facts = []
for topic, gen in zip(topics, generations):
# optionally, first detect if the response is abstained
response_abstained = is_response_abstained(gen, self.abstain_detection_type)
if response_abstained:
atomic_facts.append(None)
continue
# continue only when the response is not abstained
curr_afs, _ = self.af_generator.run(gen)
curr_afs = [fact for _, facts in curr_afs for fact in facts]
if len(curr_afs)==0:
Expand Down Expand Up @@ -271,6 +280,10 @@ def _get_score(self, topic, generation, atomic_facts, knowledge_source, cost_est
type=str,
default="consider_cache",
choices=["consider_cache", "ignore_cache"])
parser.add_argument('--abstain_detection_type',
type=str,
default=None,
choices=["perplexity_ai", "generic", "none"])
parser.add_argument('--use_atomic_facts',
action="store_true")
parser.add_argument('--verbose',
Expand All @@ -294,7 +307,8 @@ def _get_score(self, topic, generation, atomic_facts, knowledge_source, cost_est
model_dir=args.model_dir,
cache_dir=args.cache_dir,
openai_key=args.openai_key,
cost_estimate=args.cost_estimate)
cost_estimate=args.cost_estimate,
abstain_detection_type=args.abstain_detection_type)

tot = 0
topics, generations, atomic_facts = [], [], []
Expand Down

0 comments on commit 2e608ea

Please sign in to comment.