提交 a49e202b 编写于 作者: M Megvii Engine Team

feat(functional/loss): add reduction choices to loss functions

GitOrigin-RevId: a29e6bb4cfeda8a5d56a50985f9dbc2d1f1be515
上级 039727f8
......@@ -6,8 +6,11 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import numpy as np
from ..core.tensor.array_method import _reduce
from ..tensor import Tensor
from .elemwise import abs, log
from .nn import indexing_one_hot, logsigmoid, logsumexp, relu
......@@ -22,7 +25,26 @@ __all__ = [
]
def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
def _reduce_output(loss_fn):
r"""
Wrapper to apply canonical reductions to loss outputs.
"""
@functools.wraps(loss_fn)
def reduced_loss_fn(*args, reduction="mean", **kwargs):
loss = loss_fn(*args, **kwargs)
if reduction == "none":
return loss
elif reduction in ("mean", "sum"):
return _reduce(reduction)(loss)
else:
raise ValueError("{} is not a valid value for reduction".format(reduction))
return reduced_loss_fn
@_reduce_output
def l1_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
r"""
Calculates the mean absolute error (MAE) between
each element in the pred :math:`x` and label :math:`y`.
......@@ -43,6 +65,7 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
:param pred: predicted result from model.
:param label: ground truth to compare.
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
:return: loss value.
Examples:
......@@ -66,10 +89,11 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
"""
diff = pred - label
return abs(diff).mean()
return abs(diff)
def square_loss(pred: Tensor, label: Tensor) -> Tensor:
@_reduce_output
def square_loss(pred: Tensor, label: Tensor, reduction: str = "mean") -> Tensor:
r"""
Calculates the mean squared error (squared L2 norm) between
each element in the pred :math:`x` and label :math:`y`.
......@@ -90,6 +114,7 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor:
:param pred: predicted result from model.
:param label: ground truth to compare.
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
:return: loss value.
Shape:
......@@ -118,15 +143,17 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor:
"""
diff = pred - label
return (diff ** 2).mean()
return diff ** 2
@_reduce_output
def cross_entropy(
pred: Tensor,
label: Tensor,
axis: int = 1,
with_logits: bool = True,
label_smooth: float = 0,
reduction: str = "mean",
) -> Tensor:
r"""
Computes the multi-class cross entropy loss (using logits by default).
......@@ -148,6 +175,7 @@ def cross_entropy(
:param axis: an axis along which softmax will be applied. Default: 1
:param with_logits: whether to apply softmax first. Default: True
:param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
:return: loss value.
Examples:
......@@ -182,20 +210,21 @@ def cross_entropy(
ls = label_smooth
if with_logits:
logZ = logsumexp(pred, axis).mean()
primary_term = indexing_one_hot(pred, label, axis).mean()
logZ = logsumexp(pred, axis)
primary_term = indexing_one_hot(pred, label, axis)
else:
logZ = 0
primary_term = log(indexing_one_hot(pred, label, axis)).mean()
primary_term = log(indexing_one_hot(pred, label, axis))
if ls is None or type(ls) in (int, float) and ls == 0:
return logZ - primary_term
if not with_logits:
pred = log(pred)
return logZ - ls * pred.mean() - (1 - ls) * primary_term
return logZ - ls * pred.mean(axis) - (1 - ls) * primary_term
@_reduce_output
def binary_cross_entropy(
pred: Tensor, label: Tensor, with_logits: bool = True
pred: Tensor, label: Tensor, with_logits: bool = True, reduction: str = "mean",
) -> Tensor:
r"""
Computes the binary cross entropy loss (using logits by default).
......@@ -206,6 +235,7 @@ def binary_cross_entropy(
:param pred: `(N, *)`, where `*` means any number of additional dimensions.
:param label: `(N, *)`, same shape as the input.
:param with_logits: bool, whether to apply sigmoid first. Default: True
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
:return: loss value.
Examples:
......@@ -229,13 +259,16 @@ def binary_cross_entropy(
"""
if not with_logits:
return -(label * log(pred) + (1 - label) * log(1 - pred)).mean()
return -(label * log(pred) + (1 - label) * log(1 - pred))
# logsigmoid(pred) and logsigmoid(-pred) has common sub-expression
# hopefully the backend would optimize this
return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred)).mean()
return -(label * logsigmoid(pred) + (1 - label) * logsigmoid(-pred))
def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
@_reduce_output
def hinge_loss(
pred: Tensor, label: Tensor, norm: str = "L1", reduction: str = "mean"
) -> Tensor:
r"""
Caculates the hinge loss which is often used in SVM.
......@@ -246,6 +279,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
:param pred: input tensor representing the predicted probability, shape is `(N, C)`.
:param label: input tensor representing the binary classification label, shape is `(N, C)`.
:param norm: specify the norm to caculate the loss, should be "L1" or "L2".
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
:return: loss value.
Examples:
......@@ -272,6 +306,6 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
# Converts binary labels to -1/1 labels.
loss = relu(1.0 - pred * label)
if norm == "L1":
return loss.sum(axis=1).mean()
return loss.sum(axis=1)
else:
return (loss ** 2).sum(axis=1).mean()
return (loss ** 2).sum(axis=1)
......@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest
import megengine.functional as F
from megengine import tensor
......@@ -43,3 +44,38 @@ def test_cross_entropy():
l_ref = ref(x, y)
l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False)
np.testing.assert_allclose(l.numpy(), l_ref)
def test_cross_entropy_reduction():
logits = np.random.randn(16, 10)
label = np.random.randint(10, size=[16])
logits = tensor(logits, dtype="float32")
label = tensor(label, dtype="int32")
perm = np.random.permutation(16)
logits_perm = tensor(logits[perm], dtype="float32")
label_perm = tensor(label[perm], dtype="int32")
loss = F.nn.cross_entropy(logits, label, reduction="none")
loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none")
np.testing.assert_allclose(loss.numpy()[perm], loss_perm.numpy())
loss_sum = F.nn.cross_entropy(logits, label, reduction="sum")
np.testing.assert_allclose(loss.numpy().sum(), loss_sum.numpy(), rtol=2e-7)
loss_mean = F.nn.cross_entropy(logits, label, reduction="mean")
np.testing.assert_allclose(loss_mean.numpy(), loss_sum.numpy() / 16)
loss_ls = F.nn.cross_entropy(logits, label, reduction="mean", label_smooth=0.1)
loss_ls_none_reduce = F.nn.cross_entropy(
logits, label, reduction="none", label_smooth=0.1
)
np.testing.assert_allclose(
loss_ls.numpy(), loss_ls_none_reduce.numpy().mean(), rtol=2e-7
)
with pytest.raises(ValueError):
F.nn.cross_entropy(logits, label, reduction="MEAN")
with pytest.raises(ValueError):
F.nn.cross_entropy(logits, label, reduction="max")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册