Skip to content

Commit

Permalink
Fix mozilla#364; Compare Bidirectional Recurrence vs Unidirectional R…
Browse files Browse the repository at this point in the history
…ecurrence & Row Convolution
  • Loading branch information
andi4191 committed Jul 10, 2017
1 parent ed36dc8 commit 7588de8
Showing 1 changed file with 90 additions and 4 deletions.
94 changes: 90 additions & 4 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,88 @@ def variable_on_worker_level(name, shape, initializer):
return var



def UiRNN(batch_x, seq_length, dropout):
r'''
UiRNN - It is the unidirectional RNN model intented to compare the results against the BiRNN for benchmark variation
'''

# Input shape: [batch_size, n_steps, n_input + 2*n_input*n_context]
batch_x_shape = tf.shape(batch_x)

# Reshaping `batch_x` to a tensor with shape `[n_steps*batch_size, n_input + 2*n_input*n_context]`.
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.

# Permute n_steps and batch_size
batch_x = tf.transpose(batch_x, [1, 0, 2])
# Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, n_input + 2*n_input*n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)

# The next three blocks will pass `batch_x` through three hidden layers with
# clipped RELU activation and dropout.

# 1st layer
b1 = variable_on_worker_level('b1', [n_hidden_1], tf.random_normal_initializer(stddev=FLAGS.b1_stddev))
h1 = variable_on_worker_level('h1', [n_input + 2*n_input*n_context, n_hidden_1], tf.contrib.layers.xavier_initializer(uniform=False))
layer_1 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(batch_x, h1), b1)), FLAGS.relu_clip)
layer_1 = tf.nn.dropout(layer_1, (1.0 - dropout[0]))

# 2nd layer
b2 = variable_on_worker_level('b2', [n_hidden_2], tf.random_normal_initializer(stddev=FLAGS.b2_stddev))
h2 = variable_on_worker_level('h2', [n_hidden_1, n_hidden_2], tf.random_normal_initializer(stddev=FLAGS.h2_stddev))
layer_2 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_1, h2), b2)), FLAGS.relu_clip)
layer_2 = tf.nn.dropout(layer_2, (1.0 - dropout[1]))

# 3rd layer
b3 = variable_on_worker_level('b3', [n_hidden_3], tf.random_normal_initializer(stddev=FLAGS.b3_stddev))
h3 = variable_on_worker_level('h3', [n_hidden_2, n_hidden_3], tf.random_normal_initializer(stddev=FLAGS.h3_stddev))
layer_3 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(layer_2, h3), b3)), FLAGS.relu_clip)
layer_3 = tf.nn.dropout(layer_3, (1.0 - dropout[2]))

# Now we create the forward LSTM units.
# Both of which have inputs of length `n_cell_dim` and bias `1.0` for the forget gate of the LSTM.

# Forward direction cell: (if else required for TF 1.0 and 1.1 compat)
lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True) \
if 'reuse' not in inspect.getargspec(tf.contrib.rnn.BasicLSTMCell.__init__).args else \
tf.contrib.rnn.BasicLSTMCell(n_cell_dim, forget_bias=1.0, state_is_tuple=True, reuse=tf.get_variable_scope().reuse)
lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell,
input_keep_prob=1.0 - dropout[3],
output_keep_prob=1.0 - dropout[3],
seed=FLAGS.random_seed)

# `layer_3` is now reshaped into `[n_steps, batch_size, 2*n_cell_dim]`,
# as the LSTM BRNN expects its input to be of shape `[max_time, batch_size, input_size]`.
layer_3 = tf.reshape(layer_3, [-1, batch_x_shape[0], n_hidden_3])

# Now we feed `layer_3` into the LSTM RNN cell and obtain the LSTM Unidirectional RNN output.
lstm_cell_stack = tf.contrib.rnn.MultiRNNCell([lstm_fw_cell]*2, state_is_tuple=True)
outputs, output_states = tf.nn.dynamic_rnn(cell=lstm_cell_stack, inputs=layer_3, dtype=tf.float32, time_major=True, sequence_length=seq_length)
# Reshape outputs from two tensors each of shape [n_steps, batch_size, n_cell_dim]
# to a single tensor of shape [n_steps*batch_size, 2*n_cell_dim]
outputs = tf.concat(outputs, 2)
outputs = tf.reshape(outputs, [-1, n_cell_dim])

# Now we feed `outputs` to the fifth hidden layer with clipped RELU activation and dropout
b5 = variable_on_worker_level('b5', [n_hidden_5], tf.random_normal_initializer(stddev=FLAGS.b5_stddev))
h5 = variable_on_worker_level('h5', [(n_cell_dim), n_hidden_5], tf.random_normal_initializer(stddev=FLAGS.h5_stddev))
layer_5 = tf.minimum(tf.nn.relu(tf.add(tf.matmul(outputs, h5), b5)), FLAGS.relu_clip)
layer_5 = tf.nn.dropout(layer_5, (1.0 - dropout[5]))

# Now we apply the weight matrix `h6` and bias `b6` to the output of `layer_5`
# creating `n_classes` dimensional vectors, the logits.
b6 = variable_on_worker_level('b6', [n_hidden_6], tf.random_normal_initializer(stddev=FLAGS.b6_stddev))
h6 = variable_on_worker_level('h6', [n_hidden_5, n_hidden_6], tf.contrib.layers.xavier_initializer(uniform=False))
layer_6 = tf.add(tf.matmul(layer_5, h6), b6)

# Finally we reshape layer_6 from a tensor of shape [n_steps*batch_size, n_hidden_6]
# to the slightly more useful shape [n_steps, batch_size, n_hidden_6].
# Note, that this differs from the input in that it is time-major.
layer_6 = tf.reshape(layer_6, [-1, batch_x_shape[0], n_hidden_6])

# Output shape: [n_steps, batch_size, n_hidden_6]
return layer_6

def BiRNN(batch_x, seq_length, dropout):
r'''
That done, we will define the learned variables, the weights and biases,
Expand Down Expand Up @@ -466,8 +548,8 @@ def calculate_mean_edit_distance_and_loss(batch_set, dropout):
# Obtain the next batch of data
batch_x, batch_seq_len, batch_y = batch_set.next_batch()

# Calculate the logits of the batch using BiRNN
logits = BiRNN(batch_x, tf.to_int64(batch_seq_len), dropout)
# Calculate the logits of the batch using UiRNN
logits = UiRNN(batch_x, tf.to_int64(batch_seq_len), dropout)

# Compute the CTC loss using either TensorFlow's `ctc_loss` or Baidu's `warp_ctc_loss`.
if FLAGS.use_warpctc:
Expand Down Expand Up @@ -889,6 +971,7 @@ def __init__(self, index, num_jobs, set_name='train', report=False):
self.jobs_running = []
self.jobs_done = []
self.samples = []
self._time = stopwatch()
for i in range(self.num_jobs):
self.jobs_open.append(WorkerJob(self.id, self.index, self.set_name, FLAGS.iters_per_worker, self.report))

Expand Down Expand Up @@ -969,6 +1052,9 @@ def done(self):

self.loss = agg_loss / num_jobs

self._time = stopwatch(self._time)
log_info('job type: %s, inference time: %s ' % (self.set_name, format_duration(self._time)))

# if the job was for validation dataset then append it to the COORD's _loss for early stop verification
if (FLAGS.early_stop is True) and (self.set_name == 'dev'):
COORD._dev_losses.append(self.loss)
Expand Down Expand Up @@ -1607,8 +1693,8 @@ def export():

seq_length = tf.placeholder(tf.int32, [None], name='input_lengths')

# Calculate the logits of the batch using BiRNN
logits = BiRNN(input_tensor, tf.to_int64(seq_length), no_dropout)
# Calculate the logits of the batch using UiRNN
logits = UiRNN(input_tensor, tf.to_int64(seq_length), no_dropout)

# Beam search decode the batch
decoded, _ = tf.nn.ctc_beam_search_decoder(logits, seq_length, merge_repeated=False)
Expand Down

0 comments on commit 7588de8

Please sign in to comment.