Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(base): support loading external modules from py and yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Jul 19, 2019
1 parent 07ce8e9 commit 66f26e0
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
3 changes: 3 additions & 0 deletions tests/contrib/dummy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
!FooContribEncoder
parameter:
bar: 531
12 changes: 12 additions & 0 deletions tests/contrib/dummy_contrib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from gnes.encoder.base import BaseTextEncoder


class FooContribEncoder(BaseTextEncoder):

def __init__(self, bar: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_trained = True
self.bar = bar

def encode(self, text, **kwargs):
return 'hello %d' % self.bar
4 changes: 2 additions & 2 deletions tests/test_contrib_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ class TestYaml(unittest.TestCase):

def setUp(self):
dirname = os.path.dirname(__file__)
module_path = os.path.join(dirname, 'contrib', 'toy22.py')
module_path = os.path.join(dirname, 'contrib', 'dummy_contrib.py')
cls_name = 'FooContribEncoder'
os.environ['GNES_CONTRIB_MODULE'] = '%s:%s' % (cls_name, module_path)
self.yaml_path = os.path.join(os.path.dirname(__file__),
'contrib', 'toy22.yml')
'contrib', 'dummy.yml')

def test_load_contrib(self):
from gnes.encoder.base import BaseEncoder, BaseTextEncoder
Expand Down
4 changes: 2 additions & 2 deletions tests/test_contrib_module_negative.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ class TestYaml(unittest.TestCase):

def setUp(self):
dirname = os.path.dirname(__file__)
module_path = os.path.join(dirname, 'contrib', 'toy22.py')
module_path = os.path.join(dirname, 'contrib', 'dummy_contrib.py')
cls_name = 'FooContribEncoder'
os.environ['GNES_CONTRIB_MODULE'] = '%s:%s' % (cls_name, module_path)
self.yaml_path = os.path.join(os.path.dirname(__file__),
'contrib', 'toy22.yml')
'contrib', 'dummy.yml')

def test_broken_contrib(self):
os.environ['GNES_CONTRIB_MODULE'] = ''
Expand Down

0 comments on commit 66f26e0

Please sign in to comment.