diff --git a/python_module/megengine/functional/nn.py b/python_module/megengine/functional/nn.py index 93fd66b729eaa04aa0a4cf8b808e8cbe1e4a3e97..adc616f10af0f4d7c22d004b455a469cee96516d 100644 --- a/python_module/megengine/functional/nn.py +++ b/python_module/megengine/functional/nn.py @@ -443,7 +443,7 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: import megengine.functional as F inp = tensor(np.arange(1, 4, dtype=np.int32)) - out = F.one_hot(inp) + out = F.one_hot(inp, num_classes=4) print(out.numpy()) Outputs: