-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathattack_utils.py
129 lines (108 loc) · 5.07 KB
/
attack_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import warnings
from distutils.version import LooseVersion
from typing import Callable, Dict, Optional, Union
import torch
from adv_lib.utils import BackwardCounter, ForwardCounter
from adv_lib.utils.attack_utils import _default_metrics
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import ConfusionMatrix
def run_attack(model: nn.Module,
loader: DataLoader,
attack: Callable,
target: Optional[Union[int, Tensor]] = None,
metrics: Dict[str, Callable] = _default_metrics,
return_adv: bool = False) -> dict:
device = next(model.parameters()).device
targeted = True if target is not None else False
loader_length = len(loader)
image_list = getattr(loader.sampler.data_source, 'dataset', loader.sampler.data_source).images
start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
forward_counter, backward_counter = ForwardCounter(), BackwardCounter()
model.register_forward_pre_hook(forward_counter)
if LooseVersion(torch.__version__) >= LooseVersion('1.8'):
model.register_full_backward_hook(backward_counter)
else:
model.register_backward_hook(backward_counter)
forwards, backwards = [], [] # number of forward and backward calls per sample
times, accuracies, apsrs, apsrs_orig = [], [], [], []
distances = {k: [] for k in metrics.keys()}
if return_adv:
images, adv_images = [], []
for i, (image, label) in enumerate(tqdm(loader, ncols=80, total=loader_length)):
if return_adv:
images.append(image.clone())
image, label = image.to(device), label.to(device).squeeze(1).long()
if targeted:
if isinstance(target, Tensor):
attack_label = target.to(device).expand(image.shape[0], -1, -1)
elif isinstance(target, int):
attack_label = torch.full_like(label, fill_value=target)
else:
attack_label = label
logits = model(image)
if i == 0:
num_classes = logits.size(1)
confmat_orig = ConfusionMatrix(num_classes=num_classes)
confmat_adv = ConfusionMatrix(num_classes=num_classes)
mask = label < num_classes
mask_sum = mask.flatten(1).sum(dim=1)
pred = logits.argmax(dim=1)
accuracies.extend(((pred == label) & mask).flatten(1).sum(dim=1).div(mask_sum).cpu().tolist())
confmat_orig.update(label, pred)
if targeted:
target_mask = attack_label < logits.size(1)
target_sum = target_mask.flatten(1).sum(dim=1)
apsrs_orig.extend(((pred == attack_label) & target_mask).flatten(1).sum(dim=1).div(target_sum).cpu().tolist())
else:
apsrs_orig.extend(((pred != label) & mask).flatten(1).sum(dim=1).div(mask_sum).cpu().tolist())
forward_counter.reset(), backward_counter.reset()
start.record()
adv_image = attack(model=model, inputs=image, labels=attack_label, targeted=targeted)
# performance monitoring
end.record()
torch.cuda.synchronize()
times.append((start.elapsed_time(end)) / 1000) # times for cuda Events are in milliseconds
forwards.append(forward_counter.num_samples_called)
backwards.append(backward_counter.num_samples_called)
forward_counter.reset(), backward_counter.reset()
if adv_image.min() < 0 or adv_image.max() > 1:
warnings.warn('Values of produced adversarials are not in the [0, 1] range -> Clipping to [0, 1].')
adv_image.clamp_(min=0, max=1)
if return_adv:
adv_images.append(adv_image.cpu().clone())
adv_logits = model(adv_image)
adv_pred = adv_logits.argmax(dim=1)
confmat_adv.update(label, adv_pred)
if targeted:
apsrs.extend(((adv_pred == attack_label) & target_mask).flatten(1).sum(dim=1).div(target_sum).cpu().tolist())
else:
apsrs.extend(((adv_pred != label) & mask).flatten(1).sum(dim=1).div(mask_sum).cpu().tolist())
for metric, metric_func in metrics.items():
distances[metric].extend(metric_func(adv_image, image).detach().cpu().tolist())
acc_global, accs, ious = confmat_orig.compute()
adv_acc_global, adv_accs, adv_ious = confmat_adv.compute()
data = {
'image_names': image_list[:len(apsrs)],
'targeted': targeted,
'accuracy': accuracies,
'acc_global': acc_global.item(),
'adv_acc_global': adv_acc_global.item(),
'ious': ious.cpu().tolist(),
'adv_ious': adv_ious.cpu().tolist(),
'apsr_orig': apsrs_orig,
'apsr': apsrs,
'times': times,
'num_forwards': forwards,
'num_backwards': backwards,
'distances': distances,
}
if return_adv:
shapes = [img.shape for img in images]
if len(set(shapes)) == 1:
images = torch.cat(images, dim=0)
adv_images = torch.cat(adv_images, dim=0)
data['images'] = images
data['adv_images'] = adv_images
return data