Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(helper): batching decorator supports tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Aug 26, 2019
1 parent 928574c commit ce0e65a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
9 changes: 5 additions & 4 deletions gnes/encoder/text/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing import List, Tuple

import torch
from pytorch_transformers import *

from gnes.encoder.base import BaseTextEncoder
from ..base import BaseTextEncoder
from ...helper import batching


class PyTorchTransformers(BaseTextEncoder):
Expand Down Expand Up @@ -58,16 +58,17 @@ def load_model_tokenizer(x):
self.logger.warning('cannot deserialize model/tokenizer from %s, will download from web' % self.work_dir)
self.model, self.tokenizer = load_model_tokenizer(pretrained_weights)

@batching
def encode(self, text: List[str], *args, **kwargs) -> Tuple:
# encoding and padding
ids = [self.tokenizer.encode(t) for t in text]
max_len = max(len(t) for t in ids)
ids = [t + [0] * (max_len - len(t)) for t in ids]
m_ids = [[1] * len(t) + [0] * (max_len - len(t)) for t in ids]
seq_ids = torch.tensor(ids)
mask_ids = torch.tensor(m_ids)
mask_ids = torch.tensor(m_ids, dtype=torch.float32)

if self.use_cuda:
if self.on_gpu:
seq_ids = seq_ids.cuda()

with torch.no_grad():
Expand Down
26 changes: 21 additions & 5 deletions gnes/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,27 @@ def arg_wrapper(self, data, label=None, *args, **kwargs):
if r is not None:
final_result.append(r)

if len(final_result) and concat_axis is not None and isinstance(final_result[0], np.ndarray):
final_result = np.concatenate(final_result, concat_axis)

if chunk_dim != -1:
final_result = final_result.reshape((-1, chunk_dim, final_result.shape[1]))
if len(final_result) == 1:
# the only result of one batch
return final_result[0]

if len(final_result) and concat_axis is not None:
if isinstance(final_result[0], np.ndarray):
final_result = np.concatenate(final_result, concat_axis)
if chunk_dim != -1:
final_result = final_result.reshape((-1, chunk_dim, final_result.shape[1]))
elif isinstance(final_result[0], tuple):
reduced_result = []
num_cols = len(final_result[0])
for col in range(num_cols):
reduced_result.append(np.concatenate([row[col] for row in final_result], concat_axis))
if chunk_dim != -1:
for col in range(num_cols):
reduced_result[col] = reduced_result[col].reshape(
(-1, chunk_dim, reduced_result[col].shape[1]))
final_result = tuple(reduced_result)
else:
raise TypeError('dont know how to reduce %s' % type(final_result[0]))

if len(final_result):
return final_result
Expand Down

0 comments on commit ce0e65a

Please sign in to comment.