Skip to content

Commit 7b491a6

Browse files
Ben EvansBen Evans
Ben Evans
authored and
Ben Evans
committed
Changed stopping criteria to be no improvement over X iterations, instead of reaching some %
1 parent 7c7fe91 commit 7b491a6

5 files changed

+80
-28
lines changed

Diff for: Classifier.java

+11-8
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99

1010
/**
11-
* K-NN used as a classifier.
11+
* Classifier to use as the evaluation criteria
12+
* for feature selection. K-Nearest neigbhour is
13+
* implemented here.
1214
*/
1315
public class Classifier {
1416

@@ -35,24 +37,25 @@ public Classifier(Set<Instance> training, Set<Instance> testing){
3537
}
3638

3739
/**
38-
* Classifies and calculates the percentage
39-
* of correct classifications in the testingSet
40-
* against the training set.
40+
* Classifies and returns the percentage
41+
* of correct classifications using every feature
42+
* in the instances.
43+
*
4144
*/
4245
public double classify(){
4346
Instance sampleInstance = training.iterator().next();
4447
int totalFeatures = sampleInstance.getNumFeatures();
4548

46-
// To begin with all features are selected
49+
// We are using all features
4750
Set<Integer> allIndices = IntStream.rangeClosed(0, totalFeatures - 1)
4851
.boxed().collect(Collectors.toSet());
4952

5053
return classify(allIndices);
5154
}
5255
/**
53-
* Classifies and calculates the percentage
54-
* of correct classifications in the testingSet
55-
* against the training set.
56+
* Classifies and returns the percentage
57+
* of correct classifications using only the specified indices
58+
* for the instances.
5659
*/
5760
public double classify(Set<Integer> indices) {
5861
int correct = 0;

Diff for: FeatureSelection.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ public abstract class FeatureSelection {
1111
private Classifier classifier;
1212
protected Set<Instance> instances;
1313

14+
// The number of times to try extra subsets if no improvement is made
15+
protected final int MAX_ITERATIONS_WITHOUT_PROGRESS = 5;
16+
1417
public FeatureSelection(Set<Instance> instances){
1518
this.instances = instances;
1619
this.classifier = new Classifier(instances);
@@ -29,7 +32,7 @@ If a number of features is found (n), where n < m, with a higher accuracy,
2932
features. If numFeatures is >= original.size(), the original
3033
set is returned.
3134
*/
32-
public abstract Set<Integer> select(double minimumAccuracy);
35+
public abstract Set<Integer> select();
3336

3437
/**
3538
* Returns the feature in remaining features
@@ -86,7 +89,7 @@ protected double objectiveFunction(Set<Integer> selectedFeatures) {
8689
return classifier.classify(selectedFeatures);
8790
}
8891

89-
protected Set<Integer> getFeatures(){
92+
protected Set<Integer> getAllFeatureIndices(){
9093
// Extract an instance to check the amount of features, assumes all instances have same # of features
9194
Instance sampleInstance = instances.iterator().next();
9295
int totalFeatures = sampleInstance.getNumFeatures();
@@ -96,7 +99,7 @@ protected Set<Integer> getFeatures(){
9699
.boxed().collect(Collectors.toSet());
97100
}
98101

99-
public void compareAccuracy(Set<Integer> selectedIndices) {
102+
public void compareTestingAccuracy(Set<Integer> selectedIndices) {
100103
System.out.println("Classification accuracy on testing set using all features: " + classifier.classify());
101104
System.out.println("Classification accuracy on testing set using features " + selectedIndices + ": " + classifier.classify(selectedIndices));
102105
}

Diff for: SequentialBackwardsSelection.java

+24-6
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,37 @@ public SequentialBackwardsSelection(Set<Instance> instances){
1212

1313
@Override
1414
public Set<Integer> select(int maxNumFeatures) {
15-
return select((accuracy, size) -> size > maxNumFeatures);
15+
// While we have too many features or the accuracy is still improving
16+
return select((noImprovement, size) -> size > maxNumFeatures || noImprovement < MAX_ITERATIONS_WITHOUT_PROGRESS, maxNumFeatures);
1617
}
1718

1819
@Override
19-
public Set<Integer> select(double minimumAccuracy) {
20-
return select((accuracy, size) -> accuracy < minimumAccuracy);
20+
public Set<Integer> select() {
21+
// While the accuracy is still improving
22+
return select((noImprovement, size) -> noImprovement < MAX_ITERATIONS_WITHOUT_PROGRESS);
2123
}
2224

2325
public Set<Integer> select(Criteria criteria) {
26+
return select(criteria, instances.size());
27+
}
28+
29+
public Set<Integer> select(Criteria criteria, int maxNumFeatures) {
2430
// In this case we have no data to use, so return the empty set
2531
if (instances == null || instances.isEmpty()) return new HashSet<Integer>();
2632

2733
// To begin with all features are selected
28-
Set<Integer> selectedFeatures = getFeatures();
34+
Set<Integer> selectedFeatures = getAllFeatureIndices();
2935

3036
// Keep track of the best solution, so we never get worse
3137
double highestAccuracy = 0;
3238
Set<Integer> bestSoFar = new HashSet<>();
3339
double accuracy = objectiveFunction(selectedFeatures);
40+
double lastAccuracy = accuracy;
41+
42+
// Number of iterations with no improvement
43+
double noImprovement = 0;
3444

35-
while (criteria.evaluate(accuracy, selectedFeatures.size())){
45+
while (criteria.evaluate(noImprovement, selectedFeatures.size())){
3646
int feature = worst(selectedFeatures);
3747

3848
// No more valid features
@@ -43,11 +53,19 @@ public Set<Integer> select(Criteria criteria) {
4353

4454
accuracy = objectiveFunction(selectedFeatures);
4555

46-
if (accuracy > highestAccuracy) {
56+
// If this is the highest so far, and also valid (i.e < number of features required)
57+
if (accuracy > highestAccuracy && selectedFeatures.size() <= maxNumFeatures) {
4758
highestAccuracy = accuracy;
4859
// Make a copy, so we don't accidentally modify this
4960
bestSoFar = new HashSet<>(selectedFeatures);
5061
}
62+
63+
if (Double.compare(accuracy, lastAccuracy) <= 0){
64+
noImprovement++;
65+
} else{
66+
noImprovement = 0;
67+
}
68+
lastAccuracy = accuracy;
5169
}
5270

5371
return bestSoFar;

Diff for: SequentialFloatingForwardSelection.java

+24-7
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ public Set<Integer> select(int maxNumFeatures) {
1414
return select((accuracy, size) -> size < maxNumFeatures);
1515
}
1616

17-
public Set<Integer> select(double minimumAccuracy) {
18-
return select((accuracy, size) -> accuracy < minimumAccuracy);
17+
public Set<Integer> select() {
18+
return select((noImprovement, size) -> noImprovement < MAX_ITERATIONS_WITHOUT_PROGRESS);
1919
}
2020

2121
public Set<Integer> select(Criteria criteria) {
2222
// In this case we have no data to use, so return the empty set
2323
if (instances == null || instances.isEmpty()) return new HashSet<Integer>();
2424

2525
// To begin with no features are selected, so all the indices from 0..totalFeatures are remaining
26-
Set<Integer> remainingFeatures = getFeatures();
26+
Set<Integer> remainingFeatures = getAllFeatureIndices();
2727

2828
// Subset of only selected features indices
2929
Set<Integer> selectedFeatures = new HashSet<>();
@@ -34,24 +34,34 @@ public Set<Integer> select(Criteria criteria) {
3434
double highestAccuracy = 0;
3535
Set<Integer> bestSoFar = new HashSet<>();
3636
double accuracy = objectiveFunction(selectedFeatures);
37+
double lastAccuracy = accuracy;
38+
39+
// Number of iterations with no improvement
40+
double noImprovement = 0;
3741

3842
while (criteria.evaluate(accuracy, selectedFeatures.size())){
3943
int feature = best(selectedFeatures, remainingFeatures);
44+
45+
System.out.println("Selected features are:" + selectedFeatures);
46+
System.out.println("Remaining features are:" + remainingFeatures);
47+
System.out.println("Adding feature: " + feature);
48+
4049
// No more valid features
4150
if (feature == -1) break;
4251

4352
selectedFeatures.add(feature);
4453
// Remove the feature so we do not keep selecting the same one
4554
remainingFeatures.remove(feature);
4655

47-
double lastAccuracy = objectiveFunction(selectedFeatures);
56+
double accuracyBeforeRemoval = objectiveFunction(selectedFeatures);
57+
4858

4959
// Now remove the worst features, while we are improving
5060
while(true){
5161
int worstFeature = worst(selectedFeatures);
5262

5363
// No more valid features
54-
if (feature == -1) break;
64+
if (worstFeature == -1) break;
5565

5666
selectedFeatures.remove(worstFeature);
5767
// Feature becomes available again
@@ -62,14 +72,14 @@ public Set<Integer> select(Criteria criteria) {
6272
double newAccuracy = objectiveFunction(selectedFeatures);
6373

6474
// If the accuracy did not improve, undo this step and continue adding features
65-
if (newAccuracy < lastAccuracy) {
75+
if (newAccuracy < accuracyBeforeRemoval) {
6676
selectedFeatures.add(worstFeature);
6777
remainingFeatures.remove(worstFeature);
6878
System.out.println("Accuracy did not improve, so undoing above step");
6979
break;
7080
}
7181

72-
lastAccuracy = newAccuracy;
82+
accuracyBeforeRemoval = newAccuracy;
7383
}
7484

7585
accuracy = objectiveFunction(selectedFeatures);
@@ -80,6 +90,13 @@ public Set<Integer> select(Criteria criteria) {
8090
bestSoFar = new HashSet<>(selectedFeatures);
8191
}
8292

93+
if (Double.compare(accuracy, lastAccuracy) <= 0){
94+
noImprovement++;
95+
} else{
96+
noImprovement = 0;
97+
}
98+
lastAccuracy = accuracy;
99+
83100
}
84101

85102
return bestSoFar;

Diff for: SequentialForwardSelection.java

+15-4
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ public Set<Integer> select(int maxNumFeatures) {
1414
return select((accuracy, size) -> size < maxNumFeatures);
1515
}
1616

17-
public Set<Integer> select(double minimumAccuracy) {
18-
return select((accuracy, size) -> accuracy < minimumAccuracy);
17+
public Set<Integer> select() {
18+
return select((noImprovement, size) -> noImprovement < MAX_ITERATIONS_WITHOUT_PROGRESS);
1919
}
2020

2121
public Set<Integer> select(Criteria criteria) {
2222
// In this case we have no data to use, so return the empty set
2323
if (instances == null || instances.isEmpty()) return new HashSet<Integer>();
2424

2525
// To begin with no features are selected, so all the indices from 0..totalFeatures are remaining
26-
Set<Integer> remainingFeatures = getFeatures();
26+
Set<Integer> remainingFeatures = getAllFeatureIndices();
2727

2828
// Subset of only selected features indices
2929
Set<Integer> selectedFeatures = new HashSet<>();
@@ -34,8 +34,12 @@ public Set<Integer> select(Criteria criteria) {
3434
double highestAccuracy = 0;
3535
Set<Integer> bestSoFar = new HashSet<>();
3636
double accuracy = objectiveFunction(selectedFeatures);
37+
double lastAccuracy = accuracy;
3738

38-
while (criteria.evaluate(accuracy, selectedFeatures.size())){
39+
// Number of iterations with no improvement
40+
double noImprovement = 0;
41+
42+
while (criteria.evaluate(noImprovement, selectedFeatures.size())){
3943
int feature = best(selectedFeatures, remainingFeatures);
4044
// No more valid features
4145
if (feature == -1) break;
@@ -51,6 +55,13 @@ public Set<Integer> select(Criteria criteria) {
5155
// Make a copy, so we don't accidentally modify this
5256
bestSoFar = new HashSet<>(selectedFeatures);
5357
}
58+
59+
if (Double.compare(accuracy, lastAccuracy) <= 0){
60+
noImprovement++;
61+
} else{
62+
noImprovement = 0;
63+
}
64+
lastAccuracy = accuracy;
5465
}
5566

5667
return bestSoFar;

0 commit comments

Comments
 (0)