未验证 提交 fe31e54e 编写于 作者: O oneflow-ci-bot 提交者: GitHub

Merge branch 'master' into save_load_by_pickle

......@@ -24,7 +24,11 @@ limitations under the License.
namespace oneflow {
namespace one {
class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> {
struct ReshapeCaptureState : public AutoGradCaptureState {
DimVector input_shape_vec;
};
class ReshapeOpExprGrad : public OpExprGradFunction<ReshapeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
......@@ -32,17 +36,18 @@ class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> {
return Maybe<void>::Ok();
}
Maybe<void> Capture(AutoGradCaptureState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->SaveTensorForBackward(inputs.at(0));
ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& saved_tensors = ctx->SavedTensors();
in_grads->resize(1);
in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), saved_tensors.at(0)));
Shape shape(ctx->input_shape_vec);
in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape));
return Maybe<void>::Ok();
}
};
......
......@@ -173,6 +173,12 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
.Attr<float>("beta1", user_op_conf.attr<float>("beta1"))
.Attr<float>("beta2", user_op_conf.attr<float>("beta2"))
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"));
if (user_op_conf.has_input("bias_correction1", 0)) {
fused_op_builder.Input("bias_correction1", user_op_conf.input("bias_correction1", 0));
}
if (user_op_conf.has_input("bias_correction2", 0)) {
fused_op_builder.Input("bias_correction2", user_op_conf.input("bias_correction2", 0));
}
} else if (user_op_conf.op_type_name() == "rmsprop_update") {
const bool centered = user_op_conf.attr<bool>("centered");
fused_op_builder.Input("mean_square", user_op_conf.input("mean_square", 0.f))
......
......@@ -156,7 +156,6 @@ import oneflow.framework.register_python_callback
INVALID_SPLIT_AXIS = oneflow._oneflow_internal.INVALID_SPLIT_AXIS
register_class_method_util.RegisterMethod4Class()
oneflow._oneflow_internal.RegisterGILForeignLockHelper()
import oneflow.framework.env_util as env_util
import oneflow.framework.scope_util as scope_util
import oneflow.framework.session_context as session_ctx
......@@ -166,6 +165,7 @@ if not env_util.HasAllMultiClientEnvVars():
env_util.SetDefaultMultiClientEnvVars()
oneflow._oneflow_internal.SetIsMultiClient(True)
env_util.api_env_init()
oneflow._oneflow_internal.RegisterGILForeignLockHelper()
oneflow._oneflow_internal.InitDefaultConsistentTransportTokenScope()
session_ctx.OpenDefaultSession(
MultiClientSession(oneflow._oneflow_internal.NewSessionId())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册