Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(encoder):
Browse files Browse the repository at this point in the history
fix bug caused by batching in inception_mixture
  • Loading branch information
Larryjianfeng committed Aug 6, 2019
1 parent 4edb8ee commit fd37e6d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions gnes/encoder/video/incep_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,21 @@ def encode(self, data: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
in img]

@batching(concat_axis=None)
def _encode1(_, data):
def _encode1(self, data):
_, end_points_ = self.sess.run((self.logits, self.end_points),
feed_dict={self.inputs: data})
return end_points_[self.select_layer]

v = [_ for vi in _encode1(None, img) for _ in vi]
v = [_ for vi in _encode1(self, img) for _ in vi]
print(np.array(v).shape)

v_input = [v[s:e] for s, e in zip(pos_start, pos_end)]
v_input = [(vi + [[0.0] * self.input_size] * (max_len - len(vi)))[:max_len] for vi in v_input]
v_input = [np.array(vi, dtype=np.float32) for vi in v_input]

@batching
def _encode2(_, data):
def _encode2(self, data):
return self.sess2.run(self.mix_model.repre,
feed_dict={self.mix_model.feeds: data})

return _encode2(None, v_input).astype(np.float32)
return _encode2(self, v_input).astype(np.float32)

0 comments on commit fd37e6d

Please sign in to comment.