提交 611b8fde 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!14 bugfix for logistic regression: code in md, and some math

Merge pull request !14 from dyonghan/logistic-regression
......@@ -62,7 +62,7 @@ Iris数据集是模式识别最著名的数据集之一。数据集包含3类,
概括统计:
```
Min Max Mean SD Class Correlation
Min Max Mean SD Class Correlation
sepal length: 4.3 7.9 5.84 0.83 0.7826
sepal width: 2.0 4.4 3.05 0.43 -0.4194
petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
......@@ -185,12 +185,12 @@ plt.ylabel('p')
对于每个样本$N_i$,模型的计算方式如下:
$$
Z_i = W * X_i + b \\
Z_i = W \cdot X_i + b \\
P_{i} = sigmoid(Z_{i}) = \frac{1}{1 + e^{-Z_{i}}} \\
loss = -\frac{1}n\sum_i[Y_{i} * ln(P_{i}) + (1 - Y_{i})ln(1 - P_{i})]
$$
其中,$X_i$是1D Tensor(含4个元素),$Z_i$是1D Tensor(含1个元素),$Y_i$是真实类别(2个类别{0, 1}中的一个),$P_i$是1D Tensor(含1个元素,表示属于类别1的概率,值域为[0, 1])。
其中,$X_i$是1D Tensor(含4个元素),$Z_i$是1D Tensor(含1个元素),$Y_i$是真实类别(2个类别{0, 1}中的一个),$P_i$是1D Tensor(含1个元素,表示属于类别1的概率,值域为[0, 1]),$loss$是标量
```python
......@@ -225,7 +225,7 @@ model.train(5, ds_train, callbacks=[LossMonitor(per_print_times=ds_train.get_dat
然后计算模型在测试集上精度,测试集上的精度达到了1.0左右,即逻辑回归模型学会了区分2类鸢尾花。
```python
x = model.predict(ms.Tensor(X_test))
x = model.predict(ms.Tensor(X_test)).asnumpy()
pred = np.round(1 / (1 + np.exp(-x)))
correct = np.equal(pred, Y_test)
acc = np.mean(correct)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册