提交 de26ff81 编写于 作者: wit-df's avatar wit-df

上传新文件

上级 d9e81c46
import numpy as np
from data_proce import one_hot
class Knn:
def __init__(self, k=3) -> None:
self.k = k
def fit(self, X, y):
self.X = X
self.y = one_hot(y)
def predict_proba(self, X_test):
if np.ndim(X_test) == 1:
X_test = np.expand_dims(X_test, axis=0)
result = []
for i in range(len(X_test)):
dist = (self.X - X_test[i]) ** 2
dist = dist.sum(axis=1)
samples_id = np.argsort(dist)[:self.k]
result.append(self.y[samples_id].mean(axis=0))
return np.array(result)
def predict(self, X_test):
return self.predict_proba(X_test).argmax(axis=1)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册