Skip to content

Commit e715320

Browse files
Ben EvansBen Evans
Ben Evans
authored and
Ben Evans
committed
Changed classifier to use Weka classifiers, as K-NN will be too slow for larger datasets. Added check to prevent loops in floating methods
1 parent 3ad7208 commit e715320

7 files changed

+229
-101
lines changed

TestAll.java

+18-23
Original file line numberDiff line numberDiff line change
@@ -25,43 +25,43 @@ public TestAll() throws FileNotFoundException {
2525
}
2626

2727
@org.junit.Test
28-
public void testSequentialForwardSelection() {
28+
public void testSequentialForwardSelection() throws Exception {
2929
System.out.println("-------------------");
3030
System.out.println("Sequential forward selection");
31-
FeatureSelection selector = new SequentialForwardSelection(training, testing);
31+
FeatureSelection selector = new SequentialForwardSelection("wine.arff");
3232
Set<Integer> selectedIndices = selector.select();
3333
selector.compareTestingAccuracy(selectedIndices);
3434
System.out.println("-------------------");
3535
}
3636

3737
@org.junit.Test
38-
public void testSequentialForwardSelectionNumfeatures() {
38+
public void testSequentialForwardSelectionNumfeatures() throws Exception {
3939
int maxFeatures = 10;
4040
System.out.println("-------------------");
4141
System.out.println("Sequential forward selection for max " + maxFeatures + " features");
42-
FeatureSelection selector = new SequentialForwardSelection(training, testing);
42+
FeatureSelection selector = new SequentialForwardSelection("wine.arff");
4343
Set<Integer> selectedIndices = selector.select(maxFeatures);
4444
selector.compareTestingAccuracy(selectedIndices);
4545
System.out.println("-------------------");
4646
assertTrue(selectedIndices.size() <= maxFeatures);
4747
}
4848

4949
@org.junit.Test
50-
public void testSequentialBackwardSelection() {
50+
public void testSequentialBackwardSelection() throws Exception {
5151
System.out.println("-------------------");
5252
System.out.println("Sequential backward selection");
53-
FeatureSelection selector = new SequentialBackwardSelection(training, testing);
53+
FeatureSelection selector = new SequentialBackwardSelection("wine.arff");
5454
Set<Integer> selectedIndices = selector.select();
5555
selector.compareTestingAccuracy(selectedIndices);
5656
System.out.println("-------------------");
5757
}
5858

5959
@org.junit.Test
60-
public void testSequentialBackwardSelectionNumfeatures() {
60+
public void testSequentialBackwardSelectionNumfeatures() throws Exception {
6161
int maxFeatures = 10;
6262
System.out.println("-------------------");
6363
System.out.println("Sequential backward selection for max " + maxFeatures + " Features");
64-
FeatureSelection selector = new SequentialBackwardSelection(training, testing);
64+
FeatureSelection selector = new SequentialBackwardSelection("wine.arff");
6565
Set<Integer> selectedIndices = selector.select(maxFeatures);
6666
selector.compareTestingAccuracy(selectedIndices);
6767
System.out.println("-------------------");
@@ -75,21 +75,21 @@ public void testSequentialBackwardSelectionNumfeatures() {
7575
*/
7676

7777
@org.junit.Test
78-
public void testSequentialFloatingForwardSelection() {
78+
public void testSequentialFloatingForwardSelection() throws Exception {
7979
System.out.println("-------------------");
8080
System.out.println("Sequential floating forward selection");
81-
FeatureSelection selector = new SequentialFloatingForwardSelection(training, testing);
81+
FeatureSelection selector = new SequentialFloatingForwardSelection("wine.arff");
8282
Set<Integer> selectedIndices = selector.select();
8383
selector.compareTestingAccuracy(selectedIndices);
8484
System.out.println("-------------------");
8585
}
8686

8787
@org.junit.Test
88-
public void testSequentialFloatingForwardSelectionNumFeatures() {
88+
public void testSequentialFloatingForwardSelectionNumFeatures() throws Exception {
8989
int maxFeatures = 5;
9090
System.out.println("-------------------");
9191
System.out.println("Sequential floating forward selection for " + maxFeatures + " features");
92-
FeatureSelection selector = new SequentialFloatingForwardSelection(training, testing);
92+
FeatureSelection selector = new SequentialFloatingForwardSelection("wine.arff");
9393
Set<Integer> selectedIndices = selector.select(maxFeatures);
9494
selector.compareTestingAccuracy(selectedIndices);
9595
System.out.println("-------------------");
@@ -98,27 +98,28 @@ public void testSequentialFloatingForwardSelectionNumFeatures() {
9898
}
9999

100100
@org.junit.Test
101-
public void testSequentialBackwardFloatingSelection() {
101+
public void testSequentialBackwardFloatingSelection() throws Exception {
102102
System.out.println("-------------------");
103103
System.out.println("Sequential backward floating selection");
104-
FeatureSelection selector = new SequentialFloatingBackwardSelection(training, testing);
104+
FeatureSelection selector = new SequentialFloatingBackwardSelection("wine.arff");
105105
Set<Integer> selectedIndices = selector.select();
106106
selector.compareTestingAccuracy(selectedIndices);
107107
System.out.println("-------------------");
108108
}
109109

110110
@org.junit.Test
111-
public void testSequentialBackwardFloatingSelectionNumFeatures() {
111+
public void testSequentialBackwardFloatingSelectionNumFeatures() throws Exception {
112112
int maxFeatures = 10;
113113
System.out.println("-------------------");
114114
System.out.println("Sequential backward floating selection for " + maxFeatures + " features");
115-
FeatureSelection selector = new SequentialFloatingBackwardSelection(training, testing);
115+
FeatureSelection selector = new SequentialFloatingBackwardSelection("wine.arff");
116116
Set<Integer> selectedIndices = selector.select(maxFeatures);
117117
selector.compareTestingAccuracy(selectedIndices);
118118
System.out.println("-------------------");
119119
assertTrue(selectedIndices.size() <= maxFeatures);
120120
}
121121

122+
122123
private void loadWineSet() throws FileNotFoundException {
123124
Scanner scanner = new Scanner(new File("src/res/wine.data"));
124125
List<Instance> instances = new ArrayList<>();
@@ -144,10 +145,6 @@ private void loadWineSet() throws FileNotFoundException {
144145
private void loadIsoletSet() throws FileNotFoundException {
145146
this.training = loadIsoletSet(true);
146147
this.testing = loadIsoletSet(false);
147-
148-
System.out.println("Training size: " + training.size());
149-
System.out.println("Testing size: " + testing.size());
150-
151148
}
152149

153150
private List<Instance> loadIsoletSet(boolean training) throws FileNotFoundException {
@@ -156,9 +153,7 @@ private List<Instance> loadIsoletSet(boolean training) throws FileNotFoundExcept
156153

157154
Scanner scanner = new Scanner(new File("src/res/isolet-" + file + ".data"));
158155

159-
int occurences = 0;
160-
161-
while (scanner.hasNext() && occurences++ < 500) {
156+
while (scanner.hasNext()) {
162157
String line = scanner.nextLine();
163158
Instance isolet = createIsolet(line);
164159
instances.add(isolet);

clasification/WekaClassifier.java

+115-18
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import weka.classifiers.*;
44
import weka.classifiers.bayes.NaiveBayes;
5+
import weka.classifiers.functions.MultilayerPerceptron;
56
import weka.classifiers.meta.FilteredClassifier;
7+
import weka.classifiers.pmml.consumer.SupportVectorMachineModel;
68
import weka.classifiers.trees.J48;
9+
import weka.classifiers.trees.RandomForest;
710
import weka.core.Instance;
811
import weka.core.Instances;
912
import weka.filters.unsupervised.attribute.Remove;
@@ -14,43 +17,77 @@
1417
import java.io.IOException;
1518
import java.util.ArrayList;
1619
import java.util.List;
20+
import java.util.Random;
1721
import java.util.Set;
22+
import java.util.stream.Collectors;
23+
import java.util.stream.IntStream;
1824

1925
/**
2026
* Created by ben on 15/05/17.
2127
*/
2228
public class WekaClassifier {
2329

30+
// Data to learn model from
2431
Instances training;
25-
Instances testing;
2632

27-
public final int CLASS_INDEX = 0;
33+
// Used to check performance of learnt model
34+
Instances validation;
35+
36+
// Only used for final evaluation
37+
Instances testing;
2838

29-
// classifier
30-
weka.classifiers.Classifier classifier = new NaiveBayes();
39+
private int CLASS_INDEX;
40+
private final int NUM_FOLDS = 3;
3141

3242
public WekaClassifier (String fileName) throws Exception {
3343
Instances instances = readArffFile(fileName);
3444

35-
int trainSize = (int) Math.round(instances.numInstances() * 0.7);
36-
int testSize = instances.numInstances() - trainSize;
45+
instances.randomize(new java.util.Random(0));
46+
47+
int trainSize = (int) Math.round(instances.numInstances() * 0.6);
48+
int validationSize = (int) Math.round(instances.numInstances() * 0.2);
49+
int testSize = instances.numInstances() - trainSize - validationSize;
3750

3851
this.training = new Instances(instances, 0, trainSize);
39-
this.testing = new Instances(instances, trainSize, testSize);
52+
this.validation = new Instances(instances, trainSize, validationSize);
53+
this.testing = new Instances(instances, trainSize + validationSize, testSize);
4054

41-
System.out.println("Training instances: " + training.size());
42-
System.out.println("Testing instances: " + testing.size());
55+
setClassIndex();
4356

4457
classify();
4558
}
4659

4760
public WekaClassifier (String trainingFileName, String testingFileName) throws Exception {
61+
62+
Instances instances = readArffFile(trainingFileName);
63+
64+
// Split into training and validation
4865
this.training = readArffFile(trainingFileName);
66+
67+
instances.randomize(new java.util.Random(0));
68+
69+
int trainSize = (int) Math.round(instances.numInstances() * 0.8);
70+
int validationSize = instances.numInstances() - trainSize;
71+
72+
this.training = new Instances(instances, 0, trainSize);
73+
this.validation = new Instances(instances, trainSize, validationSize);
74+
75+
4976
this.testing = readArffFile(testingFileName);
5077

78+
setClassIndex();
79+
5180
classify();
5281
}
5382

83+
public void setClassIndex(){
84+
// this.CLASS_INDEX = training.numAttributes() - 1;
85+
this.CLASS_INDEX = 0;
86+
training.setClassIndex(CLASS_INDEX);
87+
testing.setClassIndex(CLASS_INDEX);
88+
89+
}
90+
5491
public static void main(String[] args) throws Exception {
5592
new WekaClassifier("wine.arff");
5693
}
@@ -67,12 +104,60 @@ private Instances readArffFile (String fileName) throws IOException{
67104
}
68105

69106
public double classify() throws Exception {
107+
108+
// classifier
109+
weka.classifiers.Classifier classifier = createClassifier();
110+
70111
classifier.buildClassifier(training);
71112
Evaluation eval = new Evaluation(training);
72113

73-
eval.evaluateModel(classifier, training);
114+
return evaluate(classifier);
115+
}
116+
117+
public double testAccuracy() throws Exception {
118+
// classifier
119+
weka.classifiers.Classifier classifier = createClassifier();
120+
121+
classifier.buildClassifier(training);
122+
Evaluation eval = new Evaluation(training);
123+
124+
eval.evaluateModel(classifier, testing);
125+
126+
System.out.println(eval.toSummaryString());
127+
128+
return eval.pctCorrect();
129+
}
130+
131+
private weka.classifiers.Classifier createClassifier() throws Exception {
132+
return new NaiveBayes();
133+
}
134+
135+
public double testAccuracy(Set<Integer> indices) throws Exception {
136+
weka.classifiers.Classifier classifier = createClassifier();
137+
138+
Remove rm = new Remove();
139+
140+
int[] remove = remove(indices);
141+
rm.setAttributeIndicesArray(remove);
142+
143+
FilteredClassifier fc = new FilteredClassifier();
144+
fc.setFilter(rm);
145+
146+
fc.setClassifier(classifier);
147+
fc.buildClassifier(training);
148+
149+
Evaluation eval = new Evaluation(training);
150+
eval.evaluateModel(fc, testing);
151+
152+
return eval.pctCorrect();
153+
}
154+
155+
private double evaluate(weka.classifiers.Classifier classifier) throws Exception {
156+
Evaluation eval = new Evaluation(training);
157+
158+
eval.evaluateModel(classifier, validation);
159+
// eval.crossValidateModel(classifier, training, NUM_FOLDS, new Random(1));
74160

75-
System.out.println(eval.pctCorrect());
76161
return eval.pctCorrect();
77162
}
78163

@@ -81,18 +166,18 @@ private int[] remove(Set<Integer> toKeep){
81166
List<Integer> toRemove = new ArrayList<Integer>();
82167

83168
for(int i=0; i<training.numAttributes(); i++){
84-
if (!toKeep.contains(i)){
169+
if (!toKeep.contains(i) && i != CLASS_INDEX){
85170
toRemove.add(i);
86171
}
87172
}
88173

89-
System.out.println("Removing features: " + toRemove);
90-
91174
// Convert list to int[]
92175
return toRemove.stream().mapToInt(i->i).toArray();
93176
}
94177

95178
public double classify(Set<Integer> indices) throws Exception {
179+
weka.classifiers.Classifier classifier = createClassifier();
180+
96181
Remove rm = new Remove();
97182

98183
int[] remove = remove(indices);
@@ -105,12 +190,24 @@ public double classify(Set<Integer> indices) throws Exception {
105190
// train and make predictions
106191
fc.buildClassifier(training);
107192

108-
Evaluation eval = new Evaluation(training);
109-
eval.evaluateModel(fc, training);
193+
return evaluate(fc);
194+
}
110195

111-
System.out.println(eval.pctCorrect());
196+
public int getNumFeatures(){
197+
return training.numAttributes();
198+
}
112199

113-
return eval.pctCorrect();
200+
public Set<Integer> getAllFeatureIndices() {
201+
int totalFeatures = training.numAttributes();
202+
203+
Set<Integer> features = IntStream.rangeClosed(0, totalFeatures - 1)
204+
.boxed().collect(Collectors.toSet());
205+
206+
// Class shouldnt be considered a feature
207+
features.remove(CLASS_INDEX);
208+
209+
// Return a set from 0..totalFeatures
210+
return features;
114211
}
115212

116213
}

0 commit comments

Comments
 (0)