From dba5d710c5c2aa6b89affc9c2a84ccef6383e1f2 Mon Sep 17 00:00:00 2001 From: wit-df Date: Mon, 3 Jul 2023 16:25:21 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=20linear=5Fregression.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- linear_regression.py | 32 -------------------------------- 1 file changed, 32 deletions(-) delete mode 100644 linear_regression.py diff --git a/linear_regression.py b/linear_regression.py deleted file mode 100644 index 429036d..0000000 --- a/linear_regression.py +++ /dev/null @@ -1,32 +0,0 @@ -import numpy as np - -# 线性回归(正规方程法) -class LinearRegression_Equation: - # 根据公式计算w,是n*1的2维矩阵 - def fit(self, X, y): - self.w = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y) - - # 根据公式(4.1)计算预测结果,是m*1的2维矩阵 - def predict(self, X): - return X.dot(self.w) - -# 线性回归(梯度下降法) -class LinearRegression_GD: - def __init__(self, - N = 1000, # 训练轮数 - lr = 0.01 # 学习率 - ) -> None: - self.N = N - self.lr = lr - - def fit(self, X, y): - m, n = X.shape - w = np.zeros((n, 1)) - y = y.reshape(-1, 1) - for _ in range(self.N): - gradient = 1 / m * X.T.dot(X.dot(w) - y) - w -= self.lr * gradient - self.w = w - - def predict(self, X): - return X.dot(self.w) \ No newline at end of file -- GitLab