diff --git a/knn.py b/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..37f43515871567534055f0debb19b93d3e1db08a --- /dev/null +++ b/knn.py @@ -0,0 +1,24 @@ +import numpy as np +from data_proce import one_hot + +class Knn: + def __init__(self, k=3) -> None: + self.k = k + + def fit(self, X, y): + self.X = X + self.y = one_hot(y) + + def predict_proba(self, X_test): + if np.ndim(X_test) == 1: + X_test = np.expand_dims(X_test, axis=0) + result = [] + for i in range(len(X_test)): + dist = (self.X - X_test[i]) ** 2 + dist = dist.sum(axis=1) + samples_id = np.argsort(dist)[:self.k] + result.append(self.y[samples_id].mean(axis=0)) + return np.array(result) + + def predict(self, X_test): + return self.predict_proba(X_test).argmax(axis=1) \ No newline at end of file