-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
131 lines (111 loc) · 4.99 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/user/bin/python3
# -*- coding:utf-8 -*-
from copy import deepcopy
import numpy as np
import torch
from torch.utils.data import Dataset
from utils import *
# from utils import data_loader
from sklearn.preprocessing import OneHotEncoder
class MultiViewDataWithoutLeak(Dataset):
def __init__(self, dataname, multi_view, std, train=True):
super(MultiViewDataWithoutLeak, self).__init__()
self.multi_view = multi_view
self.std = std
self.train = train
if dataname is not None:
x, y, label = data_loader_without_leak(dataname, self.multi_view, std)
self.x = x
self.y = y
self.label = label
else:
self.x = dict()
self.y = None
self.label = None
def __getitem__(self, index):
data = dict()
for n_v in range(len(self.x)):
data[n_v] = self.x[n_v][index]
# if self.std > 0:
# if not self.train:
# noise = np.random.normal(0, 1000000, data[n_v].shape).astype(np.float32)
# data[n_v] += noise
y = self.y[index]
label = self.label[index]
return data, y, label
def __len__(self):
return len(self.x[0])
def postprocessing(self, addNoise=False, ratio_noise=0.5, addConflict=False, ratio_conflict=0.5):
if addConflict:
self.addConflict(ratio_conflict)
pass
if addNoise:
self.addNoise(ratio_noise)
def addNoise(self, ratio):
num_classes = len(np.unique(self.label))
num_views = len(self.x)
for i in range(len(self.label)):
if num_views == 6:
vs = np.random.choice(num_views, size=3, replace=False)
for v in vs:
self.x[v][i] = self.x[v][i] * (1 - ratio) + ratio * np.random.normal(0, 10000)
else:
v = np.random.randint(num_views)
self.x[v][i] = self.x[v][i] * (1 - ratio) + + ratio * np.random.normal(0, 10000)
def addConflict(self, ratio):
num_classes = len(np.unique(self.label))
num_views = len(self.x)
for c in range(num_classes):
samples = np.where(self.label == c)[0].tolist()
other_class_indices = (self.label != c).nonzero(as_tuple=True)[0]
if len(other_class_indices) > 0:
for sample in samples:
other_sample_idx = other_class_indices[torch.randint(0, len(other_class_indices), (1,)).item()]
if num_views == 6:
vs = np.random.choice(num_views, size=3, replace=False)
for v in vs:
self.x[v][sample] = self.x[v][sample] * (1 - ratio) + ratio * self.x[v][other_sample_idx]
else:
v = np.random.randint(num_views)
self.x[v][sample] = self.x[v][sample] * (1 - ratio) + ratio * self.x[v][other_sample_idx]
def Split_Training_Test_Dataset(self, train_sam_ratio=0.75, std=0.0):
sam_num = self.__len__()
train_sam_num = torch.tensor(torch.ceil(torch.tensor(sam_num * train_sam_ratio)).item(), dtype=torch.int64)
rand_sam_ind = torch.randperm(sam_num)
train_sam_ind = rand_sam_ind[range(train_sam_num)]
train_data_set = MultiViewDataWithoutLeak(None, self.multi_view, std)
n_views = len(self.x)
for view in range(n_views):
train_data_set.x[view] = self.x[view][train_sam_ind]
train_data_set.y = self.y[train_sam_ind]
train_data_set.label = self.label[train_sam_ind]
test_sam_ind = rand_sam_ind[range(train_sam_num, sam_num)]
test_data_set = MultiViewDataWithoutLeak(None, self.multi_view, std, train=False)
for view in range(n_views):
test_data_set.x[view] = self.x[view][test_sam_ind]
test_data_set.y = self.y[test_sam_ind]
test_data_set.label = self.label[test_sam_ind]
return train_data_set, test_data_set
class MultiViewData(Dataset):
def __init__(self, dataname, train=True):
super(MultiViewData, self).__init__()
self.data_loaders, self.label_loaders, self.y_loaders, idx_train, idx_val = data_loader(dataname)
if train:
self.x = self.data_loaders['train']
self.label = self.label_loaders['train']
self.y = self.y_loaders['train']
self.idx = idx_train
else:
self.x = self.data_loaders['val']
self.label = self.label_loaders['val']
self.y = self.y_loaders['val']
self.idx = idx_val
def __getitem__(self, index):
data = dict()
for n_v in range(len(self.x)):
data[n_v] = (self.x[n_v][index])
y = self.y[index]
label = self.label[index]
return data, y, label
def __len__(self):
return len(self.x[0])