Skip to content

Commit

Permalink
Merge pull request #21 from R1j1t/dev
Browse files Browse the repository at this point in the history
- updated README to correctly represent the spacy extension response and minor changes
- Added validation is max_edit_dist param and raise ValueError if fails
- update vocab loading from static vocab file to load from transformers model
  • Loading branch information
R1j1t authored Sep 6, 2020
2 parents 621abe8 + 7bd844d commit be9c064
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,4 @@ peter's code/
contextualSpellCheck/tests/debugFile.txt

# vs code ignore
.vscode/
.vscode/
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Also, please install the dependencies from requirements.txt

## Usage

**Note:** For other language examples check `examples` folder.
**Note:** For other language examples check [`examples`](https://github.com/R1j1t/contextualSpellCheck/tree/master/examples) folder.

### How to load the package in spacy pipeline

Expand Down Expand Up @@ -118,21 +118,21 @@ To make the usage simpler spacy provides custom extensions which a library can u
| doc._.performed_spellCheck | `Boolean` | To check whether contextualSpellCheck identified any misspells and performed correction | `False` |
| doc._.suggestions_spellCheck | `{Spacy.Token:str}` | if corrections are performed, it returns the mapping of misspell token (`spaCy.Token`) with suggested word(`str`) | `{}` |
| doc._.outcome_spellCheck | `str` | corrected sentence(`str`) as output | `""` |
| doc._.score_spellCheck | `{Spacy.Token:List(str,float)}` | if corrections are performed, it returns the mapping of misspell token (`spaCy.Token`) with suggested words(`str`) and probability of that correction | `None` |
| doc._.score_spellCheck | `{Spacy.Token:List(str,float)}` | if corrections are identified, it returns the mapping of misspell token (`spaCy.Token`) with suggested words(`str`) and probability of that correction | `None` |

### `spaCy.Span` level extensions
| Extension | Type | Description | Default |
|-------------------------------|---------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------|
| span._.get_has_spellCheck | `Boolean` | To check whether contextualSpellCheck identified any misspells and performed correction in this span | `False` |
| span._.score_spellCheck | `{Spacy.Token:List(str,float)}` | if corrections are performed, it returns the mapping of misspell token (`spaCy.Token`) with suggested words(`str`) and probability of that correction for tokens in this `span` | `{spaCy.Token: []}` |
| span._.score_spellCheck | `{Spacy.Token:List(str,float)}` | if corrections are identified, it returns the mapping of misspell token (`spaCy.Token`) with suggested words(`str`) and probability of that correction for tokens in this `span` | `{spaCy.Token: []}` |

### `spaCy.Token` level extensions

| Extension | Type | Description | Default |
|-----------------------------------|-----------------|-------------------------------------------------------------------------------------------------------------|---------|
| token._.get_require_spellCheck | `Boolean` | To check whether contextualSpellCheck identified any misspells and performed correction on this `token` | `False` |
| token._.get_suggestion_spellCheck | `str` | if corrections are performed, it returns the suggested word(`str`) | `""` |
| token._.score_spellCheck | `[(str,float)]` | if corrections are performed, it returns suggested words(`str`) and probability(`float`) of that correction | `[]` |
| token._.score_spellCheck | `[(str,float)]` | if corrections are identified, it returns suggested words(`str`) and probability(`float`) of that correction | `[]` |

## API

Expand Down Expand Up @@ -189,6 +189,8 @@ Response:
- [ ] better candidate generation (maybe by fine tuning the model?)
- [ ] add metric by testing on datasets
- [ ] Improve documentation
- [ ] Add examples for other langauges
- [ ] use piece wise tokeniser when identifying the misspell

## Support and contribution

Expand Down
28 changes: 14 additions & 14 deletions contextualSpellCheck/contextualSpellCheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def __init__(
"Please check datatype provided. vocab_path should be str,"
" debug and performance should be bool"
)
try:
int(float(max_edit_dist))
except ValueError as identifier:
raise ValueError(
f"cannot convert {max_edit_dist} to int. Please provide a valid integer"
)

if vocab_path != "":
try:
Expand Down Expand Up @@ -101,23 +107,14 @@ def __init__(
vocab_path = ""
words = []

if vocab_path == "":
current_path = os.path.dirname(__file__)
vocab_path = os.path.join(current_path, "data/vocab.txt")
with open(vocab_path, encoding="utf8") as f:
# if want to remove '[unusedXX]' from vocab
# words = [
# line.rstrip()
# for line in f
# if not line.startswith("[unused")
# ]
words = [line.strip() for line in f]

self.max_edit_dist = max_edit_dist
self.max_edit_dist = int(float(max_edit_dist))
self.model_name = model_name
self.BertTokenizer = AutoTokenizer.from_pretrained(self.model_name)

if vocab_path == "":
words = list(self.BertTokenizer.get_vocab().keys())
self.vocab = Vocab(strings=words)
logging.getLogger("transformers").setLevel(logging.ERROR)
self.BertTokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.BertModel = AutoModelForMaskedLM.from_pretrained(self.model_name)
self.mask = self.BertTokenizer.mask_token
self.debug = debug
Expand Down Expand Up @@ -234,6 +231,9 @@ def misspell_identify(self, doc, query=""):
`tuple`: returns `List[`Spacy.Token`]` and `Spacy.Doc`
"""

# deep copy is required to preserve individual token info
# from objects in pipeline which can modify token info
# like merge_entities
docCopy = copy.deepcopy(doc)

misspell = []
Expand Down
56 changes: 55 additions & 1 deletion contextualSpellCheck/tests/test_contextualSpellCheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,21 @@ def test_warning():
e
== "Please check datatype provided. vocab_path should be str, debug and performance should be bool"
)
max_edit_distance = "non_int_or_float"
with pytest.raises(ValueError) as e:
ContextualSpellCheck(max_edit_dist=max_edit_distance)
assert (
e
== f"cannot convert {max_edit_distance} to int. Please provide a valid integer"
)

try:
ContextualSpellCheck(max_edit_dist="3.1")
except Exception as uncatched_error:
pytest.fail(str(uncatched_error))

def test_vocabFile():

def test_vocab_file():
with warnings.catch_warnings(record=True) as w:
ContextualSpellCheck(vocab_path="testing.txt")
assert any([issubclass(i.category, UserWarning) for i in w])
Expand All @@ -577,3 +589,45 @@ def test_bert_model_name():
with pytest.raises(OSError) as e:
ContextualSpellCheck(model_name=model_name)
assert e == error_message


def test_correct_model_name():
model_name = "TurkuNLP/bert-base-finnish-cased-v1"
try:
ContextualSpellCheck(model_name=model_name)
except OSError:
pytest.fail("Specificed model is not present in transformers")
except Exception as uncatched_error:
pytest.fail(str(uncatched_error))


@pytest.mark.parametrize(
"max_edit_distance,expected_spell_check_flag",
[(0, False), (1, False), (2, True), (3, True)],
)
def test_max_edit_dist(max_edit_distance, expected_spell_check_flag):
if "contextual spellchecker" in nlp.pipe_names:
nlp.remove_pipe("contextual spellchecker")
checker_edit_dist = ContextualSpellCheck(max_edit_dist=max_edit_distance)
nlp.add_pipe(checker_edit_dist)
doc = nlp(
"Income was $9.4 milion compared to the prior year of $2.7 milion."
)

# To check the status of `performed_spell_check` flag
assert doc[4]._.get_require_spellCheck == expected_spell_check_flag
assert doc[3:5]._.get_has_spellCheck == expected_spell_check_flag
assert doc._.performed_spellCheck == expected_spell_check_flag

# To check the response of "suggestions_spellCheck"
gold_outcome = (
"Income was $9.4 million compared to the prior year of $2.7 million."
)
gold_token = "million"
gold_outcome = gold_outcome if expected_spell_check_flag else ""
gold_token = gold_token if expected_spell_check_flag else ""
print("gold_outcome:", gold_outcome, "gold_token:", gold_token)
assert doc[4]._.get_suggestion_spellCheck == gold_token
assert doc._.outcome_spellCheck == gold_outcome

nlp.remove_pipe("contextual spellchecker")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="contextualSpellCheck",
version="0.2.1",
version="0.3.0",
author="R1j1t",
author_email="[email protected]",
description="Contextual spell correction using BERT (bidirectional representations)",
Expand Down

0 comments on commit be9c064

Please sign in to comment.