diff --git a/oneflow/core/autograd/gradient_funcs/reshape.cpp b/oneflow/core/autograd/gradient_funcs/reshape.cpp index 313c3dd1ec7a96b9b4291cccdd5eb56e35106955..d504908e8d53bc2c4c578c726bab097c6cdf5f51 100644 --- a/oneflow/core/autograd/gradient_funcs/reshape.cpp +++ b/oneflow/core/autograd/gradient_funcs/reshape.cpp @@ -24,7 +24,11 @@ limitations under the License. namespace oneflow { namespace one { -class ReshapeOpExprGrad : public OpExprGradFunction { +struct ReshapeCaptureState : public AutoGradCaptureState { + DimVector input_shape_vec; +}; + +class ReshapeOpExprGrad : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override { const auto* fw_op_expr = dynamic_cast(&op); @@ -32,17 +36,18 @@ class ReshapeOpExprGrad : public OpExprGradFunction { return Maybe::Ok(); } - Maybe Capture(AutoGradCaptureState* ctx, const TensorTuple& inputs, + Maybe Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { - ctx->SaveTensorForBackward(inputs.at(0)); + ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec(); return Maybe::Ok(); } - Maybe Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads, + Maybe Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { const auto& saved_tensors = ctx->SavedTensors(); in_grads->resize(1); - in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), saved_tensors.at(0))); + Shape shape(ctx->input_shape_vec); + in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape)); return Maybe::Ok(); } }; diff --git a/oneflow/core/common/just.h b/oneflow/core/common/just.h index 44f4af1b16e7a7f4ae4d4af48ae313474d407348..e26f097e0d3c37ba02a7dd1cfb120a9782a19839 100644 --- a/oneflow/core/common/just.h +++ b/oneflow/core/common/just.h @@ -90,62 +90,62 @@ typename std::remove_const::type>::type&& Remo #if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__) -#define JUST(...) \ - ::oneflow::private_details::RemoveRValConst(({ \ - auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ - if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \ - return ::oneflow::private_details::JustErrorAddStackFrame( \ - ::oneflow::private_details::JustGetError(value_to_check_), __FILE__, __LINE__, \ - __FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \ - } \ - std::forward(value_to_check_); \ +#define JUST(...) \ + ::oneflow::private_details::RemoveRValConst(({ \ + auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ + if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ + return ::oneflow::private_details::JustErrorAddStackFrame( \ + ::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \ + __FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \ + } \ + std::forward(_just_value_to_check_); \ })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() -#define CHECK_JUST(...) \ - ([&](const char* func_name) { \ - auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ - if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \ - LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \ - ::oneflow::private_details::JustErrorAddStackFrame( \ - ::oneflow::private_details::JustGetError(value_to_check_), __FILE__, __LINE__, \ - func_name, OF_PP_STRINGIZE(__VA_ARGS__))); \ - } \ - return std::forward(value_to_check_); \ - })(__FUNCTION__) \ +#define CHECK_JUST(...) \ + ([&](const char* _just_closure_func_name_) { \ + auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ + if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ + LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \ + ::oneflow::private_details::JustErrorAddStackFrame( \ + ::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \ + _just_closure_func_name_, OF_PP_STRINGIZE(__VA_ARGS__))); \ + } \ + return std::forward(_just_value_to_check_); \ + })(__FUNCTION__) \ .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() -#define JUST_MSG(value, ...) \ - ::oneflow::private_details::RemoveRValConst(({ \ - auto&& value_to_check_ = (value); \ - if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \ - return ::oneflow::private_details::JustErrorAddMessage( \ - ::oneflow::Error(::oneflow::private_details::JustGetError(value_to_check_)) \ - .AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \ - OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \ - } \ - std::forward(value_to_check_); \ +#define JUST_MSG(value, ...) \ + ::oneflow::private_details::RemoveRValConst(({ \ + auto&& _just_value_to_check_ = (value); \ + if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ + return ::oneflow::private_details::JustErrorAddMessage( \ + ::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \ + .AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \ + OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \ + } \ + std::forward(_just_value_to_check_); \ })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() -#define CHECK_JUST_MSG(value, ...) \ - ([&](const char* func_name) { \ - auto&& value_to_check_ = (value); \ - if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \ - LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \ - ::oneflow::private_details::JustErrorAddMessage( \ - ::oneflow::Error(::oneflow::private_details::JustGetError(value_to_check_)) \ - .AddStackFrame(__FILE__, __LINE__, func_name), \ - OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \ - .error_proto()); \ - } \ - return std::forward(value_to_check_); \ - })(__FUNCTION__) \ +#define CHECK_JUST_MSG(value, ...) \ + ([&](const char* _just_closure_func_name_) { \ + auto&& _just_value_to_check_ = (value); \ + if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ + LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \ + ::oneflow::private_details::JustErrorAddMessage( \ + ::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \ + .AddStackFrame(__FILE__, __LINE__, _just_closure_func_name_), \ + OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \ + .error_proto()); \ + } \ + return std::forward(_just_value_to_check_); \ + })(__FUNCTION__) \ .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() -#define JUST_OPT(...) \ - ::oneflow::private_details::RemoveRValConst(({ \ - auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ - if (!value_to_check_.has_value()) { return NullOpt; } \ - std::forward(value_to_check_); \ +#define JUST_OPT(...) \ + ::oneflow::private_details::RemoveRValConst(({ \ + auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ + if (!_just_value_to_check_.has_value()) { return NullOpt; } \ + std::forward(_just_value_to_check_); \ })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() #else