-
Notifications
You must be signed in to change notification settings - Fork 4
/
model_factory.py
80 lines (65 loc) · 2.55 KB
/
model_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import codecs
import json
import os
import pickle
from collections import OrderedDict
from pprint import pprint
# folders
DIR_ROOT = os.path.dirname(__file__)
DIR_MODELS = os.path.join(DIR_ROOT, 'models/')
# configs per language
DIR_MODELS_IT = os.path.join(DIR_MODELS, 'italian/')
CFG_ITALIAN = os.path.join(DIR_ROOT, 'configs/italian.json')
DIR_MODELS_EN = os.path.join(DIR_MODELS, 'english/')
CFG_ENGLISH = os.path.join(DIR_ROOT, 'configs/english.json')
# config keys
CFG_LANG = 'lang'
CFG_PBR = 'problem_report'
CFG_INQ = 'inquiry'
CFG_IRR = 'irrelevant'
CFG_FEATURES = 'features'
CFG_MODEL = 'model'
# language models
LANGUAGE_MODEL_IT = None
LANGUAGE_MODEL_EN = None
def get_models_dir(lang):
if lang == 'en':
return DIR_MODELS_EN
if lang == 'it':
return DIR_MODELS_IT
return None
class ModelFactory:
@staticmethod
def create(lang=None):
reader = codecs.getreader("utf-8")
if lang is None:
return None
elif lang == 'en':
if LANGUAGE_MODEL_EN is None:
cfg = json.load(open(CFG_ENGLISH), object_pairs_hook=OrderedDict)
return LanguageModel(DIR_MODELS_EN, cfg)
else:
return LANGUAGE_MODEL_EN
elif lang == 'it':
if LANGUAGE_MODEL_IT is None:
cfg = json.load(open(CFG_ITALIAN), object_pairs_hook=OrderedDict)
return LanguageModel(DIR_MODELS_IT, cfg)
else:
return LANGUAGE_MODEL_IT
class LanguageModel:
def __init__(self, model_folder, cfg):
if cfg is None:
return None
else:
self.lang = cfg[CFG_LANG]
self.cfg_pbr = OrderedDict(cfg[CFG_PBR])
self.cfg_inq = OrderedDict(cfg[CFG_INQ])
self.cfg_irr = OrderedDict(cfg[CFG_IRR])
self.clf_pbr = pickle.load(open(os.path.join(model_folder, cfg[CFG_PBR][CFG_MODEL]), 'rb'))
self.clf_inq = pickle.load(open(os.path.join(model_folder, cfg[CFG_INQ][CFG_MODEL]), 'rb'))
self.clf_irr = pickle.load(open(os.path.join(model_folder, cfg[CFG_IRR][CFG_MODEL]), 'rb'))
self.are_feature_vectors_identical = self.__are_features_identical_acrosstargets()
# returns true, if for all targets (e.g., problem_report, inquiry), the features, their order, and their configuration is the same
# then we only need to extract the feature vector once.
def __are_features_identical_acrosstargets(self):
return self.cfg_pbr[CFG_FEATURES] == self.cfg_inq[CFG_FEATURES] == self.cfg_irr[CFG_FEATURES]