Skip to content

Commit 087ebf7

Browse files
Ben EvansBen Evans
Ben Evans
authored and
Ben Evans
committed
Reformatted code and added comments to TestAll
1 parent 31ad172 commit 087ebf7

7 files changed

+74
-45
lines changed

Diff for: TestAll.java

+25-7
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,15 @@
66

77

88
/**
9-
* Created by ben on 18/04/17.
9+
* Runs the four selection methods against a given
10+
* dataset. This is a useful class for running all the
11+
* code at once and checking the output.
12+
*
13+
* The only tests this actually does is checks
14+
* the size of the subsets returned from the numFeatures
15+
* methods is less than or equal to the specified size.
1016
*/
17+
1118
public class TestAll {
1219

1320
// Since weka treats attributes and classes uniformly, must explicitly state class indiex
@@ -129,18 +136,29 @@ public void testSequentialFloatingBackwardSelectionNumFeatures() throws Exceptio
129136
*/
130137

131138
private FeatureSelection generateSelector(Selection method) throws Exception {
139+
FeatureSelection selector = null;
132140
switch (method){
133141
case SBS:
134-
return new SequentialBackwardSelection(FILE_NAME, CLASS_INDEX);
142+
selector = new SequentialBackwardSelection(FILE_NAME, CLASS_INDEX);
143+
break;
135144
case SFS:
136-
return new SequentialForwardSelection(FILE_NAME, CLASS_INDEX);
145+
selector = new SequentialForwardSelection(FILE_NAME, CLASS_INDEX);
146+
break;
137147
case SFBS:
138-
return new SequentialFloatingBackwardSelection(FILE_NAME, CLASS_INDEX);
148+
selector = new SequentialFloatingBackwardSelection(FILE_NAME, CLASS_INDEX);
149+
break;
139150
case SFFS:
140-
return new SequentialFloatingForwardSelection(FILE_NAME, CLASS_INDEX);
141-
default:
142-
return null;
151+
selector = new SequentialFloatingForwardSelection(FILE_NAME, CLASS_INDEX);
152+
break;
153+
}
154+
155+
// Special case for musk
156+
if(FILE_NAME.equals("musk.arff")){
157+
// There is a "giveaway" feature (molecule_name) which stores some class information
158+
selector.removeAttribute(0);
143159
}
160+
161+
return selector;
144162
}
145163

146164
private enum Selection {

Diff for: selection/Classifier.java

+19-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package selection;
22

3-
import weka.classifiers.*;
3+
import weka.classifiers.Evaluation;
44
import weka.classifiers.lazy.IBk;
55
import weka.classifiers.meta.FilteredClassifier;
66
import weka.core.Instances;
@@ -17,14 +17,14 @@
1717
import java.util.stream.IntStream;
1818

1919
/**
20-
* Reads isntances from file and splits them into
21-
*
22-
* - Training: Used to train the model (classifier)
23-
* - Validation: Used to check performance throughout, avoid overfitting to training
24-
* - Testing: Used only at the end to evaluate learnt performance
25-
*
26-
* This uses weka: http://www.cs.waikato.ac.nz/ml/weka/
27-
* both for the classifier and instances.
20+
* Reads isntances from file and splits them into
21+
* <p>
22+
* - Training: Used to train the model (classifier)
23+
* - Validation: Used to check performance throughout, avoid overfitting to training
24+
* - Testing: Used only at the end to evaluate learnt performance
25+
* <p>
26+
* This uses weka: http://www.cs.waikato.ac.nz/ml/weka/
27+
* both for the classifier and instances.
2828
*/
2929
public class Classifier {
3030

@@ -40,7 +40,7 @@ public class Classifier {
4040
private int CLASS_INDEX;
4141

4242
public Classifier(String fileName) throws Exception {
43-
Instances instances = readArffFile(fileName);
43+
Instances instances = readArffFile(fileName);
4444
instances.randomize(new java.util.Random(123));
4545

4646
int trainSize = (int) Math.round(instances.numInstances() * 0.6);
@@ -53,7 +53,7 @@ public Classifier(String fileName) throws Exception {
5353
}
5454

5555
public Classifier(String trainingFileName, String testingFileName) throws Exception {
56-
Instances instances = readArffFile(trainingFileName);
56+
Instances instances = readArffFile(trainingFileName);
5757
instances.randomize(new java.util.Random(0));
5858

5959
int trainSize = (int) Math.round(instances.numInstances() * 0.8);
@@ -79,7 +79,7 @@ private Instances removeAttribute(int index, Instances instances) throws Excepti
7979
return Filter.useFilter(instances, remove);
8080
}
8181

82-
public void setClassIndex(int index){
82+
public void setClassIndex(int index) {
8383
this.CLASS_INDEX = index;
8484

8585
training.setClassIndex(CLASS_INDEX);
@@ -89,6 +89,7 @@ public void setClassIndex(int index){
8989

9090
/**
9191
* The method of classification to use
92+
*
9293
* @return
9394
* @throws Exception
9495
*/
@@ -185,22 +186,22 @@ private double evaluateOnTesting(weka.classifiers.Classifier classifier) throws
185186
* @param toKeep
186187
* @return
187188
*/
188-
private int[] remove(Set<Integer> toKeep){
189+
private int[] remove(Set<Integer> toKeep) {
189190

190191
List<Integer> toRemove = new ArrayList<Integer>();
191192

192-
for(int i=0; i<training.numAttributes(); i++){
193-
if (!toKeep.contains(i) && i != CLASS_INDEX){
193+
for (int i = 0; i < training.numAttributes(); i++) {
194+
if (!toKeep.contains(i) && i != CLASS_INDEX) {
194195
toRemove.add(i);
195196
}
196197
}
197198

198199
// Convert list to int[]
199-
return toRemove.stream().mapToInt(i->i).toArray();
200+
return toRemove.stream().mapToInt(i -> i).toArray();
200201
}
201202

202203

203-
private Instances readArffFile (String fileName) throws IOException{
204+
private Instances readArffFile(String fileName) throws IOException {
204205
BufferedReader reader = new BufferedReader(
205206
new FileReader("src/res/" + fileName));
206207

@@ -209,7 +210,7 @@ private Instances readArffFile (String fileName) throws IOException{
209210
return instances;
210211
}
211212

212-
public int getNumFeatures(){
213+
public int getNumFeatures() {
213214
return training.numAttributes();
214215
}
215216

Diff for: selection/FeatureSelection.java

+21-13
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,21 @@ public abstract class FeatureSelection {
1414

1515
// The number of iterations to try if no improvement is made
1616
protected final int MAX_ITERATIONS_WITHOUT_PROGRESS = 10;
17-
17+
private final boolean DEBUG = true;
1818
// The wrapped classifier to use
1919
private Classifier classifier;
2020

21-
private final boolean DEBUG = true;
22-
23-
public FeatureSelection (String fileName, int classIndex) throws Exception {
21+
public FeatureSelection(String fileName, int classIndex) throws Exception {
2422
this.classifier = new Classifier(fileName);
25-
//TODO: Only needed for musk
26-
this.classifier.removeAttribute(0);
2723
this.classifier.setClassIndex(classIndex);
2824
}
2925

30-
public FeatureSelection (String trainingFile, String testingFile, int classIndex) throws Exception {
26+
public FeatureSelection(String trainingFile, String testingFile, int classIndex) throws Exception {
3127
this.classifier = new Classifier(trainingFile, testingFile);
3228
this.classifier.setClassIndex(classIndex);
3329
}
3430

31+
3532
/**
3633
* Returns a subset of only the most important features,
3734
* the parameter specifies the maximum number of features to select (m).
@@ -117,6 +114,17 @@ protected double objectiveFunction(Set<Integer> selectedFeatures) throws Excepti
117114
return classifier.classify(selectedFeatures);
118115
}
119116

117+
/**
118+
* Removes the specified attribute, this is useful if the dataset
119+
* has extra "information" variables that give away the class.
120+
*
121+
* @param index
122+
* @throws Exception
123+
*/
124+
public void removeAttribute(int index) throws Exception {
125+
this.classifier.removeAttribute(0);
126+
}
127+
120128

121129
/***
122130
* Uses the testing instances to check the performance
@@ -136,8 +144,8 @@ public void compareTestingAccuracy(Set<Integer> selectedIndices) throws Exceptio
136144
* @param size
137145
* @param accuracy
138146
*/
139-
protected void printAccuracy (int size, double accuracy){
140-
if(DEBUG) System.out.println(size + ": " + accuracy);
147+
protected void printAccuracy(int size, double accuracy) {
148+
if (DEBUG) System.out.println(size + ": " + accuracy);
141149
}
142150

143151
/***
@@ -161,19 +169,19 @@ protected Set<Integer> getAllFeatureIndices() {
161169
* ===============
162170
*/
163171

164-
protected boolean greaterThan(double d1, double d2){
172+
protected boolean greaterThan(double d1, double d2) {
165173
return Double.compare(d1, d2) > 0;
166174
}
167175

168-
protected boolean lessThan(double d1, double d2){
176+
protected boolean lessThan(double d1, double d2) {
169177
return Double.compare(d1, d2) < 0;
170178
}
171179

172-
protected boolean lessThanOrEqualTo(double d1, double d2){
180+
protected boolean lessThanOrEqualTo(double d1, double d2) {
173181
return Double.compare(d1, d2) <= 0;
174182
}
175183

176-
protected boolean equalTo(double d1, double d2){
184+
protected boolean equalTo(double d1, double d2) {
177185
return Double.compare(d1, d2) == 0;
178186
}
179187

Diff for: selection/SequentialBackwardSelection.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
/**
77
* Performs Sequential Backward Selection (SBS)
8-
*
8+
* <p>
99
* - Starts with full set of features
1010
* - Repeatedly removes the "worst" feature until
1111
* stopping criteria is met,
@@ -15,6 +15,7 @@ public class SequentialBackwardSelection extends FeatureSelection {
1515
public SequentialBackwardSelection(String file, int classIndex) throws Exception {
1616
super(file, classIndex);
1717
}
18+
1819
public SequentialBackwardSelection(String training, String testing, int classIndex) throws Exception {
1920
super(training, testing, classIndex);
2021
}
@@ -92,6 +93,4 @@ private Set<Integer> select(Criteria criteria, int maxNumFeatures) throws Except
9293
}
9394

9495

95-
96-
9796
}

Diff for: selection/SequentialFloatingBackwardSelection.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
/**
77
* Performs Sequential Floating Backward Selection (SFBS)
8-
*
8+
* <p>
99
* - Starts with full set of features
1010
* - Removes the "worst" feature
1111
* - Performs SFS as long as the objective function increases
@@ -16,6 +16,7 @@ public class SequentialFloatingBackwardSelection extends FeatureSelection {
1616
public SequentialFloatingBackwardSelection(String file, int classIndex) throws Exception {
1717
super(file, classIndex);
1818
}
19+
1920
public SequentialFloatingBackwardSelection(String training, String testing, int classIndex) throws Exception {
2021
super(training, testing, classIndex);
2122
}
@@ -109,7 +110,7 @@ private Set<Integer> select(Criteria criteria, int maxNumFeatures) throws Except
109110

110111
// If the accuracy is higher than our previous best, or the same with less features and its a valid size (<= maxFeatures)
111112
if ((greaterThan(accuracy, highestAccuracy) || (equalTo(accuracy, highestAccuracy) && selectedFeatures.size() < bestSoFar.size()))
112-
&& selectedFeatures.size() <= maxNumFeatures) {
113+
&& selectedFeatures.size() <= maxNumFeatures) {
113114
highestAccuracy = accuracy;
114115
// Save our best set
115116
bestSoFar = new HashSet<>(selectedFeatures);

Diff for: selection/SequentialFloatingForwardSelection.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
/**
77
* Performs Sequential Floating Forward Selection (SFFS)
8-
*
8+
* <p>
99
* - Starts with empty set of features
1010
* - Adds the "best" feature
1111
* - Performs SBS as long as the objective function increases
@@ -16,6 +16,7 @@ public class SequentialFloatingForwardSelection extends FeatureSelection {
1616
public SequentialFloatingForwardSelection(String file, int classIndex) throws Exception {
1717
super(file, classIndex);
1818
}
19+
1920
public SequentialFloatingForwardSelection(String training, String testing, int classIndex) throws Exception {
2021
super(training, testing, classIndex);
2122
}

Diff for: selection/SequentialForwardSelection.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
/**
77
* Performs Sequential Floating Forward Selection (SFFS)
8-
*
8+
* <p>
99
* - Starts with empty set of features
1010
* - Adds the "best" feature until stopping criteria is met
1111
*/
@@ -14,6 +14,7 @@ public class SequentialForwardSelection extends FeatureSelection {
1414
public SequentialForwardSelection(String file, int classIndex) throws Exception {
1515
super(file, classIndex);
1616
}
17+
1718
public SequentialForwardSelection(String training, String testing, int classIndex) throws Exception {
1819
super(training, testing, classIndex);
1920
}

0 commit comments

Comments
 (0)