Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NLP model interpretation #1752

Merged
merged 32 commits into from
Mar 25, 2022
Merged
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e786867
upload NLP interpretation
binlinquge Mar 10, 2022
63afacb
fix problems and relocate project
binlinquge Mar 14, 2022
4d715fc
remove abandoned picture
binlinquge Mar 14, 2022
ef6d8f6
remove abandoned picture
binlinquge Mar 14, 2022
d4e9e68
fix dead link in README
binlinquge Mar 14, 2022
04b7cb5
fix dead link in README
binlinquge Mar 14, 2022
3b0e081
Merge branch 'PaddlePaddle:develop' into develop
binlinquge Mar 15, 2022
b3337c3
fix code style problems
binlinquge Mar 15, 2022
4a54362
Merge branch 'develop' of https://github.com/binlinquge/PaddleNLP int…
binlinquge Mar 15, 2022
18bf999
Merge branch 'develop' into develop
ZeyuChen Mar 15, 2022
d766d94
fix CR round 1
binlinquge Mar 16, 2022
734911b
remove .gitkeep files
binlinquge Mar 16, 2022
bd2f2d2
fix code style
binlinquge Mar 16, 2022
916cff3
fix file encoding problem
binlinquge Mar 16, 2022
5111301
fix code style
binlinquge Mar 16, 2022
0145072
delete duplicated files due to directory rebuild
binlinquge Mar 16, 2022
6724d26
fix CR round 2
binlinquge Mar 18, 2022
1ed3581
fix code style
binlinquge Mar 18, 2022
8493aec
fix ernie tokenizer
Mar 21, 2022
3555ff9
fix code style
binlinquge Mar 21, 2022
1ff7953
fix problem from CR round 1
binlinquge Mar 22, 2022
f3112b9
fix bugs
binlinquge Mar 23, 2022
7bf71fc
fix README
binlinquge Mar 23, 2022
36357ed
remove duplicated files
binlinquge Mar 23, 2022
ee07f9d
deal with diff of old and new tokenizer results
binlinquge Mar 24, 2022
b180c59
fix CR round 4
binlinquge Mar 24, 2022
c75353b
fix code style
binlinquge Mar 24, 2022
ff71607
add missing dependence
binlinquge Mar 24, 2022
3270158
fix broken import path
binlinquge Mar 25, 2022
f06d261
move some data file to cloud
binlinquge Mar 25, 2022
ef88b5e
MRC upper case to lower case
binlinquge Mar 25, 2022
7bd929e
Merge branch 'develop' into develop
guoshengCS Mar 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix code style
binlinquge committed Mar 21, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 3555ff9386f2c1368db03b8bff6bbb99046dca42
Original file line number Diff line number Diff line change
@@ -179,7 +179,11 @@ def init_lstm_var(args):
if args.language == "ch":
tokenizer = ErnieTokenizer.from_pretrained(args.vocab_path)
padding_idx = tokenizer.vocab.get('[PAD]')
tokenizer.inverse_vocab = [item[0] for item in sorted(tokenizer.vocab.items(), key=lambda x: x[1])]
tokenizer.inverse_vocab = [
item[0]
for item in sorted(
tokenizer.vocab.items(), key=lambda x: x[1])
]
else:
vocab = Vocab.load_vocabulary(
args.vocab_path, unk_token='[UNK]', pad_token='[PAD]')
Original file line number Diff line number Diff line change
@@ -164,7 +164,7 @@ def init_roberta_var(args):
collate_fn=batchify_fn,
return_list=True)

return model, tokenizer, dataloader
return model, tokenizer, dataloader, dev_ds


def init_lstm_var(args):
@@ -197,16 +197,17 @@ def init_lstm_var(args):
Stack(dtype="int64"), # title_seq_lens
): [data for data in fn(samples)]

return model, tokenizer, batches, batchify_fn, vocab
return model, tokenizer, batches, batchify_fn, vocab, dev_ds


if __name__ == "__main__":
args = get_args()
if args.base_model.startswith('roberta'):
model, tokenizer, dataloader = init_roberta_var(args)
model, tokenizer, dataloader, dev_ds = init_roberta_var(args)

elif args.base_model == 'lstm':
model, tokenizer, dataloader, batchify_fn, vocab = init_lstm_var(args)
model, tokenizer, dataloader, batchify_fn, vocab, dev_ds = init_lstm_var(
args)
else:
raise ValueError('unsupported base model name.')

@@ -255,7 +256,7 @@ def init_lstm_var(args):
vocab._idx_to_token[idx] for idx in title_ids.tolist()[0]
]

result['id'] = dataloader.dataset.data[step]['id']
result['id'] = dev_ds.data[step]['id']

probs, atts, embedded = model.forward_interpret(*fwd_args,
**fwd_kwargs)
2 changes: 1 addition & 1 deletion examples/model_interpretation/task/senti/run_inter_all.sh
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ for BASE_MODEL in "lstm" "roberta_base" "roberta_large";
do
for INTER_MODE in "attention" "integrated_gradient" "lime";
do
for LANGUAGE in "ch";
for LANGUAGE in "ch" "en";
do
TASK=senti_${LANGUAGE}
DATA=../../data/${TASK}
Original file line number Diff line number Diff line change
@@ -171,7 +171,11 @@ def init_lstm_var(args):
if args.language == "ch":
tokenizer = ErnieTokenizer.from_pretrained(args.vocab_path)
padding_idx = tokenizer.vocab.get('[PAD]')
tokenizer.inverse_vocab = [item[0] for item in sorted(tokenizer.vocab.items(), key=lambda x: x[1])]
tokenizer.inverse_vocab = [
item[0]
for item in sorted(
tokenizer.vocab.items(), key=lambda x: x[1])
]
else:
vocab = Vocab.load_vocabulary(
args.vocab_path, unk_token='[UNK]', pad_token='[PAD]')
Original file line number Diff line number Diff line change
@@ -179,7 +179,7 @@ def init_roberta_var(args):
collate_fn=batchify_fn,
return_list=True)

return model, tokenizer, dataloader
return model, tokenizer, dataloader, dev_ds


def init_lstm_var(args):
@@ -207,7 +207,7 @@ def init_lstm_var(args):
Stack(dtype="int64"), # title_seq_lens
): [data for data in fn(samples)]

return model, tokenizer, batches, batchify_fn, vocab
return model, tokenizer, batches, batchify_fn, vocab, dev_ds


def get_seq_token_num(language):
@@ -550,9 +550,10 @@ def LIME_error_evaluation(exp_q, pred_label, probs, lime_score_total,
if __name__ == "__main__":
args = get_args()
if args.base_model.startswith('roberta'):
model, tokenizer, dataloader = init_roberta_var(args)
model, tokenizer, dataloader, dev_ds = init_roberta_var(args)
elif args.base_model == 'lstm':
model, tokenizer, dataloader, batchify_fn, vocab = init_lstm_var(args)
model, tokenizer, dataloader, batchify_fn, vocab, dev_ds = init_lstm_var(
args)
else:
raise ValueError('unsupported base model name.')

@@ -598,7 +599,7 @@ def LIME_error_evaluation(exp_q, pred_label, probs, lime_score_total,
batchify_fn=batchify_fn,
vocab=vocab)

result['id'] = dataloader.dataset.data[step]['id']
result['id'] = dev_ds.data[step]['id']

probs, atts, embedded = model.forward_interpret(*fwd_args,
**fwd_kwargs)