Skip to content

Commit 39f8322

Browse files
committed
Add code
1 parent 1c0062b commit 39f8322

12 files changed

+749
-0
lines changed

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# DIG Code
2+
Code for "Discretized Integrated Gradients for Explaining Language Models"
3+
4+
<h2 align="center">
5+
Overview of variants of DIG
6+
<img align="center" src="./overview.png" alt="...">
7+
</h2>
8+
**Overview of paths used in DIG and IG**. w is the word being attributed. The gray region is the neighborhood of w. Green line depicts the straight-line path from w to w' used by IG and the green squares are the corresponding interpolation points. **Left**: In DIG-Greedy, we first monotonize each word in the neighborhood (red arrow). Then the word closest to its corresponding monotonic point is selected as the anchor (blue line to w_5 since the red arrow of w_5 has the shortest magnitude). **Right**: In DIG-MaxCount we first count the number of monotonic dimensions for each word in the neighborhood (shown in [.] above). Then, the word with the highest number of monotonic dimensions is selected as the anchor word (blue line to w_4), followed by changing the non-monotonic dimensions of w_4 (red line to c). Repeating this step gives the zigzag blue path. Finally, the red stars are the interpolated points used by our method. Please refer to the paper for more details.
9+
10+
### Dependencies
11+
12+
- Dependencies can be installed using `requirements.txt`.
13+
14+
### Evaluating DIG:
15+
16+
- Install all the requirements from `requirements.txt.`
17+
18+
- Execute `./setup.sh` for setting up the folder hierarchy for experiments.
19+
20+
- Commands for reproducing the reported results on DistilBERT fine-tuned on SST2:
21+
22+
```shell
23+
# Generate the KNN graph
24+
python knn.py -dataset sst2 -nn distilbert
25+
26+
# DIG (strategy: Greedy)
27+
python main.py -dataset sst2 -nn distilbert -strategy greedy
28+
29+
# DIG (strategy: MaxCount)
30+
python main.py -dataset sst2 -nn distilbert -strategy maxcount
31+
```
32+
Similarly, commands can be changed for other settings.

attributions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
from dig import DiscretetizedIntegratedGradients
3+
4+
def summarize_attributions(attributions):
5+
attributions = attributions.sum(dim=-1).squeeze(0)
6+
attributions = attributions / torch.norm(attributions)
7+
return attributions
8+
9+
def run_dig_explanation(dig_func, all_input_embed, position_embed, type_embed, attention_mask, steps):
10+
attributions = dig_func.attribute(scaled_features=all_input_embed, additional_forward_args=(attention_mask, position_embed, type_embed), n_steps=steps)
11+
attributions_word = summarize_attributions(attributions)
12+
13+
return attributions_word

bert_helper.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import torch, sys, pickle
2+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
3+
4+
5+
model, tokenizer = None, None
6+
7+
def nn_init(device, dataset, returns=False):
8+
global model, tokenizer
9+
if dataset == 'sst2':
10+
tokenizer = AutoTokenizer.from_pretrained('textattack/bert-base-uncased-SST-2')
11+
model = AutoModelForSequenceClassification.from_pretrained('textattack/bert-base-uncased-SST-2', return_dict=False)
12+
elif dataset == 'imdb':
13+
tokenizer = AutoTokenizer.from_pretrained('textattack/bert-base-uncased-imdb')
14+
model = AutoModelForSequenceClassification.from_pretrained('textattack/bert-base-uncased-imdb', return_dict=False)
15+
elif dataset == 'rotten':
16+
tokenizer = AutoTokenizer.from_pretrained('textattack/bert-base-uncased-rotten-tomatoes')
17+
model = AutoModelForSequenceClassification.from_pretrained('textattack/bert-base-uncased-rotten-tomatoes', return_dict=False)
18+
19+
model.to(device)
20+
model.eval()
21+
model.zero_grad()
22+
23+
if returns:
24+
return model, tokenizer
25+
26+
def move_to_device(device):
27+
global model
28+
model.to(device)
29+
30+
def predict(model, inputs_embeds, attention_mask=None):
31+
return model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)[0]
32+
33+
def nn_forward_func(input_embed, attention_mask=None, position_embed=None, type_embed=None, return_all_logits=False):
34+
global model
35+
embeds = input_embed + position_embed + type_embed
36+
embeds = model.bert.embeddings.dropout(model.bert.embeddings.LayerNorm(embeds))
37+
pred = predict(model, embeds, attention_mask=attention_mask)
38+
if return_all_logits:
39+
return pred
40+
else:
41+
return pred.max(1).values
42+
43+
def load_mappings(dataset, knn_nbrs=500):
44+
with open(f'processed/knns/bert_{dataset}_{knn_nbrs}.pkl', 'rb') as f:
45+
[word_idx_map, word_features, adj] = pickle.load(f)
46+
word_idx_map = dict(word_idx_map)
47+
48+
return word_idx_map, word_features, adj
49+
50+
def construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device):
51+
text_ids = tokenizer.encode(text, add_special_tokens=False, truncation=True, max_length=tokenizer.max_len_single_sentence)
52+
input_ids = [cls_token_id] + text_ids + [sep_token_id] # construct input token ids
53+
ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id] # construct reference token ids
54+
55+
return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device)
56+
57+
def construct_input_ref_pos_id_pair(input_ids, device):
58+
global model
59+
seq_length = input_ids.size(1)
60+
position_ids = model.bert.embeddings.position_ids[:,0:seq_length].to(device)
61+
ref_position_ids = model.bert.embeddings.position_ids[:,0:seq_length].to(device)
62+
63+
return position_ids, ref_position_ids
64+
65+
def construct_input_ref_token_type_pair(input_ids, device):
66+
seq_len = input_ids.size(1)
67+
token_type_ids = torch.tensor([[0] * seq_len], dtype=torch.long, device=device)
68+
ref_token_type_ids = torch.zeros_like(token_type_ids, dtype=torch.long, device=device)
69+
return token_type_ids, ref_token_type_ids
70+
71+
def construct_attention_mask(input_ids):
72+
return torch.ones_like(input_ids)
73+
74+
def get_word_embeddings():
75+
global model
76+
return model.bert.embeddings.word_embeddings.weight
77+
78+
def construct_word_embedding(model, input_ids):
79+
return model.bert.embeddings.word_embeddings(input_ids)
80+
81+
def construct_position_embedding(model, position_ids):
82+
return model.bert.embeddings.position_embeddings(position_ids)
83+
84+
def construct_type_embedding(model, type_ids):
85+
return model.bert.embeddings.token_type_embeddings(type_ids)
86+
87+
def construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids, type_ids, ref_type_ids):
88+
input_embeddings = construct_word_embedding(model, input_ids)
89+
ref_input_embeddings = construct_word_embedding(model, ref_input_ids)
90+
input_position_embeddings = construct_position_embedding(model, position_ids)
91+
ref_input_position_embeddings = construct_position_embedding(model, ref_position_ids)
92+
input_type_embeddings = construct_type_embedding(model, type_ids)
93+
ref_input_type_embeddings = construct_type_embedding(model, ref_type_ids)
94+
95+
return (input_embeddings, ref_input_embeddings), \
96+
(input_position_embeddings, ref_input_position_embeddings), \
97+
(input_type_embeddings, ref_input_type_embeddings)
98+
99+
def get_base_token_emb(device):
100+
global model
101+
return construct_word_embedding(model, torch.tensor([tokenizer.pad_token_id], device=device))
102+
103+
def get_tokens(text_ids):
104+
global tokenizer
105+
return tokenizer.convert_ids_to_tokens(text_ids.squeeze())
106+
107+
def get_inputs(text, device):
108+
global model, tokenizer
109+
ref_token_id = tokenizer.pad_token_id
110+
sep_token_id = tokenizer.sep_token_id
111+
cls_token_id = tokenizer.cls_token_id
112+
113+
input_ids, ref_input_ids = construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device)
114+
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids, device)
115+
type_ids, ref_type_ids = construct_input_ref_token_type_pair(input_ids, device)
116+
attention_mask = construct_attention_mask(input_ids)
117+
118+
(input_embed, ref_input_embed), (position_embed, ref_position_embed), (type_embed, ref_type_embed) = \
119+
construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids, type_ids, ref_type_ids)
120+
121+
return [input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, type_embed, ref_type_embed, attention_mask]

dig.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import typing, sys
2+
from typing import Any, Callable, List, Tuple, Union
3+
4+
import torch
5+
from torch import Tensor
6+
7+
from captum.log import log_usage
8+
from captum._utils.common import (_expand_additional_forward_args, _expand_target, _format_additional_forward_args, _format_output, _is_tuple)
9+
from captum._utils.typing import (BaselineType, Literal, TargetType, TensorOrTupleOfTensorsGeneric)
10+
from captum.attr._utils.approximation_methods import approximation_parameters
11+
from captum.attr._utils.attribution import GradientAttribution
12+
from captum.attr._utils.batching import _batch_attribution
13+
from captum.attr._utils.common import _format_input_baseline, _reshape_and_sum, _validate_input, _format_input
14+
15+
16+
class DiscretetizedIntegratedGradients(GradientAttribution):
17+
18+
def __init__(self, forward_func: Callable, multiply_by_inputs: bool = True) -> None:
19+
GradientAttribution.__init__(self, forward_func)
20+
self._multiply_by_inputs = multiply_by_inputs
21+
22+
@log_usage()
23+
def attribute(
24+
self,
25+
scaled_features: Tuple[Tensor, ...],
26+
target: TargetType = None,
27+
additional_forward_args: Any = None,
28+
n_steps: int = 50,
29+
return_convergence_delta: bool = False,
30+
) -> Union[
31+
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
32+
]:
33+
is_inputs_tuple = _is_tuple(scaled_features)
34+
scaled_features_tpl = _format_input(scaled_features)
35+
attributions = self.calculate_dig_attributions(scaled_features_tpl=scaled_features_tpl, target=target, additional_forward_args=additional_forward_args, n_steps=n_steps)
36+
if return_convergence_delta:
37+
assert len(scaled_features_tpl) == 1, 'More than one tuple not supported in this code!'
38+
start_point, end_point = _format_input(scaled_features_tpl[0][0].unsqueeze(0)), _format_input(scaled_features_tpl[0][-1].unsqueeze(0)) # baselines, inputs (only works for one input, i.e. len(tuple) == 1)
39+
# computes approximation error based on the completeness axiom
40+
delta = self.compute_convergence_delta(
41+
attributions,
42+
start_point,
43+
end_point,
44+
additional_forward_args=additional_forward_args,
45+
target=target,
46+
)
47+
return _format_output(is_inputs_tuple, attributions), delta
48+
49+
return _format_output(is_inputs_tuple, attributions)
50+
51+
def calculate_dig_attributions(
52+
self,
53+
scaled_features_tpl: Tuple[Tensor, ...],
54+
target: TargetType = None,
55+
additional_forward_args: Any = None,
56+
n_steps: int = 50,
57+
) -> Tuple[Tensor, ...]:
58+
additional_forward_args = _format_additional_forward_args(additional_forward_args)
59+
input_additional_args = (_expand_additional_forward_args(additional_forward_args, n_steps) if additional_forward_args is not None else None)
60+
expanded_target = _expand_target(target, n_steps)
61+
62+
# grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...)
63+
grads = self.gradient_func(
64+
forward_fn=self.forward_func,
65+
inputs=scaled_features_tpl,
66+
target_ind=expanded_target,
67+
additional_forward_args=input_additional_args,
68+
)
69+
70+
# calculate (x - x') for each interpolated point
71+
shifted_inputs_tpl = tuple(torch.cat([scaled_features[1:], scaled_features[-1].unsqueeze(0)]) for scaled_features in scaled_features_tpl)
72+
steps = tuple(shifted_inputs_tpl[i] - scaled_features_tpl[i] for i in range(len(shifted_inputs_tpl)))
73+
scaled_grads = tuple(grads[i] * steps[i] for i in range(len(grads)))
74+
75+
# aggregates across all steps for each tensor in the input tuple
76+
attributions = tuple(_reshape_and_sum(scaled_grad, n_steps, grad.shape[0] // n_steps, grad.shape[1:]) for (scaled_grad, grad) in zip(scaled_grads, grads))
77+
78+
return attributions

distilbert_helper.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import torch, sys, pickle
2+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
3+
4+
5+
model, tokenizer = None, None
6+
7+
def nn_init(device, dataset, returns=False):
8+
global model, tokenizer
9+
if dataset == 'sst2':
10+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
11+
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english", return_dict=False)
12+
elif dataset == 'imdb':
13+
tokenizer = AutoTokenizer.from_pretrained("textattack/distilbert-base-uncased-imdb")
14+
model = AutoModelForSequenceClassification.from_pretrained("textattack/distilbert-base-uncased-imdb", return_dict=False)
15+
elif dataset == 'rotten':
16+
tokenizer = AutoTokenizer.from_pretrained("textattack/distilbert-base-uncased-rotten-tomatoes")
17+
model = AutoModelForSequenceClassification.from_pretrained("textattack/distilbert-base-uncased-rotten-tomatoes", return_dict=False)
18+
19+
model.to(device)
20+
model.eval()
21+
model.zero_grad()
22+
23+
if returns:
24+
return model, tokenizer
25+
26+
def move_to_device(device):
27+
global model
28+
model.to(device)
29+
30+
def predict(model, inputs_embeds, attention_mask=None):
31+
return model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)[0]
32+
33+
def nn_forward_func(input_embed, attention_mask=None, position_embed=None, type_embed=None, return_all_logits=False):
34+
global model
35+
embeds = input_embed + position_embed
36+
embeds = model.distilbert.embeddings.dropout(model.distilbert.embeddings.LayerNorm(embeds))
37+
pred = predict(model, embeds, attention_mask=attention_mask)
38+
if return_all_logits:
39+
return pred
40+
else:
41+
return pred.max(1).values
42+
43+
def load_mappings(dataset, knn_nbrs=500):
44+
with open(f'processed/knns/distilbert_{dataset}_{knn_nbrs}.pkl', 'rb') as f:
45+
[word_idx_map, word_features, adj] = pickle.load(f)
46+
word_idx_map = dict(word_idx_map)
47+
48+
return word_idx_map, word_features, adj
49+
50+
def construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device):
51+
text_ids = tokenizer.encode(text, add_special_tokens=False, truncation=True, max_length=tokenizer.max_len_single_sentence)
52+
input_ids = [cls_token_id] + text_ids + [sep_token_id] # construct input token ids
53+
ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id] # construct reference token ids
54+
55+
return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device)
56+
57+
def construct_input_ref_pos_id_pair(input_ids, device):
58+
seq_length = input_ids.size(1)
59+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
60+
ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)
61+
62+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
63+
ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
64+
return position_ids, ref_position_ids
65+
66+
def construct_attention_mask(input_ids):
67+
return torch.ones_like(input_ids)
68+
69+
def get_word_embeddings():
70+
global model
71+
return model.distilbert.embeddings.word_embeddings.weight
72+
73+
def construct_word_embedding(model, input_ids):
74+
return model.distilbert.embeddings.word_embeddings(input_ids)
75+
76+
def construct_position_embedding(model, position_ids):
77+
return model.distilbert.embeddings.position_embeddings(position_ids)
78+
79+
def construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids):
80+
input_embeddings = construct_word_embedding(model, input_ids)
81+
ref_input_embeddings = construct_word_embedding(model, ref_input_ids)
82+
input_position_embeddings = construct_position_embedding(model, position_ids)
83+
ref_input_position_embeddings = construct_position_embedding(model, ref_position_ids)
84+
85+
return (input_embeddings, ref_input_embeddings), (input_position_embeddings, ref_input_position_embeddings)
86+
87+
def get_base_token_emb(device):
88+
global model
89+
return construct_word_embedding(model, torch.tensor([tokenizer.pad_token_id], device=device))
90+
91+
def get_tokens(text_ids):
92+
global tokenizer
93+
return tokenizer.convert_ids_to_tokens(text_ids.squeeze())
94+
95+
def get_inputs(text, device):
96+
global model, tokenizer
97+
ref_token_id = tokenizer.pad_token_id
98+
sep_token_id = tokenizer.sep_token_id
99+
cls_token_id = tokenizer.cls_token_id
100+
101+
input_ids, ref_input_ids = construct_input_ref_pair(tokenizer, text, ref_token_id, sep_token_id, cls_token_id, device)
102+
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids, device)
103+
attention_mask = construct_attention_mask(input_ids)
104+
105+
(input_embed, ref_input_embed), (position_embed, ref_position_embed) = construct_sub_embedding(model, input_ids, ref_input_ids, position_ids, ref_position_ids)
106+
107+
return [input_ids, ref_input_ids, input_embed, ref_input_embed, position_embed, ref_position_embed, None, None, attention_mask]

knn.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os, sys, numpy as np, pickle, argparse
2+
from sklearn.neighbors import kneighbors_graph
3+
4+
import torch
5+
6+
7+
def main(args):
8+
device = torch.device("cpu")
9+
10+
if args.nn == 'distilbert':
11+
from distilbert_helper import nn_init, get_word_embeddings, get_base_token_emb
12+
elif args.nn == 'roberta':
13+
from roberta_helper import nn_init, get_word_embeddings, get_base_token_emb
14+
elif args.nn == 'bert':
15+
from bert_helper import nn_init, get_word_embeddings, get_base_token_emb
16+
17+
print(f'Starting KNN computation..')
18+
19+
model, tokenizer = nn_init(device, args.dataset, returns=True)
20+
word_features = get_word_embeddings().cpu().detach().numpy()
21+
word_idx_map = tokenizer.get_vocab()
22+
A = kneighbors_graph(word_features, args.nbrs, mode='distance', n_jobs=args.procs)
23+
24+
knn_fname = f'processed/knns/{args.nn}_{args.dataset}_{args.nbrs}.pkl'
25+
with open(knn_fname, 'wb') as f:
26+
pickle.dump([word_idx_map, word_features, A], f)
27+
28+
print(f'Written KNN data at {knn_fname}')
29+
30+
31+
if __name__ == '__main__':
32+
parser = argparse.ArgumentParser(description='knn')
33+
parser.add_argument('-nn', default='distilbert', choices=['distilbert', 'roberta', 'bert'])
34+
parser.add_argument('-dataset', default='sst2', choices=['sst2', 'imdb', 'rotten'])
35+
parser.add_argument('-procs', default=40, type=int)
36+
parser.add_argument('-nbrs', default=500, type=int)
37+
38+
args = parser.parse_args()
39+
40+
main(args)

0 commit comments

Comments
 (0)