Skip to content

Commit

Permalink
feat: check saved model params and fix lowercase for interact
Browse files Browse the repository at this point in the history
* fix: lowercase text while interact

* feat: check saved model params

* fix: rm extra configs
  • Loading branch information
vikmary authored and seliverstov committed Mar 15, 2018
1 parent 284c2d9 commit c7a131c
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 24 deletions.
2 changes: 1 addition & 1 deletion deeppavlov/configs/go_bot/gobot_dstc2.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"save_path": "go_bot/model",
"learning_rate": 0.002,
"dropout_rate": 0.8,
"hidden_dim": 128,
"hidden_size": 128,
"dense_size": 64,
"obs_size": 530,
"action_size": 45
Expand Down
2 changes: 1 addition & 1 deletion deeppavlov/configs/go_bot/gobot_dstc2_minimal.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"save_path": "go_bot_minimal/model",
"learning_rate": 0.01,
"dropout_rate": 0.9,
"hidden_dim": 128,
"hidden_size": 128,
"dense_size": 64,
"obs_size": 502,
"action_size": 45
Expand Down
2 changes: 1 addition & 1 deletion deeppavlov/skills/go_bot/go_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, template_path,

def _encode_context(self, context, db_result=None):
# tokenize input
tokenized = ' '.join(self.tokenizer([context])[0]).strip()
tokenized = ' '.join(self.tokenizer([context])[0]).lower().strip()
if self.debug:
log.debug("Text tokens = `{}`".format(tokenized))

Expand Down
69 changes: 48 additions & 21 deletions deeppavlov/skills/go_bot/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer
Expand All @@ -27,11 +28,10 @@

@register('go_bot_rnn')
class GoalOrientedBotNetwork(TFModel):
GRAPH_PARAMS = ["hidden_size", "action_size", "dense_size", "obs_size"]
def __init__(self, **params):
self.opt = params

# initialize parameters
self._init_params()
self._init_params(params)
# build computational graph
self._build_graph()
# initialize session
Expand Down Expand Up @@ -71,15 +71,17 @@ def train_on_batch(self, x: list, y: list):
action = y
self._train_step(features, action, action_mask)

def _init_params(self, params=None):
params = params or self.opt
self.learning_rate = params['learning_rate']
self.dropout_rate = params.get('dropout_rate', 1.)
self.n_hidden = params['hidden_dim']
self.n_actions = params['action_size']
#TODO: try obs_size=None or as a placeholder
self.obs_size = params['obs_size']
self.dense_size = params.get('dense_size', params['hidden_dim'])
def _init_params(self, params):
self.opt = params
self.opt['dropout_rate'] = params.get('dropout_rate', 1.)
self.opt['dense_size'] = params.get('dense_size', self.opt['hidden_size'])

self.learning_rate = self.opt['learning_rate']
self.dropout_rate = self.opt['dropout_rate']
self.hidden_size = self.opt['hidden_size']
self.action_size = self.opt['action_size']
self.obs_size = self.opt['obs_size']
self.dense_size = self.opt['dense_size']

def _build_graph(self):

Expand All @@ -105,18 +107,18 @@ def _add_placeholders(self):
# TODO: make batch_size != 1
self._dropout = tf.placeholder_with_default(1.0, shape=[])
_initial_state_c = \
tf.placeholder_with_default(np.zeros([1, self.n_hidden], np.float32),
shape=[1, self.n_hidden])
tf.placeholder_with_default(np.zeros([1, self.hidden_size], np.float32),
shape=[1, self.hidden_size])
_initial_state_h = \
tf.placeholder_with_default(np.zeros([1, self.n_hidden], np.float32),
shape=[1, self.n_hidden])
tf.placeholder_with_default(np.zeros([1, self.hidden_size], np.float32),
shape=[1, self.hidden_size])
self._initial_state = tf.nn.rnn_cell.LSTMStateTuple(_initial_state_c,
_initial_state_h)
self._features = tf.placeholder(tf.float32, [1, None, self.obs_size],
name='features')
self._action = tf.placeholder(tf.int32, [1, None],
name='ground_truth_action')
self._action_mask = tf.placeholder(tf.float32, [None, None, self.n_actions],
self._action_mask = tf.placeholder(tf.float32, [None, None, self.action_size],
name='action_mask')

def _build_body(self):
Expand All @@ -127,22 +129,47 @@ def _build_body(self):
kernel_initializer=xavier_initializer())

# recurrent network unit
_lstm_cell = tf.nn.rnn_cell.LSTMCell(self.n_hidden)
_lstm_cell = tf.nn.rnn_cell.LSTMCell(self.hidden_size)
_output, _state = tf.nn.dynamic_rnn(_lstm_cell,
_units,
initial_state=self._initial_state)

# output projection
# TODO: try multiplying logits to action_mask
_logits = tf.layers.dense(_output,
self.n_actions,
self.action_size,
kernel_initializer=xavier_initializer())
return _logits, _state

def load(self, *args, **kwargs):
super().load(*args, **kwargs)
self.load_params()

def save(self, *args, **kwargs):
super().save(*args, **kwargs)
self.save_params()

def save_params(self):
path = str(self.save_path.with_suffix('.json').resolve())
log.info('[saving parameters to {}]'.format(path))
with open(path, 'w') as fp:
json.dump(self.opt, fp)

def load_params(self):
path = str(self.load_path.with_suffix('.json').resolve())
log.info('[loading parameters from {}]'.format(path))
with open(path, 'r') as fp:
params = json.load(fp)
for p in self.GRAPH_PARAMS:
if self.opt[p] != params[p]:
raise ValueError("`{}` parameter must be equal to "
"saved model parameter value `{}`"\
.format(p, params[p]))

def reset_state(self):
# set zero state
self.state_c = np.zeros([1, self.n_hidden], dtype=np.float32)
self.state_h = np.zeros([1, self.n_hidden], dtype=np.float32)
self.state_c = np.zeros([1, self.hidden_size], dtype=np.float32)
self.state_h = np.zeros([1, self.hidden_size], dtype=np.float32)

def _train_step(self, features, action, action_mask):
_, loss_value, prediction = \
Expand Down

0 comments on commit c7a131c

Please sign in to comment.