From 0fae3067aa56921f8fce57732b698aa98aaa1d08 Mon Sep 17 00:00:00 2001 From: zche4846 <59648772+zche4846@users.noreply.github.com> Date: Thu, 9 Jun 2022 14:55:55 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=20No.77=E3=80=91=E5=9F=BA?= =?UTF-8?q?=E4=BA=8EERNIE=5FDOC=E5=AE=8C=E6=88=90=E9=95=BF=E6=96=87?= =?UTF-8?q?=E6=9C=AC=E5=88=86=E7=B1=BB=E4=BB=BB=E5=8A=A1=20(#1845)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add ernie_doc example * add README file * bug fix * bug fix * space to tabs * trim trailing space * trim trailing space * trim trailing space * trim trailing space * end of file line * end of file line * end of file line * yapf * yapf * yapf * yapf * change copyright info * delete redundant code & modify copyright info & compatibility improvement. * editing docstrings & rewriting memories initiation method for static mode. * code style fix * editing docstrings for variable 'memories' Co-authored-by: yingyibiao Co-authored-by: Jack Zhou --- .../text_classification/ernie_doc/README.md | 86 ++ .../text_classification/ernie_doc/__init__.py | 0 .../text_classification/ernie_doc/data.py | 1248 +++++++++++++++++ .../ernie_doc/export_model.py | 57 + .../text_classification/ernie_doc/metrics.py | 314 +++++ .../text_classification/ernie_doc/modeling.py | 987 +++++++++++++ .../text_classification/ernie_doc/predict.py | 301 ++++ .../text_classification/ernie_doc/train.py | 346 +++++ 8 files changed, 3339 insertions(+) create mode 100644 examples/text_classification/ernie_doc/README.md create mode 100644 examples/text_classification/ernie_doc/__init__.py create mode 100644 examples/text_classification/ernie_doc/data.py create mode 100644 examples/text_classification/ernie_doc/export_model.py create mode 100644 examples/text_classification/ernie_doc/metrics.py create mode 100644 examples/text_classification/ernie_doc/modeling.py create mode 100644 examples/text_classification/ernie_doc/predict.py create mode 100644 examples/text_classification/ernie_doc/train.py diff --git a/examples/text_classification/ernie_doc/README.md b/examples/text_classification/ernie_doc/README.md new file mode 100644 index 000000000000..9a93b5e7580c --- /dev/null +++ b/examples/text_classification/ernie_doc/README.md @@ -0,0 +1,86 @@ +# Ernie_doc 在iflytek数据集上的使用 + +## 简介 + +本示例将使用ERNIE-DOC模型,演示如何在长文本数据集上(e.g. iflytek)完成分类任务的训练,预测以及动转静过程。以下是本例的简要目录结构及说明: + +```shell +. +├── LICENSE +├── README.md #文档 +├── data.py #数据处理 +├── export_model.py #将动态图参数导出成静态图参数 +├── metrics.py #ERNIE-Doc下游任务指标 +├── modeling.py #ERNIE-Doc模型实现(针对实现静态图修改) +├── predict.py #分类任务预测脚本(包括动态图预测和动转静) +└── train.py #分类任务训练脚本(包括数据下载,模型导出和测试集结果导出) +``` + +## 快速开始 + +### 通用参数释义 + +除[ERNIE_DOC](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/examples/language_model/ernie-doc/run_classifier.py) +展示的通用参数之外,本例还有如下参数: + +- `static_mode` 在 `predict.py` 表示是否使用静态图进行预测。 +- `test_results_file` 在`train.py`和`predict.py`中表示测试集预测结果所存储的地址,默认为`./test_restuls.json`。 +- `static_path` 在`export_model.py`和`predict.py`中表示要将转化完成的静态图存储的地址,如果改地址已经有静态图模型参数,`predict.py` + 会直接读取该模型参数,而`export_model.py`会覆盖掉该模型参数。默认路径为`{HOME}/.paddlenlp/static/inference`。 + +### 分类任务训练 + +iflytek的数据示例如下: + +```shell +{"label": "110", "label_des": "社区超市", "sentence": "朴朴快送超市创立于2016年,专注于打造移动端30分钟即时配送一站式购物平台,商品品类包含水果、蔬菜、肉禽蛋奶、海鲜水产、粮油调味、酒水饮料、休闲食品、日用品、外卖等。朴朴公司希望能以全新的商业模式,更高效快捷的仓储配送模式,致力于成为更快、更好、更多、更省的在线零售平台,带给消费者更好的消费体验,同时推动中国食品安全进程,成为一家让社会尊敬的互联网公司。,朴朴一下,又好又快,1.配送时间提示更加清晰友好2.保障用户隐私的一些优化3.其他提高使用体验的调整4.修复了一些已知bug"} +``` + +该数据集共有1.7万多条关于app应用描述的长文本标注数据,包含和日常生活相关的各类应用主题,共119个类别。 使用训练脚本 + +```shell +python train.py --batch_size 16 \ + --model_name_or_path ernie-doc-base-zh \ + --epoch 5 \ + --output_dir ./checkpoints/ +``` + +根据通用参数释义可自行更改训练超参数和模型保存地址。 + +### 模型导出和预测 + +可以使用模型导出脚本将动态图模型转化成静态图: + +```shell +python export_model.py --batch_size 16 \ + --model_name_or_path finetuned_model \ + --max_seq_lenght 512 \ + --memory_length 128 \ + --static_path ./my_static_model/ +``` + +也可以直接使用预测脚本将`static_mode`设为True (设置成False则使用动态图预测),直接完成转化静态图和使用静态图预测的步骤: + +```shell +python predict.py --static_mode True \ + --dataset iflytek \ + --batch_size 16 \ + --model_name_or_path finetuned_model \ + --max_seq_lenght 512 \ + --memory_length 128 \ + --static_path ./my_static_model/ \ + --test_results_file ./test_results.json +``` + +模型输出的`test_results_file`示例: + +```shell +{"id": "2590", "label": "70"} +{"id": "2591", "label": "91"} +{"id": "2592", "label": "20"} +{"id": "2593", "label": "28"} +{"id": "2594", "label": "95"} +{"id": "2595", "label": "116"} +{"id": "2596", "label": "59"} +{"id": "2597", "label": "22"} +``` diff --git a/examples/text_classification/ernie_doc/__init__.py b/examples/text_classification/ernie_doc/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/text_classification/ernie_doc/data.py b/examples/text_classification/ernie_doc/data.py new file mode 100644 index 000000000000..5b54f048adcc --- /dev/null +++ b/examples/text_classification/ernie_doc/data.py @@ -0,0 +1,1248 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import json +from collections import namedtuple +import numpy as np +from paddle.utils import try_import +from paddlenlp.transformers import tokenize_chinese_chars +from paddlenlp.utils.log import logger + + +def get_related_pos(insts, seq_len, memory_len=128): + """generate relative postion ids""" + beg = seq_len + seq_len + memory_len + r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ + list(range(0, seq_len)) for i in range(len(insts))] + return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) + + +def pad_batch_data(insts, + insts_data_type="int64", + pad_idx=0, + final_cls=False, + pad_max_len=None, + return_pos=False, + return_input_mask=False, + return_max_len=False, + return_num_token=False, + return_seq_lens=False): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. + """ + return_list = [] + if pad_max_len: + max_len = pad_max_len + else: + max_len = max(len(inst) for inst in insts) + # Any token included in dict can be used to pad, since the paddings' loss + # will be masked out by weights and make no effect on parameter gradients. + + # Input id + if final_cls: + inst_data = np.array([ + inst[:-1] + list([pad_idx] * (max_len - len(inst))) + [inst[-1]] + for inst in insts + ]) + else: + inst_data = np.array( + [inst + list([pad_idx] * (max_len - len(inst))) for inst in insts]) + return_list += [inst_data.astype(insts_data_type).reshape([-1, max_len, 1])] + + # Position id + if return_pos: + inst_pos = np.array([ + list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst)) + for inst in insts + ]) + + return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])] + + if return_input_mask: + # This is used to avoid attention on paddings. + if final_cls: + input_mask_data = np.array([[1] * len(inst[:-1]) + [0] * + (max_len - len(inst)) + [1] + for inst in insts]) + else: + input_mask_data = np.array([[1] * len(inst) + [0] * + (max_len - len(inst)) + for inst in insts]) + input_mask_data = np.expand_dims(input_mask_data, axis=-1) + return_list += [input_mask_data.astype("float32")] + + if return_max_len: + return_list += [max_len] + + if return_num_token: + num_token = 0 + for inst in insts: + num_token += len(inst) + return_list += [num_token] + + if return_seq_lens: + seq_lens_type = [-1] + seq_lens = np.array([len(inst) for inst in insts]) + return_list += [seq_lens.astype("int64").reshape(seq_lens_type)] + + return return_list if len(return_list) > 1 else return_list[0] + + +class TextPreprocessor(object): + def __call__(self, text): + raise NotImplementedError("TextPreprocessor object can't be called") + + +class ImdbTextPreprocessor(TextPreprocessor): + def __call__(self, text): + text = text.strip().replace('

', ' ') + text = text.replace('\t', '') + return text + + +class HYPTextPreprocessor(TextPreprocessor): + def __init__(self): + self.bs4 = try_import('bs4') + + def __call__(self, text): + text = self.bs4.BeautifulSoup(text, "html.parser").get_text() + text = text.strip().replace('\n', '').replace('\t', '') + return text + + +class ClassifierIterator(object): + def __init__(self, + dataset, + batch_size, + tokenizer, + trainer_num, + trainer_id, + max_seq_length=512, + memory_len=128, + repeat_input=False, + in_tokens=False, + mode="train", + random_seed=None, + preprocess_text_fn=None): + self.batch_size = batch_size + self.tokenizer = tokenizer + self.trainer_num = trainer_num + self.trainer_id = trainer_id + self.max_seq_length = max_seq_length + self.memory_len = memory_len + self.repeat_input = repeat_input + self.in_tokens = in_tokens + self.dataset = [data for data in dataset] + self.num_examples = None + self.mode = mode + self.shuffle = True if mode == "train" else False + if random_seed is None: + random_seed = 12345 + self.random_seed = random_seed + self.preprocess_text_fn = preprocess_text_fn + + def shuffle_sample(self): + if self.shuffle: + self.global_rng = np.random.RandomState(self.random_seed) + self.global_rng.shuffle(self.dataset) + + def _cnt_list(self, inp): + """Cnt_list""" + cnt = 0 + for lit in inp: + if lit: + cnt += 1 + return cnt + + def _convert_to_features(self, example, qid): + """ + Convert example to features fed into model + """ + if "text" in example: # imdb + text = example["text"] + elif "sentence" in example: # iflytek + text = example["sentence"] + + if self.preprocess_text_fn: + text = self.preprocess_text_fn(text) + if "label" in example: + label = example["label"] + else: + label = "-1" + doc_spans = [] + _DocSpan = namedtuple("DocSpan", ["start", "length"]) + start_offset = 0 + max_tokens_for_doc = self.max_seq_length - 2 + tokens_a = self.tokenizer.tokenize(text) + while start_offset < len(tokens_a): + length = len(tokens_a) - start_offset + if length > max_tokens_for_doc: + length = max_tokens_for_doc + doc_spans.append(_DocSpan(start=start_offset, length=length)) + if start_offset + length == len(tokens_a): + break + start_offset += min(length, self.memory_len) + + features = [] + Feature = namedtuple("Feature", + ["src_ids", "label_id", "qid", "cal_loss"]) + for (doc_span_index, doc_span) in enumerate(doc_spans): + tokens = tokens_a[doc_span.start:doc_span.start + + doc_span.length] + ["[SEP]"] + ["[CLS]"] + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + features.append( + Feature( + src_ids=token_ids, label_id=label, qid=qid, cal_loss=1)) + + if self.repeat_input: + features_repeat = features + features = list(map(lambda x: x._replace(cal_loss=0), features)) + features = features + features_repeat + return features + + def _get_samples(self, pre_batch_list, is_last=False): + if is_last: + # Pad batch + len_doc = [len(doc) for doc in pre_batch_list] + max_len_idx = len_doc.index(max(len_doc)) + dirty_sample = pre_batch_list[max_len_idx][-1]._replace(cal_loss=0) + for sample_list in pre_batch_list: + sample_list.extend([dirty_sample] * + (max(len_doc) - len(sample_list))) + + samples = [] + min_len = min([len(doc) for doc in pre_batch_list]) + for cnt in range(min_len): + for batch_idx in range(self.batch_size * self.trainer_num): + sample = pre_batch_list[batch_idx][cnt] + samples.append(sample) + + for idx in range(len(pre_batch_list)): + pre_batch_list[idx] = pre_batch_list[idx][min_len:] + return samples + + def _pad_batch_records(self, batch_records, gather_idx=[]): + batch_token_ids = [record.src_ids for record in batch_records] + if batch_records[0].label_id is not None: + batch_labels = [record.label_id for record in batch_records] + batch_labels = np.array(batch_labels).astype("int64").reshape( + [-1, 1]) + else: + batch_labels = np.array([]).astype("int64").reshape([-1, 1]) + # Qid + if batch_records[-1].qid is not None: + batch_qids = [record.qid for record in batch_records] + batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1]) + else: + batch_qids = np.array([]).astype("int64").reshape([-1, 1]) + + if gather_idx: + batch_gather_idx = np.array(gather_idx).astype("int64").reshape( + [-1, 1]) + need_cal_loss = np.array([1]).astype("int64") + else: + batch_gather_idx = np.array(list(range(len(batch_records)))).astype( + "int64").reshape([-1, 1]) + need_cal_loss = np.array([0]).astype("int64") + + # Padding + padded_token_ids, input_mask = pad_batch_data( + batch_token_ids, pad_idx=self.tokenizer.pad_token_id, pad_max_len=self.max_seq_length, \ + final_cls=True, return_input_mask=True) + padded_task_ids = np.zeros_like(padded_token_ids, dtype="int64") + padded_position_ids = get_related_pos(padded_token_ids, \ + self.max_seq_length, self.memory_len) + + return_list = [ + padded_token_ids, padded_position_ids, padded_task_ids, input_mask, + batch_labels, batch_qids, batch_gather_idx, need_cal_loss + ] + return return_list + + def _prepare_batch_data(self, examples): + batch_records, max_len, gather_idx = [], 0, [] + for index, example in enumerate(examples): + max_len = max(max_len, len(example.src_ids)) + if self.in_tokens: + to_append = (len(batch_records) + 1 + ) * max_len <= self.batch_size + else: + to_append = len(batch_records) < self.batch_size + if to_append: + batch_records.append(example) + if example.cal_loss == 1: + gather_idx.append(index % self.batch_size) + else: + yield self._pad_batch_records(batch_records, gather_idx) + batch_records, max_len = [example], len(example.src_ids) + gather_idx = [index % self.batch_size + ] if example.cal_loss == 1 else [] + yield self._pad_batch_records(batch_records, gather_idx) + + def _create_instances(self): + examples = self.dataset + pre_batch_list = [] + insert_idx = [] + for qid, example in enumerate(examples): + features = self._convert_to_features(example, qid) + if self._cnt_list( + pre_batch_list) < self.batch_size * self.trainer_num: + if insert_idx: + pre_batch_list[insert_idx[0]] = features + insert_idx.pop(0) + else: + pre_batch_list.append(features) + if self._cnt_list( + pre_batch_list) == self.batch_size * self.trainer_num: + assert self._cnt_list(pre_batch_list) == len( + pre_batch_list), "the two value must be equal" + assert not insert_idx, "the insert_idx must be null" + sample_batch = self._get_samples(pre_batch_list) + + for idx, lit in enumerate(pre_batch_list): + if not lit: + insert_idx.append(idx) + for batch_records in self._prepare_batch_data(sample_batch): + yield batch_records + + if self.mode != "train": + if self._cnt_list(pre_batch_list): + pre_batch_list += [ + [] + for _ in range(self.batch_size * self.trainer_num - + self._cnt_list(pre_batch_list)) + ] + sample_batch = self._get_samples(pre_batch_list, is_last=True) + for batch_records in self._prepare_batch_data(sample_batch): + yield batch_records + + def __call__(self): + curr_id = 0 + for batch_records in self._create_instances(): + if curr_id == self.trainer_id or self.mode != "train": + yield batch_records + curr_id = (curr_id + 1) % self.trainer_num + + def get_num_examples(self): + if self.num_examples is None: + self.num_examples = 0 + for qid, example in enumerate(self.dataset): + self.num_examples += len( + self._convert_to_features(example, qid)) + return self.num_examples + + +class MRCIterator(ClassifierIterator): + """ + Machine Reading Comprehension iterator. Only for answer extraction. + """ + + def __init__(self, + dataset, + batch_size, + tokenizer, + trainer_num, + trainer_id, + max_seq_length=512, + memory_len=128, + repeat_input=False, + in_tokens=False, + mode="train", + random_seed=None, + doc_stride=128, + max_query_length=64): + super(MRCIterator, self).__init__( + dataset, + batch_size, + tokenizer, + trainer_num, + trainer_id, + max_seq_length, + memory_len, + repeat_input, + in_tokens, + mode, + random_seed, + preprocess_text_fn=None) + self.doc_stride = doc_stride + self.max_query_length = max_query_length + self.examples = [] + self.features = [] + self.features_all = [] + self._preprocess_data() + + def shuffle_sample(self): + if self.shuffle: + self.global_rng = np.random.RandomState(self.random_seed) + self.global_rng.shuffle(self.features_all) + + def _convert_qa_to_examples(self): + Example = namedtuple('Example', [ + 'qas_id', 'question_text', 'doc_tokens', 'orig_answer_text', + 'start_position', 'end_position' + ]) + examples = [] + for qa in self.dataset: + qas_id = qa["id"] + question_text = qa["question"] + context = qa["context"] + start_pos = None + end_pos = None + orig_answer_text = None + if self.mode == 'train': + if len(qa["answers"]) != 1: + raise ValueError( + "For training, each question should have exactly 1 answer." + ) + orig_answer_text = qa["answers"][0] + answer_offset = qa["answer_starts"][0] + answer_length = len(orig_answer_text) + doc_tokens = [ + context[:answer_offset], + context[answer_offset:answer_offset + answer_length], + context[answer_offset + answer_length:] + ] + + start_pos = 1 + end_pos = 1 + + actual_text = " ".join(doc_tokens[start_pos:(end_pos + 1)]) + if orig_answer_text.islower(): + actual_text = actual_text.lower() + if actual_text.find(orig_answer_text) == -1: + logger.info("Could not find answer: '%s' vs. '%s'" % + (actual_text, orig_answer_text)) + continue + + else: + doc_tokens = tokenize_chinese_chars(context) + + example = Example( + qas_id=qas_id, + question_text=question_text, + doc_tokens=doc_tokens, + orig_answer_text=orig_answer_text, + start_position=start_pos, + end_position=end_pos) + examples.append(example) + return examples + + def _convert_example_to_feature(self, examples): + Feature = namedtuple("Feature", [ + "qid", "example_index", "doc_span_index", "tokens", + "token_to_orig_map", "token_is_max_context", "src_ids", + "start_position", "end_position", "cal_loss" + ]) + features = [] + self.features_all = [] + unique_id = 1000 + is_training = self.mode == "train" + print("total {} examples".format(len(examples)), flush=True) + for (example_index, example) in enumerate(examples): + query_tokens = self.tokenizer.tokenize(example.question_text) + if len(query_tokens) > self.max_query_length: + query_tokens = query_tokens[0:self.max_query_length] + tok_to_orig_index = [] + orig_to_tok_index = [] + all_doc_tokens = [] + for (i, token) in enumerate(example.doc_tokens): + orig_to_tok_index.append(len(all_doc_tokens)) + sub_tokens = self.tokenizer.tokenize(token) + for sub_token in sub_tokens: + tok_to_orig_index.append(i) + all_doc_tokens.append(sub_token) + + tok_start_position = None + tok_end_position = None + if is_training: + tok_start_position = orig_to_tok_index[example.start_position] + if example.end_position < len(example.doc_tokens) - 1: + tok_end_position = orig_to_tok_index[example.end_position + + 1] - 1 + else: + tok_end_position = len(all_doc_tokens) - 1 + (tok_start_position, + tok_end_position) = self._improve_answer_span( + all_doc_tokens, tok_start_position, tok_end_position, + example.orig_answer_text) + + max_tokens_for_doc = self.max_seq_length - len(query_tokens) - 3 + _DocSpan = namedtuple("DocSpan", ["start", "length"]) + doc_spans = [] + start_offset = 0 + while start_offset < len(all_doc_tokens): + length = len(all_doc_tokens) - start_offset + if length > max_tokens_for_doc: + length = max_tokens_for_doc + doc_spans.append(_DocSpan(start=start_offset, length=length)) + if start_offset + length == len(all_doc_tokens): + break + start_offset += min(length, self.doc_stride) + + features_each = [] + for (doc_span_index, doc_span) in enumerate(doc_spans): + tokens = [] + token_to_orig_map = {} + token_is_max_context = {} + tokens.append("[CLS]") + for i in range(doc_span.length): + split_token_index = doc_span.start + i + token_to_orig_map[i + 1] = tok_to_orig_index[ + split_token_index] + is_max_context = self._check_is_max_context( + doc_spans, doc_span_index, split_token_index) + token_is_max_context[i + 1] = is_max_context + tokens += all_doc_tokens[doc_span.start:doc_span.start + + doc_span.length] + tokens.append("[SEP]") + + for token in query_tokens: + tokens.append(token) + tokens.append("[SEP]") + + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + start_position = None + end_position = None + if is_training: + doc_start = doc_span.start + doc_end = doc_span.start + doc_span.length - 1 + out_of_span = False + if not (tok_start_position >= doc_start and + tok_end_position <= doc_end): + out_of_span = True + if out_of_span: + start_position = 0 + end_position = 0 + else: + doc_offset = 1 # len(query_tokens) + 2 + start_position = tok_start_position - doc_start + doc_offset + end_position = tok_end_position - doc_start + doc_offset + + feature = Feature( + qid=unique_id, + example_index=example_index, + doc_span_index=doc_span_index, + tokens=tokens, + token_to_orig_map=token_to_orig_map, + token_is_max_context=token_is_max_context, + src_ids=token_ids, + start_position=start_position, + end_position=end_position, + cal_loss=1) + features.append(feature) + features_each.append(feature) + if example_index % 1000 == 0: + print( + "processing {} examples".format(example_index), + flush=True) + + unique_id += 1 + # Repeat + if self.repeat_input: + features_each_repeat = features_each + features_each = list( + map(lambda x: x._replace(cla_loss=0), features_each)) + features_each += features_each_repeat + + self.features_all.append(features_each) + + return features + + def _preprocess_data(self): + # Construct examples + self.examples = self._convert_qa_to_examples() + # Construct features + self.features = self._convert_example_to_feature(self.examples) + + def get_num_examples(self): + if not self.features_all: + self._preprocess_data() + return len(sum(self.features_all, [])) + + def _improve_answer_span(self, doc_tokens, input_start, input_end, + orig_answer_text): + """Improve answer span""" + tok_answer_text = " ".join(self.tokenizer.tokenize(orig_answer_text)) + + for new_start in range(input_start, input_end + 1): + for new_end in range(input_end, new_start - 1, -1): + text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) + if text_span == tok_answer_text: + return (new_start, new_end) + + return (input_start, input_end) + + def _check_is_max_context(self, doc_spans, cur_span_index, position): + """Check is max context""" + best_score = None + best_span_index = None + for (span_index, doc_span) in enumerate(doc_spans): + end = doc_span.start + doc_span.length - 1 + if position < doc_span.start: + break + if position > end: + continue + num_left_context = position - doc_span.start + num_right_context = end - position + score = min(num_left_context, + num_right_context) + 0.01 * doc_span.length + if best_score is None or score > best_score: + best_score = score + best_span_index = span_index + if best_span_index > cur_span_index: + return False + + return cur_span_index == best_span_index + + def _pad_batch_records(self, batch_records, gather_idx=[]): + """Pad batch data""" + batch_token_ids = [record.src_ids for record in batch_records] + + if self.mode == "train": + batch_start_position = [ + record.start_position for record in batch_records + ] + batch_end_position = [ + record.end_position for record in batch_records + ] + batch_start_position = np.array(batch_start_position).astype( + "int64").reshape([-1, 1]) + batch_end_position = np.array(batch_end_position).astype( + "int64").reshape([-1, 1]) + else: + batch_size = len(batch_token_ids) + batch_start_position = np.zeros( + shape=[batch_size, 1], dtype="int64") + batch_end_position = np.zeros(shape=[batch_size, 1], dtype="int64") + + batch_qids = [record.qid for record in batch_records] + batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1]) + + if gather_idx: + batch_gather_idx = np.array(gather_idx).astype("int64").reshape( + [-1, 1]) + need_cal_loss = np.array([1]).astype("int64") + else: + batch_gather_idx = np.array(list(range(len(batch_records)))).astype( + "int64").reshape([-1, 1]) + need_cal_loss = np.array([0]).astype("int64") + + # padding + padded_token_ids, input_mask = pad_batch_data( + batch_token_ids, + pad_idx=self.tokenizer.pad_token_id, + pad_max_len=self.max_seq_length, + return_input_mask=True) + padded_task_ids = np.zeros_like(padded_token_ids, dtype="int64") + padded_position_ids = get_related_pos( + padded_task_ids, self.max_seq_length, self.memory_len) + + return_list = [ + padded_token_ids, padded_position_ids, padded_task_ids, input_mask, + batch_start_position, batch_end_position, batch_qids, + batch_gather_idx, need_cal_loss + ] + + return return_list + + def _create_instances(self): + """Generate batch records""" + pre_batch_list = [] + insert_idx = [] + for qid, features in enumerate(self.features_all): + if self._cnt_list( + pre_batch_list) < self.batch_size * self.trainer_num: + if insert_idx: + pre_batch_list[insert_idx[0]] = features + insert_idx.pop(0) + else: + pre_batch_list.append(features) + if self._cnt_list( + pre_batch_list) == self.batch_size * self.trainer_num: + assert self._cnt_list(pre_batch_list) == len( + pre_batch_list), "the two value must be equal" + assert not insert_idx, "the insert_idx must be null" + sample_batch = self._get_samples(pre_batch_list) + + for idx, lit in enumerate(pre_batch_list): + if not lit: + insert_idx.append(idx) + for batch_records in self._prepare_batch_data(sample_batch): + yield batch_records + + if self.mode != "train": + if self._cnt_list(pre_batch_list): + pre_batch_list += [ + [] + for _ in range(self.batch_size * self.trainer_num - + self._cnt_list(pre_batch_list)) + ] + sample_batch = self._get_samples(pre_batch_list, is_last=True) + for batch_records in self._prepare_batch_data(sample_batch): + yield batch_records + + +class MCQIterator(MRCIterator): + """ + Multiple choice question iterator. + """ + + def __init__(self, + dataset, + batch_size, + tokenizer, + trainer_num, + trainer_id, + max_seq_length=512, + memory_len=128, + repeat_input=False, + in_tokens=False, + mode="train", + random_seed=None, + doc_stride=128, + max_query_length=64, + choice_num=4): + self.choice_num = choice_num + super(MCQIterator, self).__init__( + dataset, batch_size, tokenizer, trainer_num, trainer_id, + max_seq_length, memory_len, repeat_input, in_tokens, mode, + random_seed) + + def _truncate_seq_pair(self, tokens_a, tokens_b, max_length): + """Truncates a sequence pair in place to the maximum length.""" + + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + tokens_a = list(tokens_a) + tokens_b = list(tokens_b) + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_length: + break + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + return tokens_a, tokens_b + + def _convert_qa_to_examples(self): + Example = namedtuple( + 'Example', ['qas_id', 'context', 'question', 'choice', 'label']) + examples = [] + for qas_id, qa in enumerate(self.dataset): + context = '\n'.join(qa['context']).lower() + question = qa['question'].lower() + choice = [c.lower() for c in qa['choice']] + # pad empty choice + for k in range(len(choice), self.choice_num): + choice.append('') + label = qa['label'] + + example = Example( + qas_id=qas_id, + context=context, + question=question, + choice=choice, + label=label) + examples.append(example) + return examples + + def _convert_example_to_feature(self, examples): + Feature = namedtuple( + 'Feature', ['qid', 'src_ids', 'segment_ids', 'label', 'cal_loss']) + features = [] + self.features_all = [] + pad_token_id = self.tokenizer.pad_token_id + for (ex_index, example) in enumerate(examples): + context_tokens = self.tokenizer.tokenize(example.context) + question_tokens = self.tokenizer.tokenize(example.question) + choice_tokens_lst = [ + self.tokenizer.tokenize(choice) for choice in example.choice + ] + # nums = 4 + question_choice_pairs = \ + [self._truncate_seq_pair(question_tokens, choice_tokens, self.max_query_length - 2) + for choice_tokens in choice_tokens_lst] + total_qc_num = sum( + [(len(q) + len(c)) for q, c in question_choice_pairs]) + max_tokens_for_doc = self.max_seq_length - total_qc_num - 4 + _DocSpan = namedtuple("DocSpan", ["start", "length"]) + doc_spans = [] + start_offset = 0 + + while start_offset < len(context_tokens): + length = len(context_tokens) - start_offset + if length > max_tokens_for_doc: + length = max_tokens_for_doc + doc_spans.append(_DocSpan(start=start_offset, length=length)) + if start_offset + length == len(context_tokens): + break + start_offset += min(length, self.doc_stride) + + features_each = [] + for (doc_span_index, doc_span) in enumerate(doc_spans): + qa_features = [] + for q_tokens, c_tokens in question_choice_pairs: + segment_tokens = ['[CLS]'] + token_type_ids = [0] + + segment_tokens += context_tokens[ + doc_span.start:doc_span.start + doc_span.length] + token_type_ids += [0] * doc_span.length + + segment_tokens += ['[SEP]'] + token_type_ids += [0] + + segment_tokens += q_tokens + token_type_ids += [1] * len(q_tokens) + + segment_tokens += ['[SEP]'] + token_type_ids += [1] + + segment_tokens += c_tokens + token_type_ids += [1] * len(c_tokens) + + segment_tokens += ['[SEP]'] + token_type_ids += [1] + + input_ids = self.tokenizer.convert_tokens_to_ids( + segment_tokens) + feature = Feature( + qid=example.qas_id, + label=example.label, + src_ids=input_ids, + segment_ids=token_type_ids, + cal_loss=1) + qa_features.append(feature) + + features.append(qa_features) + features_each.append(qa_features) + + # Repeat + if self.repeat_input: + features_each_repeat = features_each + features_each = list( + map(lambda x: x._replace(cla_loss=0), features_each)) + features_each += features_each_repeat + + self.features_all.append(features_each) + + return features + + def _pad_batch_records(self, batch_records, gather_idx=[]): + batch_token_ids = [[record.src_ids for record in records] + for records in batch_records] + if batch_records[0][0].label is not None: + batch_labels = [[record.label for record in records] + for records in batch_records] + batch_labels = np.array(batch_labels).astype("int64").reshape( + [-1, 1]) + else: + batch_labels = np.array([]).astype("int64").reshape([-1, 1]) + # Qid + batch_qids = [[record.qid for record in records] + for records in batch_records] + batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1]) + + if gather_idx: + batch_gather_idx = np.array(gather_idx).astype("int64").reshape( + [-1, 1]) + need_cal_loss = np.array([1]).astype("int64") + else: + batch_gather_idx = np.array(list(range(len(batch_records)))).astype( + "int64").reshape([-1, 1]) + need_cal_loss = np.array([0]).astype("int64") + + batch_task_ids = [[record.segment_ids for record in records] + for records in batch_records] + + # Padding + batch_padded_token_ids = [] + batch_input_mask = [] + batch_padded_task_ids = [] + batch_padded_position_ids = [] + batch_size = len(batch_token_ids) + for i in range(batch_size): + padded_token_ids, input_mask = pad_batch_data( + batch_token_ids[i], + pad_idx=self.tokenizer.pad_token_id, + pad_max_len=self.max_seq_length, + return_input_mask=True) + padded_task_ids = pad_batch_data( + batch_task_ids[i], + pad_idx=self.tokenizer.pad_token_id, + pad_max_len=self.max_seq_length) + + padded_position_ids = get_related_pos( + padded_task_ids, self.max_seq_length, self.memory_len) + + batch_padded_token_ids.append(padded_token_ids) + batch_input_mask.append(input_mask) + batch_padded_task_ids.append(padded_task_ids) + batch_padded_position_ids.append(padded_position_ids) + + batch_padded_token_ids = np.array(batch_padded_token_ids).astype( + "int64").reshape([batch_size * self.choice_num, -1, 1]) + batch_padded_position_ids = np.array(batch_padded_position_ids).astype( + "int64").reshape([batch_size * self.choice_num, -1, 1]) + batch_padded_task_ids = np.array(batch_padded_task_ids).astype( + "int64").reshape([batch_size * self.choice_num, -1, 1]) + batch_input_mask = np.array(batch_input_mask).astype("float32").reshape( + [batch_size * self.choice_num, -1, 1]) + + return_list = [ + batch_padded_token_ids, batch_padded_position_ids, + batch_padded_task_ids, batch_input_mask, batch_labels, batch_qids, + batch_gather_idx, need_cal_loss + ] + return return_list + + def _prepare_batch_data(self, examples_list): + batch_records, max_len, gather_idx = [], 0, [] + real_batch_size = self.batch_size * self.choice_num + index = 0 + for examples in examples_list: + records = [] + gather_idx_candidate = [] + for example in examples: + if example.cal_loss == 1: + gather_idx_candidate.append(index % real_batch_size) + max_len = max(max_len, len(example.src_ids)) + records.append(example) + index += 1 + + if self.in_tokens: + to_append = (len(batch_records) + 1 + ) * self.choice_num * max_len <= self.batch_size + else: + to_append = len(batch_records) < self.batch_size + if to_append: + batch_records.append(records) + gather_idx += gather_idx_candidate + else: + yield self._pad_batch_records(batch_records, gather_idx) + batch_records, max_len = [records], max( + len(record.src_ids) for record in records) + start_index = index - len(records) + 1 + gather_idx = gather_idx_candidate + if len(batch_records) > 0: + yield self._pad_batch_records(batch_records, gather_idx) + + def _get_samples(self, pre_batch_list, is_last=False): + if is_last: + # Pad batch + len_doc = [[len(doc) for doc in doc_list] + for doc_list in pre_batch_list] + len_doc = list(itertools.chain(*len_doc)) + max_len_idx = len_doc.index(max(len_doc)) + doc_idx = max_len_idx % self.choice_num + doc_list_idx = max_len_idx // self.choice_num + dirty_sample = pre_batch_list[doc_list_idx][doc_idx][-1]._replace( + cal_loss=0) + for sample_list in pre_batch_list: + for samples in sample_list: + samples.extend([dirty_sample] * + (max(len_doc) - len(samples))) + samples = [] + min_len = min([len(doc) for doc in pre_batch_list]) + for cnt in range(min_len): + for batch_idx in range(self.batch_size * self.trainer_num): + sample = pre_batch_list[batch_idx][cnt] + samples.append(sample) + + for idx in range(len(pre_batch_list)): + pre_batch_list[idx] = pre_batch_list[idx][min_len:] + return samples + + +class SemanticMatchingIterator(MRCIterator): + def _convert_qa_to_examples(self): + Example = namedtuple('Example', + ['qid', 'text_a', 'text_b', 'text_c', 'label']) + examples = [] + for qid, qa in enumerate(self.dataset): + text_a, text_b, text_c = list( + map(lambda x: x.replace('\n', '').strip(), + [qa["text_a"], qa["text_b"], qa["text_c"]])) + + example = Example( + qid=qid, + text_a=text_a, + text_b=text_b, + text_c=text_c, + label=qa["label"]) + examples += [example] + return examples + + def _create_tokens_and_type_id(self, text_a_tokens, text_b_tokens, start, + length): + tokens = ['[CLS]'] + text_a_tokens[start:start + length] + [ + '[SEP]' + ] + text_b_tokens[start:start + length] + ['[SEP]'] + token_type_ids = [0] + [0] * (length + 1) + [1] * (length + 1) + return tokens, token_type_ids + + def _convert_example_to_feature(self, examples): + Feature = namedtuple('Feature', [ + 'qid', 'src_ids', 'segment_ids', 'pair_src_ids', 'pair_segment_ids', + 'label', 'cal_loss' + ]) + features = [] + self.features_all = [] + pad_token_id = self.tokenizer.pad_token_id + for (ex_index, example) in enumerate(examples): + text_a_tokens = self.tokenizer.tokenize(example.text_a) + text_b_tokens = self.tokenizer.tokenize(example.text_b) + text_c_tokens = self.tokenizer.tokenize(example.text_c) + a_len, b_len, c_len = list( + map(lambda x: len(x), + [text_a_tokens, text_b_tokens, text_c_tokens])) + + # Align 3 text + min_text_len = min([a_len, b_len, c_len]) + text_a_tokens = text_a_tokens[:min_text_len] + text_b_tokens = text_b_tokens[:min_text_len] + text_c_tokens = text_c_tokens[:min_text_len] + + _DocSpan = namedtuple("DocSpan", ["start", "length"]) + doc_spans = [] + start_offset = 0 + + max_tokens_for_doc = (self.max_seq_length - 3) // 2 + + while start_offset < len(text_a_tokens): + length = len(text_a_tokens) - start_offset + if length > max_tokens_for_doc: + length = max_tokens_for_doc + doc_spans.append(_DocSpan(start=start_offset, length=length)) + if start_offset + length == len(text_a_tokens): + break + start_offset += min(length, self.doc_stride) + + features_each = [] + for (doc_span_index, doc_span) in enumerate(doc_spans): + tokens1, token_type_ids1 = self._create_tokens_and_type_id( + text_a_tokens, text_b_tokens, doc_span.start, + doc_span.length) + tokens2, token_type_ids2 = self._create_tokens_and_type_id( + text_a_tokens, text_c_tokens, doc_span.start, + doc_span.length) + + input_ids1 = self.tokenizer.convert_tokens_to_ids(tokens1) + input_ids2 = self.tokenizer.convert_tokens_to_ids(tokens2) + feature = Feature( + qid=example.qid, + label=example.label, + src_ids=input_ids1, + segment_ids=token_type_ids1, + pair_src_ids=input_ids2, + pair_segment_ids=token_type_ids2, + cal_loss=1) + + features.append(feature) + features_each.append(feature) + + # Repeat + if self.repeat_input: + features_each_repeat = features_each + features_each = list( + map(lambda x: x._replace(cla_loss=0), features_each)) + features_each += features_each_repeat + + self.features_all.append(features_each) + + return features + + def _create_pad_ids(self, batch_records, prefix=""): + src_ids = prefix + "src_ids" + segment_ids = prefix + "segment_ids" + batch_token_ids = [getattr(record, src_ids) for record in batch_records] + batch_task_ids = [ + getattr(record, segment_ids) for record in batch_records + ] + + # Padding + padded_token_ids, input_mask = pad_batch_data( + batch_token_ids, + pad_idx=self.tokenizer.pad_token_id, + pad_max_len=self.max_seq_length, + return_input_mask=True) + padded_task_ids = pad_batch_data( + batch_task_ids, + pad_idx=self.tokenizer.pad_token_id, + pad_max_len=self.max_seq_length) + + padded_position_ids = get_related_pos( + padded_task_ids, self.max_seq_length, self.memory_len) + + return [ + padded_token_ids, padded_position_ids, padded_task_ids, input_mask + ] + + def _pad_batch_records(self, batch_records, gather_idx=[]): + if batch_records[0].label is not None: + batch_labels = [record.label for record in batch_records] + batch_labels = np.array(batch_labels).astype("int64").reshape( + [-1, 1]) + else: + batch_labels = np.array([]).astype("int64").reshape([-1, 1]) + # Qid + batch_qids = [record.qid for record in batch_records] + batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1]) + + if gather_idx: + batch_gather_idx = np.array(gather_idx).astype("int64").reshape( + [-1, 1]) + need_cal_loss = np.array([1]).astype("int64") + else: + batch_gather_idx = np.array(list(range(len(batch_records)))).astype( + "int64").reshape([-1, 1]) + need_cal_loss = np.array([0]).astype("int64") + + return_list = self._create_pad_ids(batch_records) \ + + self._create_pad_ids(batch_records, "pair_") \ + + [batch_labels, batch_qids, batch_gather_idx, need_cal_loss] + return return_list + + +class SequenceLabelingIterator(ClassifierIterator): + def __init__(self, + dataset, + batch_size, + tokenizer, + trainer_num, + trainer_id, + max_seq_length=512, + memory_len=128, + repeat_input=False, + in_tokens=False, + mode="train", + random_seed=None, + no_entity_id=-1): + super(SequenceLabelingIterator, self).__init__( + dataset, + batch_size, + tokenizer, + trainer_num, + trainer_id, + max_seq_length, + memory_len, + repeat_input, + in_tokens, + mode, + random_seed, + preprocess_text_fn=None) + self.no_entity_id = no_entity_id + + def _convert_to_features(self, example, qid): + """ + Convert example to features fed into model + """ + tokens = example['tokens'] + label = example["labels"] + doc_spans = [] + _DocSpan = namedtuple("DocSpan", ["start", "length"]) + start_offset = 0 + max_tokens_for_doc = self.max_seq_length - 2 + while start_offset < len(tokens): + length = len(tokens) - start_offset + if length > max_tokens_for_doc: + length = max_tokens_for_doc + doc_spans.append(_DocSpan(start=start_offset, length=length)) + if start_offset + length == len(tokens): + break + start_offset += min(length, self.memory_len) + + features = [] + Feature = namedtuple("Feature", + ["src_ids", "label_ids", "qid", "cal_loss"]) + for (doc_span_index, doc_span) in enumerate(doc_spans): + curr_tokens = ["[CLS]"] + tokens[doc_span.start:doc_span.start + + doc_span.length] + ["[SEP]"] + token_ids = self.tokenizer.convert_tokens_to_ids(curr_tokens) + label = [self.no_entity_id + ] + label[doc_span.start:doc_span.start + + doc_span.length] + [self.no_entity_id] + + features.append( + Feature( + src_ids=token_ids, label_ids=label, qid=qid, cal_loss=1)) + + if self.repeat_input: + features_repeat = features + features = list(map(lambda x: x._replace(cal_loss=0), features)) + features = features + features_repeat + return features + + def _pad_batch_records(self, batch_records, gather_idx=[]): + batch_token_ids = [record.src_ids for record in batch_records] + batch_length = [len(record.src_ids) for record in batch_records] + batch_length = np.array(batch_length).astype("int64").reshape([-1, 1]) + + if batch_records[0].label_ids is not None: + batch_labels = [record.label_ids for record in batch_records] + else: + batch_labels = np.array([]).astype("int64").reshape([-1, 1]) + # Qid + if batch_records[-1].qid is not None: + batch_qids = [record.qid for record in batch_records] + batch_qids = np.array(batch_qids).astype("int64").reshape([-1, 1]) + else: + batch_qids = np.array([]).astype("int64").reshape([-1, 1]) + + if gather_idx: + batch_gather_idx = np.array(gather_idx).astype("int64").reshape( + [-1, 1]) + need_cal_loss = np.array([1]).astype("int64") + else: + batch_gather_idx = np.array(list(range(len(batch_records)))).astype( + "int64").reshape([-1, 1]) + need_cal_loss = np.array([0]).astype("int64") + # Padding + padded_token_ids, input_mask = pad_batch_data( + batch_token_ids, + pad_idx=self.tokenizer.pad_token_id, + pad_max_len=self.max_seq_length, + return_input_mask=True) + if batch_records[0].label_ids is not None: + padded_batch_labels = pad_batch_data( + batch_labels, + pad_idx=self.no_entity_id, + pad_max_len=self.max_seq_length) + padded_task_ids = np.zeros_like(padded_token_ids, dtype="int64") + padded_position_ids = get_related_pos(padded_token_ids, \ + self.max_seq_length, self.memory_len) + + return_list = [ + padded_token_ids, padded_position_ids, padded_task_ids, input_mask, + padded_batch_labels, batch_length, batch_qids, batch_gather_idx, + need_cal_loss + ] + return return_list + + +def to_json_file(task, label_dict, file_path): + if task == "iflytek": + filename = file_path + + with open(filename, 'w+') as f_obj: + for i, j in label_dict.items(): + tmp = dict() + tmp["id"] = str(i) + tmp["label"] = str(j) + json.dump(tmp, f_obj) + f_obj.write("\n") diff --git a/examples/text_classification/ernie_doc/export_model.py b/examples/text_classification/ernie_doc/export_model.py new file mode 100644 index 000000000000..9026aeda4047 --- /dev/null +++ b/examples/text_classification/ernie_doc/export_model.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import paddle +import shutil +from paddlenlp.utils.log import logger +from predict import LongDocClassifier + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--batch_size", default=16, type=int, + help="Batch size per GPU/CPU for predicting (In static mode, it should be the same as in model training process.)") +parser.add_argument("--model_name_or_path", type=str, default="ernie-doc-base-zh", + help="Pretraining or finetuned model name or path") +parser.add_argument("--max_seq_length", type=int, default=512, + help="The maximum total input sequence length after SentencePiece tokenization.") +parser.add_argument("--memory_length", type=int, default=128, help="Length of the retained previous heads.") +parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "gpu"], + help="Select cpu, gpu devices to train model.") +parser.add_argument("--dataset", default="iflytek", choices=["imdb", "iflytek", "thucnews", "hyp"], type=str, + help="The training dataset") +parser.add_argument("--static_path", default=None, type=str, + help="The path which your static model is at or where you want to save after converting.") + +args = parser.parse_args() +# yapf: enable + +if __name__ == "__main__": + paddle.set_device(args.device) + + if os.path.exists(args.model_name_or_path): + logger.info("init checkpoint from %s" % args.model_name_or_path) + + if args.static_path and os.path.exists(args.static_path): + logger.info("will remove the old model") + shutil.rmtree(args.static_path) + + predictor = LongDocClassifier( + model_name_or_path=args.model_name_or_path, + batch_size=args.batch_size, + max_seq_length=args.max_seq_length, + memory_len=args.memory_length, + static_mode=True, + static_path=args.static_path) diff --git a/examples/text_classification/ernie_doc/metrics.py b/examples/text_classification/ernie_doc/metrics.py new file mode 100644 index 000000000000..5152f2194431 --- /dev/null +++ b/examples/text_classification/ernie_doc/metrics.py @@ -0,0 +1,314 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import collections +import sys +import paddle +from paddle.utils import try_import +from paddlenlp.metrics.dureader import get_final_text, _compute_softmax, _get_best_indexes + +# Metric for ERNIE-DOCs + + +class F1(object): + def __init__(self, positive_label=1): + self.positive_label = positive_label + self.reset() + + def compute(self, preds, labels): + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + elif isinstance(preds, list): + preds = np.array(preds, dtype='float32') + if isinstance(labels, list): + labels = np.array(labels, dtype='int64') + elif isinstance(labels, paddle.Tensor): + labels = labels.numpy() + preds = np.argmax(preds, axis=1) + tp = ((preds == labels) & (labels == self.positive_label)).sum() + fn = ((preds != labels) & (labels == self.positive_label)).sum() + fp = ((preds != labels) & (preds == self.positive_label)).sum() + return tp, fp, fn + + def update(self, statistic): + tp, fp, fn = statistic + self.tp += tp + self.fp += fp + self.fn += fn + + def accumulate(self): + recall = self.tp / (self.tp + self.fn) + precision = self.tp / (self.tp + self.fp) + f1 = 2 * recall * precision / (recall + precision) + return f1 + + def reset(self): + self.tp = 0 + self.fp = 0 + self.fn = 0 + + +class EM_AND_F1(object): + def __init__(self): + self.nltk = try_import('nltk') + self.re = try_import('re') + + def _mixed_segmentation(self, in_str, rm_punc=False): + """mixed_segmentation""" + in_str = in_str.lower().strip() + segs_out = [] + temp_str = "" + sp_char = [ + '-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', + ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', + '」', '(', ')', '-', '~', '『', '』' + ] + for char in in_str: + if rm_punc and char in sp_char: + continue + pattern = '[\\u4e00-\\u9fa5]' + if self.re.search(pattern, char) or char in sp_char: + if temp_str != "": + ss = self.nltk.word_tokenize(temp_str) + segs_out.extend(ss) + temp_str = "" + segs_out.append(char) + else: + temp_str += char + + # Handling last part + if temp_str != "": + ss = self.nltk.word_tokenize(temp_str) + segs_out.extend(ss) + + return segs_out + + # Remove punctuation + def _remove_punctuation(self, in_str): + """remove_punctuation""" + in_str = in_str.lower().strip() + sp_char = [ + '-', ':', '_', '*', '^', '/', '\\', '~', '`', '+', '=', ',', '。', + ':', '?', '!', '“', '”', ';', '’', '《', '》', '……', '·', '、', '「', + '」', '(', ')', '-', '~', '『', '』' + ] + out_segs = [] + for char in in_str: + if char in sp_char: + continue + else: + out_segs.append(char) + return ''.join(out_segs) + + # Find longest common string + def _find_lcs(self, s1, s2): + m = [[0 for i in range(len(s2) + 1)] for j in range(len(s1) + 1)] + mmax = 0 + p = 0 + for i in range(len(s1)): + for j in range(len(s2)): + if s1[i] == s2[j]: + m[i + 1][j + 1] = m[i][j] + 1 + if m[i + 1][j + 1] > mmax: + mmax = m[i + 1][j + 1] + p = i + 1 + return s1[p - mmax:p], mmax + + def _calc_f1_score(self, answers, prediction): + f1_scores = [] + for ans in answers: + ans_segs = self._mixed_segmentation(ans, rm_punc=True) + prediction_segs = self._mixed_segmentation(prediction, rm_punc=True) + lcs, lcs_len = self._find_lcs(ans_segs, prediction_segs) + if lcs_len == 0: + f1_scores.append(0) + continue + precision = 1.0 * lcs_len / len(prediction_segs) + recall = 1.0 * lcs_len / len(ans_segs) + f1 = (2 * precision * recall) / (precision + recall) + f1_scores.append(f1) + return max(f1_scores) + + def _calc_em_score(self, answers, prediction): + em = 0 + for ans in answers: + ans_ = self._remove_punctuation(ans) + prediction_ = self._remove_punctuation(prediction) + if ans_ == prediction_: + em = 1 + break + return em + + def __call__(self, prediction, ground_truth): + f1 = 0 + em = 0 + total_count = 0 + skip_count = 0 + for instance in ground_truth: + total_count += 1 + query_id = instance['id'] + query_text = instance['question'].strip() + answers = instance["answers"] + if query_id not in prediction: + sys.stderr.write('Unanswered question: {}\n'.format(query_id)) + skip_count += 1 + continue + preds = str(prediction[query_id]) + f1 += self._calc_f1_score(answers, preds) + em += self._calc_em_score(answers, preds) + + f1_score = 100.0 * f1 / total_count + em_score = 100.0 * em / total_count + + avg_score = (f1_score + em_score) * 0.5 + return em_score, f1_score, avg_score, total_count + + +def compute_qa_predictions(all_examples, all_features, all_results, n_best_size, + max_answer_length, do_lower_case, tokenizer, + verbose): + """Write final predictions to the json file and log-odds of null if needed.""" + + example_index_to_features = collections.defaultdict(list) + for feature in all_features: + example_index_to_features[feature.example_index].append(feature) + + unique_id_to_result = {} + for result in all_results: + unique_id_to_result[result.unique_id] = result + + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", [ + "feature_index", "start_index", "end_index", "start_logit", + "end_logit" + ]) + + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + + for (example_index, example) in enumerate(all_examples): + features = example_index_to_features[example_index] + + prelim_predictions = [] + # Keep track of the minimum score of null start+end of position 0 + for (feature_index, feature) in enumerate(features): + result = unique_id_to_result[feature.qid] + start_indexes = _get_best_indexes(result.start_logits, n_best_size) + end_indexes = _get_best_indexes(result.end_logits, n_best_size) + + for start_index in start_indexes: + for end_index in end_indexes: + # We could hypothetically create invalid predictions, e.g., predict + # that the start of the span is in the question. We throw out all + # invalid predictions. + if start_index >= len(feature.tokens): + continue + if end_index >= len(feature.tokens): + continue + if start_index not in feature.token_to_orig_map: + continue + if end_index not in feature.token_to_orig_map: + continue + if not feature.token_is_max_context.get(start_index, False): + continue + if end_index < start_index: + continue + length = end_index - start_index + 1 + if length > max_answer_length: + continue + prelim_predictions.append( + _PrelimPrediction( + feature_index=feature_index, + start_index=start_index, + end_index=end_index, + start_logit=result.start_logits[start_index], + end_logit=result.end_logits[end_index])) + + prelim_predictions = sorted( + prelim_predictions, + key=lambda x: (x.start_logit + x.end_logit), + reverse=True) + + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_logit", "end_logit"]) + + seen_predictions = {} + nbest = [] + for pred in prelim_predictions: + if len(nbest) >= n_best_size: + break + feature = features[pred.feature_index] + if pred.start_index > 0: # this is a non-null prediction + tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1 + )] + orig_doc_start = feature.token_to_orig_map[pred.start_index] + orig_doc_end = feature.token_to_orig_map[pred.end_index] + orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + + 1)] + tok_text = " ".join(tok_tokens) + + # De-tokenize WordPieces that have been split off. + tok_text = tok_text.replace(" ##", "") + tok_text = tok_text.replace("##", "") + + # Clean whitespace + tok_text = tok_text.strip() + tok_text = " ".join(tok_text.split()) + orig_text = "".join(orig_tokens) + + final_text = get_final_text(tok_text, orig_text, tokenizer, + verbose) + if final_text in seen_predictions: + continue + + seen_predictions[final_text] = True + else: + final_text = "" + seen_predictions[final_text] = True + + nbest.append( + _NbestPrediction( + text=final_text, + start_logit=pred.start_logit, + end_logit=pred.end_logit)) + + # In very rare edge cases we could have no valid predictions. So we + # just create a nonce prediction in this case to avoid failure. + if not nbest: + nbest.append( + _NbestPrediction( + text="empty", start_logit=0.0, end_logit=0.0)) + + total_scores = [] + best_non_null_entry = None + for entry in nbest: + total_scores.append(entry.start_logit + entry.end_logit) + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_logit"] = entry.start_logit + output["end_logit"] = entry.end_logit + nbest_json.append(output) + + assert len(nbest_json) >= 1 + + all_predictions[example.qas_id] = nbest_json[0]["text"] + all_nbest_json[example.qas_id] = nbest_json + return all_predictions, all_nbest_json diff --git a/examples/text_classification/ernie_doc/modeling.py b/examples/text_classification/ernie_doc/modeling.py new file mode 100644 index 000000000000..2d6c4259fd0b --- /dev/null +++ b/examples/text_classification/ernie_doc/modeling.py @@ -0,0 +1,987 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddlenlp.transformers.attention_utils import _convert_param_attr_to_list +from paddlenlp.transformers import PretrainedModel, register_base_model + +__all__ = [ + 'ErnieDocModel', + 'ErnieDocPretrainedModel', + 'ErnieDocForSequenceClassification', + 'ErnieDocForTokenClassification', + 'ErnieDocForQuestionAnswering', +] + + +class PointwiseFFN(nn.Layer): + def __init__(self, + d_inner_hid, + d_hid, + dropout_rate, + hidden_act, + weight_attr=None, + bias_attr=None): + super(PointwiseFFN, self).__init__() + self.linear1 = nn.Linear( + d_hid, d_inner_hid, weight_attr, bias_attr=bias_attr) + self.dropout = nn.Dropout(dropout_rate, mode="upscale_in_train") + self.linear2 = nn.Linear( + d_inner_hid, d_hid, weight_attr, bias_attr=bias_attr) + self.activation = getattr(F, hidden_act) + + def forward(self, x): + return self.linear2(self.dropout(self.activation(self.linear1(x)))) + + +class MultiHeadAttention(nn.Layer): + def __init__(self, + d_key, + d_value, + d_model, + n_head=1, + r_w_bias=None, + r_r_bias=None, + r_t_bias=None, + dropout_rate=0., + weight_attr=None, + bias_attr=None): + super(MultiHeadAttention, self).__init__() + self.d_key = d_key + self.d_value = d_value + self.d_model = d_model + self.n_head = n_head + + assert d_key * n_head == d_model, "d_model must be divisible by n_head" + + self.q_proj = nn.Linear( + d_model, + d_key * n_head, + weight_attr=weight_attr, + bias_attr=bias_attr) + self.k_proj = nn.Linear( + d_model, + d_key * n_head, + weight_attr=weight_attr, + bias_attr=bias_attr) + self.v_proj = nn.Linear( + d_model, + d_value * n_head, + weight_attr=weight_attr, + bias_attr=bias_attr) + self.r_proj = nn.Linear( + d_model, + d_key * n_head, + weight_attr=weight_attr, + bias_attr=bias_attr) + self.t_proj = nn.Linear( + d_model, + d_key * n_head, + weight_attr=weight_attr, + bias_attr=bias_attr) + self.out_proj = nn.Linear( + d_model, d_model, weight_attr=weight_attr, bias_attr=bias_attr) + self.r_w_bias = r_w_bias + self.r_r_bias = r_r_bias + self.r_t_bias = r_t_bias + self.dropout = nn.Dropout( + dropout_rate, mode="upscale_in_train") if dropout_rate else None + + def _compute_qkv(self, queries, keys, values, rel_pos, rel_task): + q = self.q_proj(queries) + k = self.k_proj(keys) + v = self.v_proj(values) + r = self.r_proj(rel_pos) + t = self.t_proj(rel_task) + return q, k, v, r, t + + def _split_heads(self, x, d_model, n_head): + # x shape: [B, T, H] + x = x.reshape(shape=[0, 0, n_head, d_model // n_head]) + # shape: [B, N, T, HH] + return paddle.transpose(x=x, perm=[0, 2, 1, 3]) + + def _rel_shift(self, x, klen=-1): + """ + To perform relative attention, it should relatively shift the attention score matrix + See more details on: https://github.com/kimiyoung/transformer-xl/issues/8#issuecomment-454458852 + """ + # input shape: [B, N, T, 2 * T + M] + x_shape = x.shape + x = x.reshape([x_shape[0], x_shape[1], x_shape[3], x_shape[2]]) + x = x[:, :, 1:, :] + x = x.reshape([x_shape[0], x_shape[1], x_shape[2], x_shape[3] - 1]) + # output shape: [B, N, T, T + M] + return x[:, :, :, :klen] + + def _scaled_dot_product_attention(self, q, k, v, r, t, attn_mask): + q_w, q_r, q_t = q + score_w = paddle.matmul(q_w, k, transpose_y=True) + score_r = paddle.matmul(q_r, r, transpose_y=True) + score_r = self._rel_shift(score_r, k.shape[2]) + + score_t = paddle.matmul(q_t, t, transpose_y=True) + score = score_w + score_r + score_t + score = score * (self.d_key**-0.5) + + if attn_mask is not None: + score += attn_mask + weights = F.softmax(score) + if self.dropout: + weights = self.dropout(weights) + out = paddle.matmul(weights, v) + return out + + def _combine_heads(self, x): + sign = len(x.shape) == 3 + # Directly using len(tensor.shape) as an if condition + # would not act functionally when applying paddle.jit.save api to save static graph. + if sign: return x + sign = len(x.shape) != 4 + if sign: + raise ValueError("Input(x) should be a 4-D Tensor.") + # x shape: [B, N, T, HH] + x = paddle.transpose(x, [0, 2, 1, 3]) + # target shape:[B, T, H] + return x.reshape([0, 0, x.shape[2] * x.shape[3]]) + + def forward(self, queries, keys, values, rel_pos, rel_task, memory, + attn_mask): + sign = memory is not None and len(memory.shape) > 1 + if sign: + cat = paddle.concat([memory, queries], 1) + else: + cat = queries + keys, values = cat, cat + + sign = (len(queries.shape) == len(keys.shape) == len(values.shape) \ + == len(rel_pos.shape) == len( + rel_task.shape) == 3) + + if not sign: + raise ValueError( + "Inputs: quries, keys, values, rel_pos and rel_task should all be 3-D tensors." + ) + + q, k, v, r, t = self._compute_qkv(queries, keys, values, rel_pos, + rel_task) + q_w, q_r, q_t = list( + map(lambda x: q + x.unsqueeze([0, 1]), + [self.r_w_bias, self.r_r_bias, self.r_t_bias])) + q_w, q_r, q_t = list( + map(lambda x: self._split_heads(x, self.d_model, self.n_head), + [q_w, q_r, q_t])) + k, v, r, t = list( + map(lambda x: self._split_heads(x, self.d_model, self.n_head), + [k, v, r, t])) + ctx_multiheads = self._scaled_dot_product_attention([q_w, q_r, q_t], \ + k, v, r, t, attn_mask) + out = self._combine_heads(ctx_multiheads) + out = self.out_proj(out) + return out + + +class ErnieDocEncoderLayer(nn.Layer): + def __init__(self, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + hidden_act, + normalize_before=False, + epsilon=1e-5, + rel_pos_params_sharing=False, + r_w_bias=None, + r_r_bias=None, + r_t_bias=None, + weight_attr=None, + bias_attr=None): + self._config = locals() + self._config.pop("self") + self._config.pop("__class__", None) # py3 + super(ErnieDocEncoderLayer, self).__init__() + if not rel_pos_params_sharing: + r_w_bias, r_r_bias, r_t_bias = \ + list(map(lambda x: self.create_parameter( + shape=[n_head * d_key], dtype="float32"), + ["r_w_bias", "r_r_bias", "r_t_bias"])) + + weight_attrs = _convert_param_attr_to_list(weight_attr, 2) + bias_attrs = _convert_param_attr_to_list(bias_attr, 2) + self.attn = MultiHeadAttention( + d_key, + d_value, + d_model, + n_head, + r_w_bias, + r_r_bias, + r_t_bias, + attention_dropout, + weight_attr=weight_attrs[0], + bias_attr=bias_attrs[0], ) + self.ffn = PointwiseFFN( + d_inner_hid, + d_model, + relu_dropout, + hidden_act, + weight_attr=weight_attrs[1], + bias_attr=bias_attrs[1]) + self.norm1 = nn.LayerNorm(d_model, epsilon=epsilon) + self.norm2 = nn.LayerNorm(d_model, epsilon=epsilon) + self.dropout1 = nn.Dropout( + prepostprocess_dropout, mode="upscale_in_train") + self.dropout2 = nn.Dropout( + prepostprocess_dropout, mode="upscale_in_train") + self.d_model = d_model + self.epsilon = epsilon + self.normalize_before = normalize_before + + def forward(self, enc_input, memory, rel_pos, rel_task, attn_mask): + residual = enc_input + if self.normalize_before: + enc_input = self.norm1(enc_input) + attn_output = self.attn(enc_input, enc_input, enc_input, rel_pos, + rel_task, memory, attn_mask) + attn_output = residual + self.dropout1(attn_output) + if not self.normalize_before: + attn_output = self.norm1(attn_output) + residual = attn_output + if self.normalize_before: + attn_output = self.norm2(attn_output) + ffn_output = self.ffn(attn_output) + output = residual + self.dropout2(ffn_output) + if not self.normalize_before: + output = self.norm2(output) + return output + + +class ErnieDocEncoder(nn.Layer): + def __init__(self, num_layers, encoder_layer, mem_len): + super(ErnieDocEncoder, self).__init__() + self.layers = nn.LayerList([( + encoder_layer + if i == 0 else type(encoder_layer)(**encoder_layer._config)) + for i in range(num_layers)]) + self.num_layers = num_layers + self.normalize_before = self.layers[0].normalize_before + self.mem_len = mem_len + + def _cache_mem(self, curr_out, prev_mem): + if self.mem_len is None or self.mem_len == 0: + return None + if prev_mem is None: + new_mem = curr[:, -self.mem_len:, :] + else: + new_mem = paddle.concat([prev_mem, curr_out], + 1)[:, -self.mem_len:, :] + new_mem.stop_gradient = True + return new_mem + + def forward(self, enc_input, memories, rel_pos, rel_task, attn_mask): + # memories shape: [N, B, M, H] + # no need to normalize enc_input, cause it's already normalized outside. + new_mem = None + for _, encoder_layer in enumerate(self.layers): + # Since in static mode, the memories should be set as tensor, + # so we use paddle.slice to free the old memories explicitly to save gpu memory. + enc_input = encoder_layer(enc_input, memories[0], rel_pos, rel_task, + attn_mask) + if new_mem is None: + new_mem = paddle.unsqueeze( + self._cache_mem(enc_input, memories[0]), axis=0) + else: + new_mem = paddle.concat( + [ + new_mem, paddle.unsqueeze( + self._cache_mem(enc_input, memories[0]), axis=0) + ], + axis=0) + sign = memories.shape[0] + if sign > 1: + axis = [0] + start = [1] + end = [memories.shape[0]] + memories = paddle.slice( + memories, axes=axis, starts=start, ends=end) + else: + memories = None + return enc_input, new_mem + + +class ErnieDocPretrainedModel(PretrainedModel): + """ + An abstract class for pretrained ErnieDoc models. It provides ErnieDoc related + `model_config_file`, `pretrained_init_configuration`, `resource_files_names`, + `pretrained_resource_files_map`, `base_model_prefix` for downloading + and loading pretrained models. + See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. + """ + model_config_file = "model_config.json" + pretrained_init_configuration = { + "ernie-doc-base-en": { + "attention_dropout_prob": 0.0, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "relu_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "task_type_vocab_size": 3, + "vocab_size": 50265, + "memory_len": 128, + "epsilon": 1e-12, + "pad_token_id": 1 + }, + "ernie-doc-base-zh": { + "attention_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "relu_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "task_type_vocab_size": 3, + "vocab_size": 28000, + "memory_len": 128, + "epsilon": 1e-12, + "pad_token_id": 0 + } + } + resource_files_names = {"model_state": "model_state.pdparams"} + pretrained_resource_files_map = { + "model_state": { + "ernie-doc-base-en": + "https://bj.bcebos.com/paddlenlp/models/transformers/ernie-doc-base-en/ernie-doc-base-en.pdparams", + "ernie-doc-base-zh": + "https://bj.bcebos.com/paddlenlp/models/transformers/ernie-doc-base-zh/ernie-doc-base-zh.pdparams", + } + } + base_model_prefix = "ernie_doc" + + def init_weights(self, layer): + # Initialization hook + if isinstance(layer, (nn.Linear, nn.Embedding)): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.initializer_range + if hasattr(self, "initializer_range") else + self.ernie_doc.config["initializer_range"], + shape=layer.weight.shape)) + + +class ErnieDocEmbeddings(nn.Layer): + def __init__(self, + vocab_size, + d_model, + hidden_dropout_prob, + memory_len, + max_position_embeddings=512, + type_vocab_size=3, + padding_idx=0): + super(ErnieDocEmbeddings, self).__init__() + self.word_emb = nn.Embedding(vocab_size, d_model) + self.pos_emb = nn.Embedding(max_position_embeddings * 2 + memory_len, + d_model) + self.token_type_emb = nn.Embedding(type_vocab_size, d_model) + self.memory_len = memory_len + self.dropouts = nn.LayerList( + [nn.Dropout(hidden_dropout_prob) for i in range(3)]) + self.norms = nn.LayerList([nn.LayerNorm(d_model) for i in range(3)]) + + def forward(self, input_ids, token_type_ids, position_ids): + # input_embeddings: [B, T, H] + input_embeddings = self.word_emb(input_ids.squeeze(-1)) + # position_embeddings: [B, 2 * T + M, H] + position_embeddings = self.pos_emb(position_ids.squeeze(-1)) + batch_size = input_ids.shape[0] + token_type_ids = paddle.concat( + [ + paddle.zeros( + shape=[batch_size, self.memory_len, 1], dtype="int64") + + token_type_ids[0, 0, 0], token_type_ids + ], + axis=1) + token_type_ids.stop_gradient = True + # token_type_embeddings: [B, M + T, H] + token_type_embeddings = self.token_type_emb(token_type_ids.squeeze(-1)) + embs = [input_embeddings, position_embeddings, token_type_embeddings] + for i in range(len(embs)): + embs[i] = self.dropouts[i](self.norms[i](embs[i])) + return embs + + +class ErnieDocPooler(nn.Layer): + """ + get pool output + """ + + def __init__(self, hidden_size, cls_token_idx=-1): + super(ErnieDocPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + self.cls_token_idx = cls_token_idx + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the last token. + cls_token_tensor = hidden_states[:, self.cls_token_idx] + pooled_output = self.dense(cls_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@register_base_model +class ErnieDocModel(ErnieDocPretrainedModel): + """ + The bare ERNIE-Doc Model outputting raw hidden-states. + + This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. + Refer to the superclass documentation for the generic methods. + + This model is also a `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer + and refer to the Paddle documentation for all matter related to general usage and behavior. + + Args: + num_hidden_layers (int): + The number of hidden layers in the Transformer encoder. + num_attention_heads (int): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_size (int): + Dimensionality of the embedding layers, encoder layers and pooler layer. + hidden_dropout_prob (int): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_dropout_prob (int): + The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target. + relu_dropout (int): + The dropout probability of FFN. + hidden_act (str): + The non-linear activation function of FFN. + memory_len (int): + The number of tokens to cache. If not 0, the last `memory_len` hidden states + in each layer will be cached into memory. + vocab_size (int): + Vocabulary size of `inputs_ids` in `ErnieDocModel`. Also is the vocab size of token embedding matrix. + Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `ErnieDocModel`. + max_position_embeddings (int): + The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input + sequence. Defaults to `512`. + task_type_vocab_size (int, optional): + The vocabulary size of the `token_type_ids`. Defaults to `3`. + normalize_before (bool, optional): + Indicate whether to put layer normalization into preprocessing of MHA and FFN sub-layers. + If True, pre-process is layer normalization and post-precess includes dropout, + residual connection. Otherwise, no pre-process and post-precess includes dropout, + residual connection, layer normalization. Defaults to `False`. + epsilon (float, optional): + The `epsilon` parameter used in :class:`paddle.nn.LayerNorm` for + initializing layer normalization layers. Defaults to `1e-5`. + rel_pos_params_sharing (bool, optional): + Whether to share the relative position parameters. + Defaults to `False`. + initializer_range (float, optional): + The standard deviation of the normal initializer for initializing all weight matrices. + Defaults to `0.02`. + pad_token_id (int, optional): + The token id of [PAD] token whose parameters won't be updated when training. + Defaults to `0`. + cls_token_idx (int, optional): + The token id of [CLS] token. Defaults to `-1`. + """ + + def __init__(self, + num_hidden_layers, + num_attention_heads, + hidden_size, + hidden_dropout_prob, + attention_dropout_prob, + relu_dropout, + hidden_act, + memory_len, + vocab_size, + max_position_embeddings, + task_type_vocab_size=3, + normalize_before=False, + epsilon=1e-5, + rel_pos_params_sharing=False, + initializer_range=0.02, + pad_token_id=0, + cls_token_idx=-1): + super(ErnieDocModel, self).__init__() + + r_w_bias, r_r_bias, r_t_bias = None, None, None + if rel_pos_params_sharing: + r_w_bias, r_r_bias, r_t_bias = \ + list(map(lambda x: self.create_parameter( + shape=[num_attention_heads * d_key], dtype="float32"), + ["r_w_bias", "r_r_bias", "r_t_bias"])) + d_key = hidden_size // num_attention_heads + d_value = hidden_size // num_attention_heads + d_inner_hid = hidden_size * 4 + encoder_layer = ErnieDocEncoderLayer( + num_attention_heads, + d_key, + d_value, + hidden_size, + d_inner_hid, + hidden_dropout_prob, + attention_dropout_prob, + relu_dropout, + hidden_act, + normalize_before=normalize_before, + epsilon=epsilon, + rel_pos_params_sharing=rel_pos_params_sharing, + r_w_bias=r_w_bias, + r_r_bias=r_r_bias, + r_t_bias=r_t_bias) + self.n_head = num_attention_heads + self.d_model = hidden_size + self.memory_len = memory_len + self.encoder = ErnieDocEncoder(num_hidden_layers, encoder_layer, + memory_len) + self.pad_token_id = pad_token_id + self.embeddings = ErnieDocEmbeddings( + vocab_size, hidden_size, hidden_dropout_prob, memory_len, + max_position_embeddings, task_type_vocab_size, pad_token_id) + self.pooler = ErnieDocPooler(hidden_size, cls_token_idx) + + def _create_n_head_attn_mask(self, attn_mask, batch_size): + # attn_mask shape: [B, T, 1] + # concat an data_mask, shape: [B, M + T, 1] + data_mask = paddle.concat( + [ + paddle.ones( + shape=[batch_size, self.memory_len, 1], + dtype=attn_mask.dtype), attn_mask + ], + axis=1) + data_mask.stop_gradient = True + # create a self_attn_mask, shape: [B, T, M + T] + self_attn_mask = paddle.matmul(attn_mask, data_mask, transpose_y=True) + self_attn_mask = (self_attn_mask - 1) * 1e8 + n_head_self_attn_mask = paddle.stack( + [self_attn_mask] * self.n_head, axis=1) + n_head_self_attn_mask.stop_gradient = True + return n_head_self_attn_mask + + def forward(self, input_ids, memories, token_type_ids, position_ids, + attn_mask): + r""" + The ErnieDocModel forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + It's data type should be `int64` and has a shape of [batch_size, sequence_length, 1]. + memories (Tensor): + Pre-computed hidden-states for each layer. + It's data type should be `float32` and has a shape of [num_hidden_layers, batch_size, memory_len, hidden_size]. + token_type_ids (Tensor): + Segment token indices to indicate first and second portions of the inputs. + Indices can be either 0 or 1: + + - 0 corresponds to a **sentence A** token, + - 1 corresponds to a **sentence B** token. + + It's data type should be `int64` and has a shape of [batch_size, sequence_length, 1]. + Defaults to None, which means no segment embeddings is added to token embeddings. + position_ids (Tensor): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. Shape as `(batch_sie, num_tokens)` and dtype as `int32` or `int64`. + attn_mask (Tensor): + Mask used in multi-head attention to avoid performing attention on to some unwanted positions, + usually the paddings or the subsequent positions. + Its data type can be int, float and bool. + When the data type is bool, the `masked` tokens have `False` values and the others have `True` values. + When the data type is int, the `masked` tokens have `0` values and the others have `1` values. + When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values. + It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. + For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], + [batch_size, num_attention_heads, sequence_length, sequence_length]. + We use whole-word-mask in ERNIE, so the whole word will have the same value. For example, "使用" as a word, + "使" and "用" will have the same value. + Defaults to `None`, which means nothing needed to be prevented attention to. + + Returns: + tuple : Returns tuple (``encoder_output``, ``pooled_output``, ``new_mem``). + + With the fields: + + - `encoder_output` (Tensor): + Sequence of hidden-states at the last layer of the model. + It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size]. + + - `pooled_output` (Tensor): + The output of first token (`[CLS]`) in sequence. + We "pool" the model by simply taking the hidden state corresponding to the first token. + Its data type should be float32 and its shape is [batch_size, hidden_size]. + + - `new_mem` (List[Tensor]): + A list of pre-computed hidden-states. The length of the list is `n_layers`. + Each element in the list is a Tensor with dtype `float32` and shape as [batch_size, memory_length, hidden_size]. + + Example: + .. code-block:: + + import numpy as np + import paddle + from paddlenlp.transformers import ErnieDocModel + from paddlenlp.transformers import ErnieDocTokenizer + + def get_related_pos(insts, seq_len, memory_len=128): + beg = seq_len + seq_len + memory_len + r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ + list(range(0, seq_len)) for i in range(len(insts))] + return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) + + tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') + model = ErnieDocModel.from_pretrained('ernie-doc-base-zh') + + inputs = tokenizer("欢迎使用百度飞桨!") + inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} + + memories = paddle.zeros([12, 1, 128, 768], dtype="float32") + position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) + attn_mask = paddle.ones([1, 128, 1]) + + inputs['memories'] = memories + inputs['position_ids'] = position_ids + inputs['attn_mask'] = attn_mask + + outputs = model(**inputs) + + encoder_output = outputs[0] + pooled_output = outputs[1] + new_mem = outputs[2] + + """ + input_embeddings, position_embeddings, token_embeddings = \ + self.embeddings(input_ids, token_type_ids, position_ids) + + batch_size = input_embeddings.shape[0] + # [B, N, T, M + T] + n_head_self_attn_mask = self._create_n_head_attn_mask(attn_mask, + batch_size) + # memories contain n_layer memory whose shape is [B, M, H] + encoder_output, new_mem = self.encoder( + enc_input=input_embeddings, + memories=memories, + rel_pos=position_embeddings, + rel_task=token_embeddings, + attn_mask=n_head_self_attn_mask) + pooled_output = self.pooler(encoder_output) + return encoder_output, pooled_output, new_mem + + +class ErnieDocForSequenceClassification(ErnieDocPretrainedModel): + """ + ErnieDoc Model with a linear layer on top of the output layer, + designed for sequence classification/regression tasks like GLUE tasks. + + Args: + ernie_doc (:class:`ErnieDocModel`): + An instance of :class:`ErnieDocModel`. + num_classes (int): + The number of classes. + dropout (float, optional) + The dropout ratio of last output. Default to `0.1`. + """ + + def __init__(self, ernie_doc, num_classes, dropout=0.1): + super(ErnieDocForSequenceClassification, self).__init__() + self.ernie_doc = ernie_doc + self.linear = nn.Linear(self.ernie_doc.config["hidden_size"], + num_classes) + self.dropout = nn.Dropout(dropout, mode="upscale_in_train") + self.apply(self.init_weights) + + def forward(self, input_ids, memories, token_type_ids, position_ids, + attn_mask): + r""" + The ErnieDocForSequenceClassification forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor): + See :class:`ErnieDocModel`. + memories (Tensor): + See :class:`ErnieDocModel`. + token_type_ids (Tensor): + See :class:`ErnieDocModel`. + position_ids (Tensor): + See :class:`ErnieDocModel`. + attn_mask (Tensor): + See :class:`ErnieDocModel`. + + Returns: + tuple : Returns tuple (`logits`, `mem`). + + With the fields: + + - `logits` (Tensor): + A tensor containing the [CLS] of hidden-states of the model at the output of last layer. + Each Tensor has a data type of `float32` and has a shape of [batch_size, num_classes]. + + - `mem` (List[Tensor]): + A list of pre-computed hidden-states. The length of the list is `n_layers`. + Each element in the list is a Tensor with dtype `float32` and has a shape of + [batch_size, memory_length, hidden_size]. + + Example: + .. code-block:: + + import numpy as np + import paddle + from paddlenlp.transformers import ErnieDocForSequenceClassification + from paddlenlp.transformers import ErnieDocTokenizer + + def get_related_pos(insts, seq_len, memory_len=128): + beg = seq_len + seq_len + memory_len + r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ + list(range(0, seq_len)) for i in range(len(insts))] + return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) + + tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') + model = ErnieDocForSequenceClassification.from_pretrained('ernie-doc-base-zh', num_classes=2) + + inputs = tokenizer("欢迎使用百度飞桨!") + inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} + + memories = paddle.zeros([12, 1, 128, 768], dtype="float32") + position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) + attn_mask = paddle.ones([1, 128, 1]) + + inputs['memories'] = memories + inputs['position_ids'] = position_ids + inputs['attn_mask'] = attn_mask + + outputs = model(**inputs) + + logits = outputs[0] + mem = outputs[1] + + """ + _, pooled_output, mem = self.ernie_doc( + input_ids, memories, token_type_ids, position_ids, attn_mask) + pooled_output = self.dropout(pooled_output) + logits = self.linear(pooled_output) + return logits, mem + + +class ErnieDocForTokenClassification(ErnieDocPretrainedModel): + """ + ErnieDoc Model with a linear layer on top of the hidden-states output layer, + designed for token classification tasks like NER tasks. + + Args: + ernie_doc (:class:`ErnieDocModel`): + An instance of :class:`ErnieDocModel`. + num_classes (int): + The number of classes. + dropout (float, optional) + The dropout ratio of last output. Default to 0.1. + """ + + def __init__(self, ernie_doc, num_classes, dropout=0.1): + super(ErnieDocForTokenClassification, self).__init__() + self.num_classes = num_classes + self.ernie_doc = ernie_doc # allow ernie_doc to be config + self.dropout = nn.Dropout(dropout, mode="upscale_in_train") + self.linear = nn.Linear(self.ernie_doc.config["hidden_size"], + num_classes) + self.apply(self.init_weights) + + def forward(self, input_ids, memories, token_type_ids, position_ids, + attn_mask): + r""" + The ErnieDocForTokenClassification forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor): + See :class:`ErnieDocModel`. + memories (Tensor): + See :class:`ErnieDocModel`. + token_type_ids (Tensor): + See :class:`ErnieDocModel`. + Defaults to None, which means no segment embeddings is added to token embeddings. + position_ids (Tensor): + See :class:`ErnieDocModel`. + attn_mask (Tensor): + See :class:`ErnieDocModel`. + + Returns: + tuple : Returns tuple (`logits`, `mem`). + + With the fields: + + - `logits` (Tensor): + A tensor containing the hidden-states of the model at the output of last layer. + Each Tensor has a data type of `float32` and has a shape of [batch_size, sequence_length, num_classes]. + + - `mem` (List[Tensor]): + A list of pre-computed hidden-states. The length of the list is `n_layers`. + Each element in the list is a Tensor with dtype `float32` and has a shape of + [batch_size, memory_length, hidden_size]. + + Example: + .. code-block:: + + import numpy as np + import paddle + from paddlenlp.transformers import ErnieDocForTokenClassification + from paddlenlp.transformers import ErnieDocTokenizer + + def get_related_pos(insts, seq_len, memory_len=128): + beg = seq_len + seq_len + memory_len + r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ + list(range(0, seq_len)) for i in range(len(insts))] + return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) + + tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') + model = ErnieDocForTokenClassification.from_pretrained('ernie-doc-base-zh', num_classes=2) + + inputs = tokenizer("欢迎使用百度飞桨!") + inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} + + memories = paddle.zeros([12, 1, 128, 768], dtype="float32") + position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) + attn_mask = paddle.ones([1, 128, 1]) + + inputs['memories'] = memories + inputs['position_ids'] = position_ids + inputs['attn_mask'] = attn_mask + + outputs = model(**inputs) + + logits = outputs[0] + mem = outputs[1] + + """ + sequence_output, _, mem = self.ernie_doc( + input_ids, memories, token_type_ids, position_ids, attn_mask) + sequence_output = self.dropout(sequence_output) + logits = self.linear(sequence_output) + return logits, mem + + +class ErnieDocForQuestionAnswering(ErnieDocPretrainedModel): + """ + ErnieDoc Model with a linear layer on top of the hidden-states + output to compute `span_start_logits` and `span_end_logits`, + designed for question-answering tasks like SQuAD. + + Args: + ernie_doc (:class:`ErnieDocModel`): + An instance of :class:`ErnieDocModel`. + dropout (float, optional) + The dropout ratio of last output. Default to 0.1. + """ + + def __init__(self, ernie_doc, dropout=0.1): + super(ErnieDocForQuestionAnswering, self).__init__() + self.ernie_doc = ernie_doc # allow ernie_doc to be config + self.dropout = nn.Dropout(dropout, mode="upscale_in_train") + self.linear = nn.Linear(self.ernie_doc.config["hidden_size"], 2) + self.apply(self.init_weights) + + def forward(self, input_ids, memories, token_type_ids, position_ids, + attn_mask): + r""" + The ErnieDocForQuestionAnswering forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor): + See :class:`ErnieDocModel`. + memories (Tensor): + See :class:`ErnieDocModel`. + token_type_ids (Tensor): + See :class:`ErnieDocModel`. + position_ids (Tensor): + See :class:`ErnieDocModel`. + attn_mask (Tensor): + See :class:`ErnieDocModel`. + + Returns: + tuple : Returns tuple (`start_logits`, `end_logits`, `mem`). + + With the fields: + + - `start_logits` (Tensor): + A tensor of the input token classification logits, indicates the start position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + - `end_logits` (Tensor): + A tensor of the input token classification logits, indicates the end position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + - `mem` (List[Tensor]): + A list of pre-computed hidden-states. The length of the list is `n_layers`. + Each element in the list is a Tensor with dtype `float32` and has a shape of + [batch_size, memory_length, hidden_size]. + + Example: + .. code-block:: + + import numpy as np + import paddle + from paddlenlp.transformers import ErnieDocForQuestionAnswering + from paddlenlp.transformers import ErnieDocTokenizer + + def get_related_pos(insts, seq_len, memory_len=128): + beg = seq_len + seq_len + memory_len + r_position = [list(range(beg - 1, seq_len - 1, -1)) + \ + list(range(0, seq_len)) for i in range(len(insts))] + return np.array(r_position).astype('int64').reshape([len(insts), beg, 1]) + + tokenizer = ErnieDocTokenizer.from_pretrained('ernie-doc-base-zh') + model = ErnieDocForQuestionAnswering.from_pretrained('ernie-doc-base-zh') + + inputs = tokenizer("欢迎使用百度飞桨!") + inputs = {k:paddle.to_tensor([v + [0] * (128-len(v))]).unsqueeze(-1) for (k, v) in inputs.items()} + + memories = paddle.zeros([12, 1, 128, 768], dtype="float32") + position_ids = paddle.to_tensor(get_related_pos(inputs['input_ids'], 128, 128)) + attn_mask = paddle.ones([1, 128, 1]) + + inputs['memories'] = memories + inputs['position_ids'] = position_ids + inputs['attn_mask'] = attn_mask + + outputs = model(**inputs) + + start_logits = outputs[0] + end_logits = outputs[1] + mem = outputs[2] + + """ + sequence_output, _, mem = self.ernie_doc( + input_ids, memories, token_type_ids, position_ids, attn_mask) + sequence_output = self.dropout(sequence_output) + logits = self.linear(sequence_output) + start_logits, end_logits = paddle.transpose(logits, perm=[2, 0, 1]) + return start_logits, end_logits, mem diff --git a/examples/text_classification/ernie_doc/predict.py b/examples/text_classification/ernie_doc/predict.py new file mode 100644 index 000000000000..c32cf5ae5a02 --- /dev/null +++ b/examples/text_classification/ernie_doc/predict.py @@ -0,0 +1,301 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import paddle +import numpy as np +from paddlenlp.utils.env import PPNLP_HOME +from paddlenlp.utils.log import logger +from paddlenlp.taskflow.utils import dygraph_mode_guard +from modeling import ErnieDocForSequenceClassification +from paddlenlp.transformers import ErnieDocTokenizer, ErnieDocBPETokenizer +from paddlenlp.datasets import load_dataset +from data import ClassifierIterator, ImdbTextPreprocessor, HYPTextPreprocessor, to_json_file +import paddle.nn as nn +from train import init_memory +from functools import partial +import argparse + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--batch_size", default=16, type=int, + help="Batch size per GPU/CPU for predicting (In static mode, it should be the same as in model training process.)") +parser.add_argument("--model_name_or_path", type=str, default="ernie-doc-base-zh", + help="Pretraining or finetuned model name or path") +parser.add_argument("--max_seq_length", type=int, default=512, + help="The maximum total input sequence length after SentencePiece tokenization.") +parser.add_argument("--memory_length", type=int, default=128, help="Length of the retained previous heads.") +parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu"], + help="Select cpu, gpu devices to train model.") +parser.add_argument("--test_results_file", default="./test_restuls.json", type=str, + help="The file path you would like to save the model ouputs on test dataset.") +parser.add_argument("--static_mode", default=False, type=bool, + help="Whether you would like to perform predicting by static model or dynamic model.") +parser.add_argument("--dataset", default="iflytek", choices=["imdb", "iflytek", "thucnews", "hyp"], type=str, + help="The training dataset") +parser.add_argument("--static_path", default=None, type=str, + help="The path which your static model is at or where you want to save after converting.") + +args = parser.parse_args() +# yapf: enable + +DATASET_INFO = { + "imdb": (ErnieDocBPETokenizer, "test", ImdbTextPreprocessor()), + "hyp": (ErnieDocBPETokenizer, "test", HYPTextPreprocessor()), + "iflytek": (ErnieDocTokenizer, "test", None), + "thucnews": (ErnieDocTokenizer, "test", None) +} + + +def predict(model, + test_dataloader, + file_path, + memories, + label_list, + static_mode, + input_handles=None, + output_handles=None): + label_dict = dict() + if not static_mode: + model.eval() + for _, batch in enumerate(test_dataloader, start=1): + input_ids, position_ids, token_type_ids, attn_mask, _, qids, \ + gather_idxs, need_cal_loss = batch + logits, memories = model(input_ids, memories, token_type_ids, + position_ids, attn_mask) + logits, qids = list( + map(lambda x: paddle.gather(x, gather_idxs), [logits, qids])) + probs = nn.functional.softmax(logits, axis=1) + idx = paddle.argmax(probs, axis=1).numpy() + idx = idx.tolist() + labels = [label_list[i] for i in idx] + for i, qid in enumerate(qids.numpy().flatten()): + label_dict[str(qid)] = labels[i] + else: + for _, batch in enumerate(test_dataloader, start=1): + input_ids, position_ids, token_type_ids, attn_mask, _, qids, \ + gather_idxs, need_cal_loss = batch + input_handles[0].copy_from_cpu(input_ids.numpy()) + input_handles[1].copy_from_cpu(memories) + input_handles[2].copy_from_cpu(token_type_ids.numpy()) + input_handles[3].copy_from_cpu(position_ids.numpy()) + input_handles[4].copy_from_cpu(attn_mask.numpy()) + model.run() + logits = paddle.to_tensor(output_handles[0].copy_to_cpu()) + memories = paddle.to_tensor(output_handles[1].copy_to_cpu()) + logits, qids = list( + map(lambda x: paddle.gather(x, gather_idxs), [logits, qids])) + probs = nn.functional.softmax(logits, axis=1) + idx = paddle.argmax(probs, axis=1).numpy() + idx = idx.tolist() + labels = [label_list[i] for i in idx] + for i, qid in enumerate(qids.numpy().flatten()): + label_dict[str(qid)] = labels[i] + to_json_file("iflytek", label_dict, file_path) + + +class LongDocClassifier: + def __init__(self, + model_name_or_path, + trainer_num=1, + rank=0, + batch_size=16, + max_seq_length=512, + memory_len=128, + static_mode=False, + dataset="iflytek", + static_path=None): + self.model_name_or_path = model_name_or_path + self.batch_size = batch_size + self.trainer_num = trainer_num + self.rank = rank + self.max_seq_length = max_seq_length + self.memory_len = memory_len + self.static_mode = static_mode + self.static_path = static_path if static_path else PPNLP_HOME + + tokenizer_class, test_name, preprocess_text_fn = DATASET_INFO[dataset] + self._construct_tokenizer(tokenizer_class) + self._input_preparation(args.dataset, test_name, preprocess_text_fn) + self._construct_model() + if static_mode: + logger.info("Loading the static model from {}".format( + self.static_path)) + self._load_static_model() + + def _input_preparation(self, + dataset="iflytek", + test_name="test", + preprocess_text_fn=None): + test_ds = load_dataset("clue", name=dataset, splits=[test_name]) + self.label_list = test_ds.label_list + self.num_classes = len(test_ds.label_list) + self.test_ds_iter = ClassifierIterator( + test_ds, + self.batch_size, + self._tokenizer, + self.trainer_num, + trainer_id=self.rank, + memory_len=self.memory_len, + max_seq_length=self.max_seq_length, + mode="eval", + preprocess_text_fn=preprocess_text_fn) + self.test_dataloader = paddle.io.DataLoader.from_generator( + capacity=70, return_list=True) + self.test_dataloader.set_batch_generator(self.test_ds_iter, + paddle.get_device()) + + def _construct_tokenizer(self, tokenizer_class): + """ + Construct the tokenizer for the predictor. + :return: + """ + tokenizer_instance = tokenizer_class.from_pretrained( + self.model_name_or_path) + self._tokenizer = tokenizer_instance + + def _construct_model(self): + """ + Construct the inference model for the predictor + :param model_name_or_path: str + :return: model instance + """ + model_instance = ErnieDocForSequenceClassification.from_pretrained( + self.model_name_or_path, num_classes=self.num_classes) + self.model_config = model_instance.ernie_doc.config + self._model = model_instance + + def _load_static_model(self, params_path=None): + """Load static model""" + inference_model_path = os.path.join(self.static_path, "static", + "inference") + with dygraph_mode_guard(): + self._construct_model() + if params_path: + state_dict = paddle.load(params_path) + self._model.set_dict(state_dict) + self._construct_input_spec() + self._convert_dygraph_to_static() + + model_file = inference_model_path + ".pdmodel" + params_file = inference_model_path + ".pdiparams" + self._config = paddle.inference.Config(model_file, params_file) + + def _prepare_static_mode(self): + """ + Construct the input data and predictor in the PaddlePaddele static mode. + """ + place = paddle.get_device() + if place == 'cpu': + self._config.disable_gpu() + else: + self._config.enable_use_gpu(100) + self._config.switch_use_feed_fetch_ops(False) + self._config.disable_glog_info() + self.predictor = paddle.inference.create_predictor(self._config) + self.input_handles = [ + self.predictor.get_input_handle(name) + for name in self.predictor.get_input_names() + ] + self.output_handle = [ + self.predictor.get_output_handle(name) + for name in self.predictor.get_output_names() + ] + + def _construct_input_spec(self): + """ + Construct the input spec for the convert dygraph model to static model. + """ + B, T, H, M, N = self.batch_size, self.max_seq_length, self.model_config["hidden_size"], self.memory_len, \ + self.model_config["num_hidden_layers"] + self._input_spec = [ + paddle.static.InputSpec( + shape=[B, T, 1], dtype="int64", name="input_ids"), # input_ids + paddle.static.InputSpec( + shape=[N, B, M, H], dtype="float32", + name="memories"), # memories + paddle.static.InputSpec( + shape=[B, T, 1], dtype="int64", + name="token_type_ids"), # token_type_ids + paddle.static.InputSpec( + shape=[B, 2 * T + M, 1], dtype="int64", + name="position_ids"), # position_ids + paddle.static.InputSpec( + shape=[B, T, 1], dtype="float32", + name="attn_mask"), # attn_mask + ] + + def _convert_dygraph_to_static(self): + """ + Convert the dygraph model to static model. + """ + assert self._model is not None, 'The dygraph model must be created before converting the dygraph model to static model.' + assert self._input_spec is not None, 'The input spec must be created before converting the dygraph model to static model.' + logger.info("Converting to the inference model cost a little time.") + static_model = paddle.jit.to_static( + self._model, input_spec=self._input_spec) + save_path = os.path.join(self.static_path, "static", "inference") + paddle.jit.save(static_model, save_path) + logger.info("The inference model save in the path:{}".format(save_path)) + + def run_model(self, saved_path): + if not self.static_mode: + create_memory = partial(init_memory, self.batch_size, + self.memory_len, + self.model_config["hidden_size"], + self.model_config["num_hidden_layers"]) + # Copy the memory + memories = create_memory() + else: + memories = np.zeros( + [ + self.model_config["num_hidden_layers"], self.batch_size, + self.memory_len, self.model_config["hidden_size"] + ], + dtype="float32") + file_path = saved_path + if not self.static_mode: + self.input_handles, self.output_handle, self.predictor = None, None, self._model + else: + self._prepare_static_mode() + predict(self.predictor, self.test_dataloader, file_path, memories, + self.label_list, self.static_mode, self.input_handles, + self.output_handle) + + +def do_predict(args): + # Initialize model + paddle.set_device(args.device) + trainer_num = paddle.distributed.get_world_size() + if trainer_num > 1: + paddle.distributed.init_parallel_env() + rank = paddle.distributed.get_rank() + if rank == 0: + if os.path.exists(args.model_name_or_path): + logger.info("init checkpoint from %s" % args.model_name_or_path) + + predictor = LongDocClassifier( + model_name_or_path=args.model_name_or_path, + rank=rank, + trainer_num=trainer_num, + batch_size=args.batch_size, + max_seq_length=args.max_seq_length, + memory_len=args.memory_length, + static_mode=args.static_mode, + static_path=args.static_path) + predictor.run_model(saved_path=args.test_results_file) + + +if __name__ == "__main__": + do_predict(args) diff --git a/examples/text_classification/ernie_doc/train.py b/examples/text_classification/ernie_doc/train.py new file mode 100644 index 000000000000..3f27dff4bb86 --- /dev/null +++ b/examples/text_classification/ernie_doc/train.py @@ -0,0 +1,346 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from collections import defaultdict +import os +import random +from functools import partial +import time +import numpy as np +import paddle +import paddle.nn as nn +from paddle.metric import Accuracy +from modeling import ErnieDocForSequenceClassification +from paddlenlp.transformers import ErnieDocTokenizer, ErnieDocBPETokenizer +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.utils.log import logger +from paddlenlp.datasets import load_dataset +from paddlenlp.ops.optimizer import AdamWDL +from data import ClassifierIterator, ImdbTextPreprocessor, HYPTextPreprocessor, to_json_file +from metrics import F1 + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.") +parser.add_argument("--model_name_or_path", type=str, default="ernie-doc-base-zh", + help="Pretraining model name or path") +parser.add_argument("--max_seq_length", type=int, default=512, + help="The maximum total input sequence length after SentencePiece tokenization.") +parser.add_argument("--learning_rate", type=float, default=1.5e-4, help="Learning rate used to train.") +parser.add_argument("--save_steps", type=int, default=1000, help="Save checkpoint every X updates steps.") +parser.add_argument("--logging_steps", type=int, default=1, help="Log every X updates steps.") +parser.add_argument("--output_dir", type=str, default='checkpoints/', help="Directory to save model checkpoint") +parser.add_argument("--epochs", type=int, default=3, help="Number of epoches for training.") +parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu"], + help="Select cpu, gpu devices to train model.") +parser.add_argument("--seed", type=int, default=1, help="Random seed for initialization.") +parser.add_argument("--memory_length", type=int, default=128, help="Length of the retained previous heads.") +parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.") +parser.add_argument("--warmup_proportion", default=0.1, type=float, + help="Linear warmup proption over the training process.") +parser.add_argument("--dataset", default="imdb", choices=["imdb", "iflytek", "thucnews", "hyp"], type=str, + help="The training dataset") +parser.add_argument("--layerwise_decay", default=1.0, type=float, help="Layerwise decay ratio") +parser.add_argument("--max_steps", default=-1, type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", ) +parser.add_argument("--test_results_file", default="./test_restuls.json", type=str, + help="The file path you would like to save the model ouputs on test dataset.") + +args = parser.parse_args() +# yapf: enable + +DATASET_INFO = { + "imdb": + (ErnieDocBPETokenizer, "test", "test", ImdbTextPreprocessor(), Accuracy()), + "hyp": (ErnieDocBPETokenizer, "dev", "test", HYPTextPreprocessor(), F1()), + "iflytek": (ErnieDocTokenizer, "dev", "test", None, Accuracy()), + "thucnews": (ErnieDocTokenizer, "dev", "test", None, Accuracy()) +} + + +def set_seed(args): + # Use the same data seed(for data shuffle) for all procs to guarantee data + # consistency after sharding. + random.seed(args.seed) + np.random.seed(args.seed) + # Maybe different op seeds(for dropout) for different procs is better. By: + # `paddle.seed(args.seed + paddle.distributed.get_rank())` + paddle.seed(args.seed) + + +def init_memory(batch_size, memory_length, d_model, n_layers): + return paddle.zeros( + [n_layers, batch_size, memory_length, d_model], dtype="float32") + + +@paddle.no_grad() +def evaluate(model, metric, data_loader, memories): + model.eval() + losses = [] + # copy the memory + tic_train = time.time() + eval_logging_step = 500 + + probs_dict = defaultdict(list) + label_dict = dict() + global_steps = 0 + for step, batch in enumerate(data_loader, start=1): + input_ids, position_ids, token_type_ids, attn_mask, labels, qids, \ + gather_idxs, need_cal_loss = batch + logits, memories = model(input_ids, memories, token_type_ids, + position_ids, attn_mask) + logits, labels, qids = list( + map(lambda x: paddle.gather(x, gather_idxs), + [logits, labels, qids])) + # Need to collect probs for each qid, so use softmax_with_cross_entropy + loss, probs = nn.functional.softmax_with_cross_entropy( + logits, labels, return_softmax=True) + losses.append(loss.mean().numpy()) + # Shape: [B, NUM_LABELS] + np_probs = probs.numpy() + # Shape: [B, 1] + np_qids = qids.numpy() + np_labels = labels.numpy().flatten() + for i, qid in enumerate(np_qids.flatten()): + probs_dict[qid].append(np_probs[i]) + label_dict[qid] = np_labels[i] # Same qid share same label. + + if step % eval_logging_step == 0: + logger.info("Step %d: loss: %.5f, speed: %.5f steps/s" % + (step, np.mean(losses), + eval_logging_step / (time.time() - tic_train))) + tic_train = time.time() + + # Collect predicted labels + preds = [] + labels = [] + for qid, probs in probs_dict.items(): + mean_prob = np.mean(np.array(probs), axis=0) + preds.append(mean_prob) + labels.append(label_dict[qid]) + + preds = paddle.to_tensor(np.array(preds, dtype='float32')) + labels = paddle.to_tensor(np.array(labels, dtype='int64')) + + metric.update(metric.compute(preds, labels)) + acc_or_f1 = metric.accumulate() + logger.info("Eval loss: %.5f, %s: %.5f" % + (np.mean(losses), metric.__class__.__name__, acc_or_f1)) + metric.reset() + model.train() + return acc_or_f1 + + +def predict(model, test_dataloader, file_path, memories, label_list): + label_dict = dict() + model.eval() + for _, batch in enumerate(test_dataloader, start=1): + input_ids, position_ids, token_type_ids, attn_mask, _, qids, \ + gather_idxs, need_cal_loss = batch + logits, memories = model(input_ids, memories, token_type_ids, + position_ids, attn_mask) + logits, qids = list( + map(lambda x: paddle.gather(x, gather_idxs), [logits, qids])) + probs = nn.functional.softmax(logits, axis=1) + idx = paddle.argmax(probs, axis=1).numpy() + idx = idx.tolist() + labels = [label_list[i] for i in idx] + for i, qid in enumerate(qids.numpy().flatten()): + label_dict[str(qid)] = labels[i] + to_json_file("iflytek", label_dict, file_path) + + +def do_train(args): + set_seed(args) + + tokenizer_class, eval_name, test_name, preprocess_text_fn, eval_metric = DATASET_INFO[ + args.dataset] + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + train_ds, eval_ds, test_ds = load_dataset( + "clue", name=args.dataset, splits=["train", eval_name, test_name]) + num_classes = len(train_ds.label_list) + + paddle.set_device(args.device) + trainer_num = paddle.distributed.get_world_size() + if trainer_num > 1: + paddle.distributed.init_parallel_env() + rank = paddle.distributed.get_rank() + if rank == 0: + if os.path.exists(args.model_name_or_path): + logger.info("init checkpoint from %s" % args.model_name_or_path) + model = ErnieDocForSequenceClassification.from_pretrained( + args.model_name_or_path, num_classes=num_classes) + model_config = model.ernie_doc.config + if trainer_num > 1: + model = paddle.DataParallel(model) + + train_ds_iter = ClassifierIterator( + train_ds, + args.batch_size, + tokenizer, + trainer_num, + trainer_id=rank, + memory_len=model_config["memory_len"], + max_seq_length=args.max_seq_length, + random_seed=args.seed, + preprocess_text_fn=preprocess_text_fn) + eval_ds_iter = ClassifierIterator( + eval_ds, + args.batch_size, + tokenizer, + trainer_num, + trainer_id=rank, + memory_len=model_config["memory_len"], + max_seq_length=args.max_seq_length, + mode="eval", + preprocess_text_fn=preprocess_text_fn) + test_ds_iter = ClassifierIterator( + test_ds, + args.batch_size, + tokenizer, + trainer_num, + trainer_id=rank, + memory_len=model_config["memory_len"], + max_seq_length=args.max_seq_length, + mode="test", + preprocess_text_fn=preprocess_text_fn) + + train_dataloader = paddle.io.DataLoader.from_generator( + capacity=70, return_list=True) + train_dataloader.set_batch_generator(train_ds_iter, paddle.get_device()) + eval_dataloader = paddle.io.DataLoader.from_generator( + capacity=70, return_list=True) + eval_dataloader.set_batch_generator(eval_ds_iter, paddle.get_device()) + test_dataloader = paddle.io.DataLoader.from_generator( + capacity=70, return_list=True) + test_dataloader.set_batch_generator(test_ds_iter, paddle.get_device()) + + num_training_examples = train_ds_iter.get_num_examples() + num_training_steps = args.epochs * num_training_examples // args.batch_size // trainer_num + logger.info("Device count: %d, trainer_id: %d" % (trainer_num, rank)) + logger.info("Num train examples: %d" % num_training_examples) + logger.info("Max train steps: %d" % num_training_steps) + logger.info("Num warmup steps: %d" % int(num_training_steps * + args.warmup_proportion)) + + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, + args.warmup_proportion) + + decay_params = [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + # Construct dict + name_dict = dict() + for n, p in model.named_parameters(): + name_dict[p.name] = n + + optimizer = AdamWDL( + learning_rate=lr_scheduler, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_params, + n_layers=model_config["num_hidden_layers"], + layerwise_decay=args.layerwise_decay, + name_dict=name_dict) + + criterion = paddle.nn.loss.CrossEntropyLoss() + metric = paddle.metric.Accuracy() + + global_steps = 0 + best_acc = -1 + create_memory = partial(init_memory, args.batch_size, args.memory_length, + model_config["hidden_size"], + model_config["num_hidden_layers"]) + # Copy the memory + memories = create_memory() + tic_train = time.time() + stop_training = False + for epoch in range(args.epochs): + train_ds_iter.shuffle_sample() + train_dataloader.set_batch_generator(train_ds_iter, paddle.get_device()) + for step, batch in enumerate(train_dataloader, start=1): + global_steps += 1 + input_ids, position_ids, token_type_ids, attn_mask, labels, qids, \ + gather_idx, need_cal_loss = batch + logits, memories = model(input_ids, memories, token_type_ids, + position_ids, attn_mask) + + logits, labels = list( + map(lambda x: paddle.gather(x, gather_idx), [logits, labels])) + loss = criterion(logits, labels) * need_cal_loss + mean_loss = loss.mean() + mean_loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + # Rough acc result, not a precise acc + acc = metric.compute(logits, labels) * need_cal_loss + metric.update(acc) + + if global_steps % args.logging_steps == 0: + logger.info( + "train: global step %d, epoch: %d, loss: %f, acc:%f, lr: %f, speed: %.2f step/s" + % (global_steps, epoch, mean_loss, metric.accumulate(), + lr_scheduler.get_lr(), + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + + if global_steps % args.save_steps == 0: + # Evaluate + logger.info("Eval:") + eval_acc = evaluate(model, eval_metric, eval_dataloader, + create_memory()) + # Save + if rank == 0: + output_dir = os.path.join(args.output_dir, + "model_%d" % (global_steps)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + if eval_acc > best_acc: + logger.info("Save best model......") + best_acc = eval_acc + best_model_dir = os.path.join(output_dir, "best_model") + if not os.path.exists(best_model_dir): + os.makedirs(best_model_dir) + model_to_save.save_pretrained(best_model_dir) + tokenizer.save_pretrained(best_model_dir) + + if args.max_steps > 0 and global_steps >= args.max_steps: + stop_training = True + break + if stop_training: + break + logger.info("Final test result:") + eval_acc = evaluate(model, eval_metric, eval_dataloader, create_memory()) + logger.info("start predict the test data") + + create_memory = partial(init_memory, args.batch_size, args.memory_length, + model_config["hidden_size"], + model_config["num_hidden_layers"]) + # Copy the memory + memories = create_memory() + predict(model, test_dataloader, args.file_path, memories, + test_ds.label_list) + logger.info("Done Predicting the results has been saved in file: {}".format( + args.file_path)) + + +if __name__ == "__main__": + do_train(args)