2
2
3
3
import weka .classifiers .*;
4
4
import weka .classifiers .bayes .NaiveBayes ;
5
+ import weka .classifiers .functions .MultilayerPerceptron ;
5
6
import weka .classifiers .meta .FilteredClassifier ;
7
+ import weka .classifiers .pmml .consumer .SupportVectorMachineModel ;
6
8
import weka .classifiers .trees .J48 ;
9
+ import weka .classifiers .trees .RandomForest ;
7
10
import weka .core .Instance ;
8
11
import weka .core .Instances ;
9
12
import weka .filters .unsupervised .attribute .Remove ;
14
17
import java .io .IOException ;
15
18
import java .util .ArrayList ;
16
19
import java .util .List ;
20
+ import java .util .Random ;
17
21
import java .util .Set ;
22
+ import java .util .stream .Collectors ;
23
+ import java .util .stream .IntStream ;
18
24
19
25
/**
20
26
* Created by ben on 15/05/17.
21
27
*/
22
28
public class WekaClassifier {
23
29
30
+ // Data to learn model from
24
31
Instances training ;
25
- Instances testing ;
26
32
27
- public final int CLASS_INDEX = 0 ;
33
+ // Used to check performance of learnt model
34
+ Instances validation ;
35
+
36
+ // Only used for final evaluation
37
+ Instances testing ;
28
38
29
- // classifier
30
- weka . classifiers . Classifier classifier = new NaiveBayes () ;
39
+ private int CLASS_INDEX ;
40
+ private final int NUM_FOLDS = 3 ;
31
41
32
42
public WekaClassifier (String fileName ) throws Exception {
33
43
Instances instances = readArffFile (fileName );
34
44
35
- int trainSize = (int ) Math .round (instances .numInstances () * 0.7 );
36
- int testSize = instances .numInstances () - trainSize ;
45
+ instances .randomize (new java .util .Random (0 ));
46
+
47
+ int trainSize = (int ) Math .round (instances .numInstances () * 0.6 );
48
+ int validationSize = (int ) Math .round (instances .numInstances () * 0.2 );
49
+ int testSize = instances .numInstances () - trainSize - validationSize ;
37
50
38
51
this .training = new Instances (instances , 0 , trainSize );
39
- this .testing = new Instances (instances , trainSize , testSize );
52
+ this .validation = new Instances (instances , trainSize , validationSize );
53
+ this .testing = new Instances (instances , trainSize + validationSize , testSize );
40
54
41
- System .out .println ("Training instances: " + training .size ());
42
- System .out .println ("Testing instances: " + testing .size ());
55
+ setClassIndex ();
43
56
44
57
classify ();
45
58
}
46
59
47
60
public WekaClassifier (String trainingFileName , String testingFileName ) throws Exception {
61
+
62
+ Instances instances = readArffFile (trainingFileName );
63
+
64
+ // Split into training and validation
48
65
this .training = readArffFile (trainingFileName );
66
+
67
+ instances .randomize (new java .util .Random (0 ));
68
+
69
+ int trainSize = (int ) Math .round (instances .numInstances () * 0.8 );
70
+ int validationSize = instances .numInstances () - trainSize ;
71
+
72
+ this .training = new Instances (instances , 0 , trainSize );
73
+ this .validation = new Instances (instances , trainSize , validationSize );
74
+
75
+
49
76
this .testing = readArffFile (testingFileName );
50
77
78
+ setClassIndex ();
79
+
51
80
classify ();
52
81
}
53
82
83
+ public void setClassIndex (){
84
+ // this.CLASS_INDEX = training.numAttributes() - 1;
85
+ this .CLASS_INDEX = 0 ;
86
+ training .setClassIndex (CLASS_INDEX );
87
+ testing .setClassIndex (CLASS_INDEX );
88
+
89
+ }
90
+
54
91
public static void main (String [] args ) throws Exception {
55
92
new WekaClassifier ("wine.arff" );
56
93
}
@@ -67,12 +104,60 @@ private Instances readArffFile (String fileName) throws IOException{
67
104
}
68
105
69
106
public double classify () throws Exception {
107
+
108
+ // classifier
109
+ weka .classifiers .Classifier classifier = createClassifier ();
110
+
70
111
classifier .buildClassifier (training );
71
112
Evaluation eval = new Evaluation (training );
72
113
73
- eval .evaluateModel (classifier , training );
114
+ return evaluate (classifier );
115
+ }
116
+
117
+ public double testAccuracy () throws Exception {
118
+ // classifier
119
+ weka .classifiers .Classifier classifier = createClassifier ();
120
+
121
+ classifier .buildClassifier (training );
122
+ Evaluation eval = new Evaluation (training );
123
+
124
+ eval .evaluateModel (classifier , testing );
125
+
126
+ System .out .println (eval .toSummaryString ());
127
+
128
+ return eval .pctCorrect ();
129
+ }
130
+
131
+ private weka .classifiers .Classifier createClassifier () throws Exception {
132
+ return new NaiveBayes ();
133
+ }
134
+
135
+ public double testAccuracy (Set <Integer > indices ) throws Exception {
136
+ weka .classifiers .Classifier classifier = createClassifier ();
137
+
138
+ Remove rm = new Remove ();
139
+
140
+ int [] remove = remove (indices );
141
+ rm .setAttributeIndicesArray (remove );
142
+
143
+ FilteredClassifier fc = new FilteredClassifier ();
144
+ fc .setFilter (rm );
145
+
146
+ fc .setClassifier (classifier );
147
+ fc .buildClassifier (training );
148
+
149
+ Evaluation eval = new Evaluation (training );
150
+ eval .evaluateModel (fc , testing );
151
+
152
+ return eval .pctCorrect ();
153
+ }
154
+
155
+ private double evaluate (weka .classifiers .Classifier classifier ) throws Exception {
156
+ Evaluation eval = new Evaluation (training );
157
+
158
+ eval .evaluateModel (classifier , validation );
159
+ // eval.crossValidateModel(classifier, training, NUM_FOLDS, new Random(1));
74
160
75
- System .out .println (eval .pctCorrect ());
76
161
return eval .pctCorrect ();
77
162
}
78
163
@@ -81,18 +166,18 @@ private int[] remove(Set<Integer> toKeep){
81
166
List <Integer > toRemove = new ArrayList <Integer >();
82
167
83
168
for (int i =0 ; i <training .numAttributes (); i ++){
84
- if (!toKeep .contains (i )){
169
+ if (!toKeep .contains (i ) && i != CLASS_INDEX ){
85
170
toRemove .add (i );
86
171
}
87
172
}
88
173
89
- System .out .println ("Removing features: " + toRemove );
90
-
91
174
// Convert list to int[]
92
175
return toRemove .stream ().mapToInt (i ->i ).toArray ();
93
176
}
94
177
95
178
public double classify (Set <Integer > indices ) throws Exception {
179
+ weka .classifiers .Classifier classifier = createClassifier ();
180
+
96
181
Remove rm = new Remove ();
97
182
98
183
int [] remove = remove (indices );
@@ -105,12 +190,24 @@ public double classify(Set<Integer> indices) throws Exception {
105
190
// train and make predictions
106
191
fc .buildClassifier (training );
107
192
108
- Evaluation eval = new Evaluation ( training );
109
- eval . evaluateModel ( fc , training );
193
+ return evaluate ( fc );
194
+ }
110
195
111
- System .out .println (eval .pctCorrect ());
196
+ public int getNumFeatures (){
197
+ return training .numAttributes ();
198
+ }
112
199
113
- return eval .pctCorrect ();
200
+ public Set <Integer > getAllFeatureIndices () {
201
+ int totalFeatures = training .numAttributes ();
202
+
203
+ Set <Integer > features = IntStream .rangeClosed (0 , totalFeatures - 1 )
204
+ .boxed ().collect (Collectors .toSet ());
205
+
206
+ // Class shouldnt be considered a feature
207
+ features .remove (CLASS_INDEX );
208
+
209
+ // Return a set from 0..totalFeatures
210
+ return features ;
114
211
}
115
212
116
213
}
0 commit comments