Skip to content

Commit

Permalink
refactor bert parsing, add sentence embedder function
Browse files Browse the repository at this point in the history
  • Loading branch information
rknaebel committed Jul 23, 2021
1 parent d715216 commit 397110c
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 73 deletions.
8 changes: 3 additions & 5 deletions cli/bert/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import click
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModel

from discopy.components.nn.bert import get_sentence_embeddings
from discopy.parsers.pipeline import ParserPipeline
from discopy.utils import init_logger
from discopy_data.data.doc import Document
from discopy_data.nn.bert import get_sentence_embedder


@click.command()
Expand All @@ -19,10 +18,9 @@
def main(bert_model, model_path, src, tgt, limit):
logger = init_logger()
logger.info('Init Parser...')
get_sentence_embeddings = get_sentence_embedder(bert_model)
parser = ParserPipeline.from_config(model_path)
parser.load(model_path)
tokenizer = AutoTokenizer.from_pretrained(bert_model)
model = TFAutoModel.from_pretrained(bert_model)
logger.info('Load pre-trained Parser...')
for line_i, line in tqdm(enumerate(src)):
if limit and line_i >= limit:
Expand All @@ -32,7 +30,7 @@ def main(bert_model, model_path, src, tgt, limit):
continue
for sent_i, sent in enumerate(doc.sentences):
sent_words = sent.tokens
embeddings = get_sentence_embeddings(sent_words, tokenizer, model)
embeddings = get_sentence_embeddings(sent_words)
doc.sentences[sent_i].embeddings = embeddings
doc = parser(doc)
tgt.write(json.dumps(doc.to_json()) + '\n')
Expand Down
63 changes: 0 additions & 63 deletions discopy/components/nn/bert.py

This file was deleted.

3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,5 @@ sklearn
sklearn-crfsuite
tensorflow>=2.1.0
transformers==4.2.1
fastapi==0.61.2
uvicorn==0.11.3
fastapi==0.67.0
git+git://github.com:rknaebel/discopy-data
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = fh.read()

setup(name='discopy-rknaebel',
version='1.0.0',
version='1.0.1',
description='Shallow Discourse Parser',
long_description=long_description,
long_description_content_type="text/markdown",
Expand All @@ -21,8 +21,7 @@
'sklearn-crfsuite',
'tensorflow>=2.1.0',
'transformers>=3.5.0',
'fastapi==0.61.2',
'uvicorn==0.11.3',
'fastapi==0.67.0',
'discopy-data-rknaebel',
],
zip_safe=False,
Expand All @@ -31,6 +30,9 @@
'discopy-train=cli.train:main',
'discopy-eval=cli.eval:main',
'discopy-parse=cli.parse:main',
'discopy-nn-train=cli.bert.train:main',
'discopy-nn-parse=cli.bert.parse:main',
'discopy-nn-predict=cli.bert.predict:main',
],
},
python_requires='>=3.7',
Expand Down

0 comments on commit 397110c

Please sign in to comment.