#include "caffe/util/math_functions.hpp" namespace caffe { template __global__ void NesterovUpdate(int N, Dtype* g, Dtype* h, Dtype momentum, Dtype local_rate) { CUDA_KERNEL_LOOP(i, N) { float hi = h[i]; float hi_new = h[i] = momentum * hi + local_rate * g[i]; g[i] = (1+momentum) * hi_new - momentum * hi; } } template void nesterov_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum, Dtype local_rate) { NesterovUpdate // NOLINT_NEXT_LINE(whitespace/operators) <<>>( N, g, h, momentum, local_rate); CUDA_POST_KERNEL_CHECK; } template void nesterov_update_gpu(int, float*, float*, float, float); template void nesterov_update_gpu(int, double*, double*, double, double); } // namespace caffe