Skip to content

Commit da90aa1

Browse files
Minimal KNN
1 parent 193032c commit da90aa1

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

mlfromscratch/supervised_learning/k_nearest_neighbors.py

+9-17
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,21 @@ class KNN():
1414
def __init__(self, k=5):
1515
self.k = k
1616

17-
def _vote(self, neighbors):
17+
def _vote(self, neighbor_labels):
1818
""" Return the most common class among the neighbor samples """
19-
counts = np.bincount(neighbors[:, 1].astype('int'))
19+
counts = np.bincount(neighbor_labels.astype('int'))
2020
return counts.argmax()
2121

2222
def predict(self, X_test, X_train, y_train):
2323
y_pred = np.empty(X_test.shape[0])
2424
# Determine the class of each sample
2525
for i, test_sample in enumerate(X_test):
26-
# Two columns [distance, label], for each observed sample
27-
neighbors = np.empty((X_train.shape[0], 2))
28-
# Calculate the distance from each observed sample to the
29-
# sample we wish to predict
30-
for j, observed_sample in enumerate(X_train):
31-
distance = euclidean_distance(test_sample, observed_sample)
32-
label = y_train[j]
33-
# Add neighbor information
34-
neighbors[j] = [distance, label]
35-
# Sort the list of observed samples from lowest to highest distance
36-
# and select the k first
37-
k_nearest_neighbors = neighbors[neighbors[:, 0].argsort()][:self.k]
38-
# Get the most common class among the neighbors
39-
label = self._vote(k_nearest_neighbors)
40-
y_pred[i] = label
26+
# Sort the training samples by their distance to the test sample and get the K nearest
27+
idx = np.argsort([euclidean_distance(test_sample, x) for x in X_train])[:self.k]
28+
# Extract the labels of the K nearest neighboring training samples
29+
k_nearest_neighbors = np.array([y_train[i] for i in idx])
30+
# Label sample as the most common class label
31+
y_pred[i] = self._vote(k_nearest_neighbors)
32+
4133
return y_pred
4234

mlfromscratch/unsupervised_learning/dcgan.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,6 @@ def save_imgs(self, epoch):
168168

169169
if __name__ == '__main__':
170170
dcgan = DCGAN()
171-
dcgan.train(epochs=200000, batch_size=32, save_interval=50)
171+
dcgan.train(epochs=200000, batch_size=64, save_interval=50)
172172

173173

0 commit comments

Comments
 (0)