-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathrecord.py
54 lines (47 loc) · 2.01 KB
/
record.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import torch
from operator import truediv
def evaluate_accuracy(data_iter, net, loss, device):
acc_sum, n = 0.0, 0
with torch.no_grad():
for X, y in data_iter:
test_l_sum, test_num = 0, 0
#X = X.permute(0, 3, 1, 2)
X = X.to(device)
y = y.to(device)
net.eval()
y_hat = net(X)
l = loss(y_hat, y.long())
acc_sum += (y_hat.argmax(dim=1) == y.to(device)).float().sum().cpu().item()
test_l_sum += l
test_num += 1
net.train()
n += y.shape[0]
return [acc_sum / n, test_l_sum] # / test_num]
def aa_and_each_accuracy(confusion_matrix):
list_diag = np.diag(confusion_matrix)
list_raw_sum = np.sum(confusion_matrix, axis=1)
each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum))
average_acc = np.mean(each_acc)
return each_acc, average_acc
def record_output(oa_ae, aa_ae, kappa_ae, element_acc_ae, path):
f = open(path, 'w')
sentence0 = 'OAs for each iteration are:' + str(oa_ae) + '\n'
f.write(sentence0)
sentence1 = 'AAs for each iteration are:' + str(aa_ae) + '\n'
f.write(sentence1)
sentence2 = 'KAPPAs for each iteration are:' + str(kappa_ae) + '\n' + '\n'
f.write(sentence2)
sentence3 = 'mean_OA ± std_OA is: ' + str(np.mean(oa_ae)) + ' ± ' + str(np.std(oa_ae)) + '\n'
f.write(sentence3)
sentence4 = 'mean_AA ± std_AA is: ' + str(np.mean(aa_ae)) + ' ± ' + str(np.std(aa_ae)) + '\n'
f.write(sentence4)
sentence5 = 'mean_KAPPA ± std_KAPPA is: ' + str(np.mean(kappa_ae)) + ' ± ' + str(np.std(kappa_ae)) + '\n' + '\n'
f.write(sentence5)
element_mean = np.mean(element_acc_ae, axis=0)
element_std = np.std(element_acc_ae, axis=0)
sentence8 = "Mean of all elements in confusion matrix: " + str(element_mean) + '\n'
f.write(sentence8)
sentence9 = "Standard deviation of all elements in confusion matrix: " + str(element_std) + '\n'
f.write(sentence9)
f.close()