Skip to content

Commit 78b7dbf

Browse files
committed
Update
1 parent f96d8ab commit 78b7dbf

File tree

5 files changed

+39
-16
lines changed

5 files changed

+39
-16
lines changed

examples/classification/IMDbSetsTextCategorizationDemo.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn.feature_selection import chi2
77
from sklearn.feature_extraction.text import CountVectorizer
88
from tmu.models.classification.vanilla_classifier import TMClassifier
9+
from scipy.sparse import csr_matrix
910

1011
from tmu.tools import BenchmarkTimer
1112

@@ -63,26 +64,33 @@ def main(args):
6364
tokenizer=lambda s: s,
6465
token_pattern=None,
6566
ngram_range=(1, args.max_ngram),
67+
max_features=100000,
6668
lowercase=False,
6769
binary=True
6870
)
6971

70-
X_train = vectorizer_X.fit_transform(training_documents)
72+
X_train = vectorizer_X.fit_transform(training_documents).astype(np.uint32)
7173
Y_train = train_y.astype(np.uint32)
7274

73-
X_test = vectorizer_X.transform(testing_documents)
75+
X_test = vectorizer_X.transform(testing_documents).astype(np.uint32)
7476
Y_test = test_y.astype(np.uint32)
7577
_LOGGER.info("Producing bit representation... Done!")
7678

7779
_LOGGER.info("Selecting Features....")
7880

79-
SKB = SelectKBest(chi2, k=args.features)
80-
SKB.fit(X_train, Y_train)
81+
#SKB = SelectKBest(chi2, k=args.features)
82+
#SKB.fit(X_train, Y_train)
8183

82-
selected_features = SKB.get_support(indices=True)
83-
X_train = SKB.transform(X_train).astype(np.uint32)
84-
X_test = SKB.transform(X_test).astype(np.uint32)
84+
selected_features = np.arange(args.features)
85+
#selected_features = SKB.get_support(indices=True)
86+
#X_train = SKB.transform(X_train).astype(np.uint32)
87+
#X_test = SKB.transform(X_test).astype(np.uint32)
8588

89+
documents = [["movie", "all"], ["very", "good"], ["love", "the", "book"]]
90+
print(documents)
91+
concepts = vectorizer_X.transform(documents)
92+
print(concepts)
93+
8694
_LOGGER.info("Selecting Features.... Done!")
8795

8896
tm = TMClassifier(
@@ -91,7 +99,8 @@ def main(args):
9199
args.s,
92100
platform=args.platform,
93101
weighted_clauses=args.weighted_clauses,
94-
clause_drop_p=args.clause_drop_p
102+
clause_drop_p=args.clause_drop_p,
103+
sets=concepts#csr_matrix([[1,8],[0,1],[15,128]])
95104
)
96105

97106
for e in range(args.epochs):

tmu/clause_bank/clause_bank_sets.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
class ClauseBankSets(BaseClauseBank):
3333
def __init__(
3434
self,
35+
sets,
3536
seed: int,
3637
number_of_states,
3738
d: float,
@@ -50,6 +51,10 @@ def __init__(
5051
self.batching = batching
5152
self.incremental = incremental
5253

54+
self.sets = sets
55+
self.number_of_sets = self.sets.shape[0]
56+
print(self.sets, self.number_of_sets)
57+
5358
self.d = d
5459

5560
LOGGER.warning("reuse_random_feedback is not implemented yet")
@@ -136,6 +141,7 @@ def calculate_clause_outputs_predict(self, encoded_X, e):
136141
return self.clause_output
137142

138143
if e % 32 == 0:
144+
139145
lib.cbse_pack_X(
140146
ffi.cast("int *", encoded_X[0].indptr.ctypes.data),
141147
ffi.cast("int *", encoded_X[0].indices.ctypes.data),
@@ -144,16 +150,16 @@ def calculate_clause_outputs_predict(self, encoded_X, e):
144150
self.ptr_packed_X,
145151
self.number_of_literals
146152
)
153+
147154
lib.cbse_calculate_clause_outputs_predict_packed_X(
148155
self.ptr_packed_X,
149156
self.number_of_clauses,
150157
self.number_of_literals,
151158
self.ptr_clause_output_batch,
152159
self.ptr_clause_bank_included,
153-
self.ptr_clause_bank_included_length,
154-
# self.cbia_p,
155-
# self.cbial_p
160+
self.ptr_clause_bank_included_length
156161
)
162+
157163
lib.cbse_unpack_clause_output(
158164
e,
159165
self.ptr_clause_output,
@@ -241,7 +247,7 @@ def type_ii_feedback(
241247
self.ptr_Xi,
242248
self.number_of_features
243249
)
244-
250+
245251
lib.cbse_type_ii_feedback(
246252
update_p,
247253
ffi.cast("int *", clause_active.ctypes.data),

tmu/lib/src/ClauseBankSets.c

+6-3
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ void cbse_type_i_feedback(
212212
continue;
213213
}
214214

215+
// Calculate intersection of input and included sets...
215216
int clause_pos_base = j*number_of_literals*2;
216-
217217
int clause_output = 1;
218218
for (int k = 0; k < clause_bank_included_length[j]; ++k) {
219219
unsigned int clause_pos = clause_pos_base + k*2;
@@ -225,6 +225,10 @@ void cbse_type_i_feedback(
225225
}
226226
}
227227

228+
// Clause output is 1 if pop count is > 0
229+
230+
// Calculate intersection with each set, which becomes the truth values for updating...
231+
228232
if (clause_output && (clause_bank_included_length[j] <= max_included_literals)) {
229233
// Update state of included literals
230234
for (int k = 0; k < clause_bank_included_length[j]; ++k) {
@@ -344,8 +348,7 @@ void cbse_type_i_feedback(
344348
int clause_included_end_pos = clause_pos_base + clause_bank_included_length[j]*2;
345349
clause_bank_included[clause_included_pos] = clause_bank_included[clause_included_end_pos];
346350
clause_bank_included[clause_included_pos + 1] = clause_bank_included[clause_included_end_pos + 1];
347-
}
348-
351+
}
349352
}
350353
}
351354
}

tmu/models/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
absorbing_include=None,
9999
absorbing_exclude=None,
100100
squared_weight_update_p=False,
101+
sets=None,
101102
seed=None
102103
):
103104
self.seed = seed
@@ -107,7 +108,8 @@ def __init__(
107108
self.number_of_state_bits_ind = number_of_state_bits_ind
108109
self.T = int(T)
109110
self.s = s
110-
111+
self.sets = sets
112+
111113
self.confidence_driven_updating = confidence_driven_updating
112114
self.type_i_ii_ratio = type_i_ii_ratio
113115
if type_i_ii_ratio >= 1.0:
@@ -288,6 +290,7 @@ def _build_cpu_sets_bank(self, X: np.ndarray):
288290
from tmu.clause_bank.clause_bank_sets import ClauseBankSets
289291
clause_bank_type = ClauseBankSets
290292
clause_bank_args = dict(
293+
sets=self.sets,
291294
X_shape=X.shape,
292295
d=self.d,
293296
s=self.s,

tmu/models/classification/vanilla_classifier.py

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
literal_sampling=1.0,
6161
feedback_rate_excluded_literals=1,
6262
literal_insertion_state=-1,
63+
sets=None,
6364
seed=None
6465
):
6566
super().__init__(
@@ -92,6 +93,7 @@ def __init__(
9293
literal_sampling=literal_sampling,
9394
feedback_rate_excluded_literals=feedback_rate_excluded_literals,
9495
literal_insertion_state=literal_insertion_state,
96+
sets=sets,
9597
seed=seed
9698
)
9799
MultiClauseBankMixin.__init__(self, seed=seed)

0 commit comments

Comments
 (0)