From ce0e65aebcc4f779972fb59593542826467d27e5 Mon Sep 17 00:00:00 2001 From: hanhxiao Date: Mon, 26 Aug 2019 16:51:39 +0800 Subject: [PATCH] feat(helper): batching decorator supports tuple --- gnes/encoder/text/transformer.py | 9 +++++---- gnes/helper.py | 26 +++++++++++++++++++++----- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/gnes/encoder/text/transformer.py b/gnes/encoder/text/transformer.py index 0317e32e..512fc7d2 100644 --- a/gnes/encoder/text/transformer.py +++ b/gnes/encoder/text/transformer.py @@ -17,9 +17,9 @@ from typing import List, Tuple import torch -from pytorch_transformers import * -from gnes.encoder.base import BaseTextEncoder +from ..base import BaseTextEncoder +from ...helper import batching class PyTorchTransformers(BaseTextEncoder): @@ -58,6 +58,7 @@ def load_model_tokenizer(x): self.logger.warning('cannot deserialize model/tokenizer from %s, will download from web' % self.work_dir) self.model, self.tokenizer = load_model_tokenizer(pretrained_weights) + @batching def encode(self, text: List[str], *args, **kwargs) -> Tuple: # encoding and padding ids = [self.tokenizer.encode(t) for t in text] @@ -65,9 +66,9 @@ def encode(self, text: List[str], *args, **kwargs) -> Tuple: ids = [t + [0] * (max_len - len(t)) for t in ids] m_ids = [[1] * len(t) + [0] * (max_len - len(t)) for t in ids] seq_ids = torch.tensor(ids) - mask_ids = torch.tensor(m_ids) + mask_ids = torch.tensor(m_ids, dtype=torch.float32) - if self.use_cuda: + if self.on_gpu: seq_ids = seq_ids.cuda() with torch.no_grad(): diff --git a/gnes/helper.py b/gnes/helper.py index aa2f0691..55f71cfa 100644 --- a/gnes/helper.py +++ b/gnes/helper.py @@ -403,11 +403,27 @@ def arg_wrapper(self, data, label=None, *args, **kwargs): if r is not None: final_result.append(r) - if len(final_result) and concat_axis is not None and isinstance(final_result[0], np.ndarray): - final_result = np.concatenate(final_result, concat_axis) - - if chunk_dim != -1: - final_result = final_result.reshape((-1, chunk_dim, final_result.shape[1])) + if len(final_result) == 1: + # the only result of one batch + return final_result[0] + + if len(final_result) and concat_axis is not None: + if isinstance(final_result[0], np.ndarray): + final_result = np.concatenate(final_result, concat_axis) + if chunk_dim != -1: + final_result = final_result.reshape((-1, chunk_dim, final_result.shape[1])) + elif isinstance(final_result[0], tuple): + reduced_result = [] + num_cols = len(final_result[0]) + for col in range(num_cols): + reduced_result.append(np.concatenate([row[col] for row in final_result], concat_axis)) + if chunk_dim != -1: + for col in range(num_cols): + reduced_result[col] = reduced_result[col].reshape( + (-1, chunk_dim, reduced_result[col].shape[1])) + final_result = tuple(reduced_result) + else: + raise TypeError('dont know how to reduce %s' % type(final_result[0])) if len(final_result): return final_result