diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 41c77c1ead045fd79eb43d7d8bd7e4472ebf58c1..6a4fadd130436351181dd91097b12f744cc82b00 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -64,7 +64,7 @@ void ComputeInterceptor::IncreaseReady(int64_t up_id) { // source node has no upstream, data_is_ready is send by carrier or others if (is_source_ && up_id == -1) { - it->second.second = GetTaskNode()->max_run_times(); + it->second.second += GetTaskNode()->max_run_times(); return; } @@ -121,16 +121,6 @@ bool ComputeInterceptor::CanWriteOutput() { return true; } -// only source node need reset -bool ComputeInterceptor::ShouldReset() { - if (is_source_ && step_ == node_->max_run_times()) { - VLOG(3) << "Interceptor " << GetInterceptorId() - << " should reset for step: " << step_ << "."; - return true; - } - return false; -} - void ComputeInterceptor::SendDataReadyToDownStream() { for (auto& outs : out_buffs_) { auto down_id = outs.first; @@ -186,24 +176,7 @@ void ComputeInterceptor::RunOps() { } void ComputeInterceptor::Run() { - // If there is no limit, source interceptor can be executed - // an unlimited number of times. - // Now source node can only run max_run_times. - if (ShouldReset()) { - for (auto& out_buff : out_buffs_) { - // buffer is using - if (out_buff.second.second != 0) { - VLOG(3) << "Interceptor " << GetInterceptorId() - << " out buffer for downstream: " << out_buff.first - << "'s counter is: " << out_buff.second.second - << ". Cannot be reset."; - return; - } - } - step_ = 0; // reset - } - - while (IsInputReady() && CanWriteOutput() && !ShouldReset()) { + while (IsInputReady() && CanWriteOutput()) { VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; RunOps(); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index ae253f844aab4ea1b30f4e05857c04635a3a9e7b..fb82ce76c7bdb851c32b1959121059cfca041b94 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -39,7 +39,6 @@ class ComputeInterceptor : public Interceptor { void DecreaseBuff(int64_t down_id); bool IsInputReady(); bool CanWriteOutput(); - bool ShouldReset(); void Run(); void Compute(const InterceptorMessage& msg);