Skip to content

Commit

Permalink
fix bug of label dimension smaller than 1 (#3238)
Browse files Browse the repository at this point in the history
  • Loading branch information
linjieccc authored Sep 9, 2022
1 parent 3e66b0c commit 4ce8fd9
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions model_zoo/uie/data_distill/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __call__(

bs = batch[0].shape[0]
if self.task_type == "entity_extraction":
max_ent_num = max([len(lb["ent_labels"]) for lb in labels])
# Ensure the dimension is greater or equal to 1
max_ent_num = max(max([len(lb["ent_labels"]) for lb in labels]), 1)
num_ents = len(self.label_maps["entity2id"])
batch_entity_labels = paddle.zeros(
shape=[bs, num_ents, max_ent_num, 2], dtype="int64")
Expand All @@ -67,8 +68,9 @@ def __call__(

batch.append([batch_entity_labels])
else:
max_ent_num = max([len(lb["ent_labels"]) for lb in labels])
max_spo_num = max([len(lb["rel_labels"]) for lb in labels])
# Ensure the dimension is greater or equal to 1
max_ent_num = max(max([len(lb["ent_labels"]) for lb in labels]), 1)
max_spo_num = max(max([len(lb["rel_labels"]) for lb in labels]), 1)
num_ents = len(self.label_maps["entity2id"])
if "relation2id" in self.label_maps.keys():
num_rels = len(self.label_maps["relation2id"])
Expand Down

0 comments on commit 4ce8fd9

Please sign in to comment.