提交 20a6f920 编写于 作者: C chengtbf

fix conflict


Former-commit-id: 2461f94bfcc57716b5a9bc1222c438ec36c1900d
......@@ -114,11 +114,12 @@ if(WIN32)
else()
# build oneflow.run
add_custom_target(copy_raw_oneflow_run
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PROJECT_SOURCE_DIR}/scripts/oneflow.run ${PROJECT_BINARY_DIR})
add_custom_target(oneflow_run ALL
COMMAND tar zcf - compiler runtime scheduler -C ${PROJECT_BINARY_DIR} >> ${PROJECT_BINARY_DIR}/oneflow.run
DEPENDS ${main_targets} copy_raw_oneflow_run)
#add_custom_target(copy_raw_oneflow_run
# COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PROJECT_SOURCE_DIR}/scripts/oneflow.run ${PROJECT_BINARY_DIR})
#
#add_custom_target(oneflow_run ALL
# COMMAND tar zcf - compiler runtime scheduler -C ${PROJECT_BINARY_DIR} >> ${PROJECT_BINARY_DIR}/oneflow.run
# DEPENDS ${main_targets} copy_raw_oneflow_run)
endif()
# build test
......
......@@ -141,12 +141,13 @@ void FwDataCompActor::Act() {
in_.pop();
mut_num_of_read_empty() = in_.empty();
}
if (expected_piece_id() == JobDesc::Singleton()->total_piece_num()) {
in_desc_id_ = -2;
AsyncSendMsgToModelAndModelTmpProducer();
AsyncSendEORDMsgForAllProducedRegstDesc();
TrySwitchToZombie();
}
TODO();
// if (expected_piece_id() == JobDesc::Singleton()->total_piece_num()) {
// in_desc_id_ = -2;
// AsyncSendMsgToModelAndModelTmpProducer();
// AsyncSendEORDMsgForAllProducedRegstDesc();
// TrySwitchToZombie();
//}
}
REGISTER_ACTOR(kDataCompTask, true, FwDataCompActor);
......
......@@ -26,6 +26,10 @@ class BalancedSplitter final {
Range At(int64_t idx) const;
int64_t BaseBeginIdx() const { return change_pos_; }
int64_t BasePartSize() const { return size_per_range_; }
int64_t BiggerPartSize() const { return size_per_range_ + 1; }
private:
int64_t size_per_range_;
int64_t change_pos_;
......
......@@ -139,6 +139,8 @@ inline uint32_t NewRandomSeed() {
(ParallelPolicy::kModelParallel)(ParallelPolicy::kDataParallel)
#define FOR_RANGE(type, i, begin, end) for (type i = begin; i < end; ++i)
#define FOR_EACH(it, container) \
for (auto it = container.begin(); it != container.end(); ++it)
void RedirectStdoutAndStderrToGlogDir();
void CloseStdoutAndStderr();
......
#include "oneflow/core/graph/loss_accumulate_compute_task_node.h"
namespace oneflow {
void AccCompTaskNode::ProduceAllRegstsAndBindEdges() {
ProduceRegst("acc", 1, kMaxRegisterNum);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_ACCUMULATE_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_ACCUMULATE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class AccCompTaskNode : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(AccCompTaskNode);
AccCompTaskNode() = default;
virtual ~AccCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_ACCUMULATE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/backward_compute_task_node.h"
#include "oneflow/core/graph/chain_graph.h"
namespace oneflow {
void BackwardCompTaskNode::ProduceAllRegstsAndBindEdges() {
bool need_in_diff = false;
chain_node()->ForEachNodeOnOutEdge([&](const ChainNode* out_node) {
if (dynamic_cast<const BackwardChainNode*>(out_node)) {
need_in_diff = true;
}
});
if (need_in_diff) { ProduceRegst("in_diff", 1, kMaxRegisterNum); }
ProduceRegst("model_diff", 1, kMaxRegisterNum);
ProduceRegst("activation_diff", 1, 1);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_BACKWARD_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_BACKWARD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class BackwardCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(BackwardCompTaskNode);
BackwardCompTaskNode() = default;
~BackwardCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
TodoTaskType GetTaskType() const override { return TodoTaskType::kBackward; }
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_BACKWARD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/boxing_task_node.h"
#include "oneflow/core/operator/boxing_op.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/graph/chain_node.h"
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
namespace {
void FwCompleteBoxOpConfDataData(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(0);
conf->mutable_data_split_box();
}
void FwCompleteBoxOpConfDataModel(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(0);
conf->mutable_clone_box();
void BoxingTaskNode::Init(int64_t machine_id) {
set_machine_id(machine_id);
set_thrd_loc_id(IDMgr::Singleton()->BoxingThrdLocId());
}
void FwCompleteBoxOpConfModelData(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(1);
conf->mutable_data_split_box();
void BoxingTaskNode::ProduceAllRegstsAndBindEdges() {
for (TaskEdge* out_edge : out_edges()) {
std::string name = "boxing_out_" + std::to_string(out_edge->edge_id());
auto out_regst = ProduceRegst(name, 1, kMaxRegisterNum);
out_edge->AddRegst(name, out_regst);
}
ProduceRegst("middle", 1, 1);
}
void FwCompleteBoxOpConfModelModel(BoxingOpConf* conf) {
conf->mutable_concat_box()->set_axis(1);
conf->mutable_clone_box();
void BoxingTaskNode::ConsumeAllRegsts() {
for (TaskEdge* in_edge : in_edges()) {
std::string name = "boxing_in_" + std::to_string(in_edge->edge_id());
auto in_regst = in_edge->GetSoleRegst();
ConsumeRegst(name, in_regst);
}
}
void FwCompleteBoxOpConfAddClone(BoxingOpConf* conf) {
conf->mutable_add_box();
conf->mutable_clone_box();
void BoxingTaskNode::Build() {
HashMap<const ChainNode*, std::vector<EdgeInfo>> in_chain2edge_info;
InitChain2SortedEdgeInfo(&TaskNode::in_edges, &TaskNode::SoleInEdge,
&TaskEdge::src_node, &in_chain2edge_info);
HashMap<const ChainNode*, std::vector<EdgeInfo>> out_chain2edge_info;
InitChain2SortedEdgeInfo(&TaskNode::out_edges, &TaskNode::SoleOutEdge,
&TaskEdge::dst_node, &out_chain2edge_info);
for (const auto& in_pair : in_chain2edge_info) {
for (const auto& out_pair : out_chain2edge_info) {
BuildWithChainPair(in_pair.first, in_pair.second, out_pair.first,
out_pair.second);
}
}
}
} // namespace
#define DEFINE_BLD_BOXING_OP_CONF_METHOD(x, y) \
void x::BldBoxingOpConfWith##y( \
const std::string& lbn, const std::vector<EdgeInfo>& sorted_in_edges, \
int64_t in_parallel_num, int64_t in_edge_first, int64_t in_edge_last, \
const std::vector<EdgeInfo>& sorted_out_edges, int64_t out_parallel_num, \
int64_t* used_out_edge_begin, BoxingOpConf* conf)
void BoxingTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
EnrollAllRegstAndBindRelatedEdge();
FwVirtualBuild();
}
void BoxingTaskNode::EnrollAllRegstAndBindRelatedEdge() {
for (TaskEdge* edge : in_edges()) {
std::string name = "boxing_in_" + edge->edge_id_str();
ConsumeRegstDesc(name, GetRelatedRegst(edge));
}
for (TaskEdge* edge : out_edges()) {
std::string name = "boxing_out_" + edge->edge_id_str();
auto regst_desc = NewProducedRegstDesc(name, 1, kMaxRegisterNum);
BindProducedRegstAndOutEdge(regst_desc, edge);
DEFINE_BLD_BOXING_OP_CONF_METHOD(InBoxingTaskNode, DataConcatAndDataSplit) {
*used_out_edge_begin = 0;
conf->set_out_num(sorted_out_edges.size());
conf->mutable_concat_box()->set_axis(0);
BoxSplitConf* split_conf = conf->mutable_split_box();
split_conf->set_axis(0);
split_conf->set_left_bound_size(0);
split_conf->set_right_bound_size(0);
BalancedSplitter bs(JobDesc::Singleton()->ParallelPieceSize(),
out_parallel_num);
split_conf->set_base_part_size(bs.BasePartSize());
int64_t out_prlll_id_min = sorted_out_edges.front().parallel_id_min;
int64_t out_prlll_id_max = sorted_out_edges.back().parallel_id_max;
if (bs.BaseBeginIdx() <= out_prlll_id_min) {
split_conf->set_bigger_part_num(0);
split_conf->set_base_part_num(out_prlll_id_max - out_prlll_id_min + 1);
} else if (out_prlll_id_min < bs.BaseBeginIdx()
&& bs.BaseBeginIdx() <= out_prlll_id_max) {
split_conf->set_bigger_part_num(bs.BaseBeginIdx() - out_prlll_id_min);
split_conf->set_base_part_num(out_prlll_id_max - bs.BaseBeginIdx() + 1);
} else if (out_prlll_id_max < bs.BaseBeginIdx()) {
split_conf->set_bigger_part_num(out_prlll_id_max - out_prlll_id_min + 1);
split_conf->set_base_part_num(0);
} else {
UNEXPECTED_RUN();
}
NewProducedRegstDesc("middle", 1);
}
void BoxingTaskNode::FwInitChain2SortedEdgesMaps(
Chain2EdgesMap* chain2sorted_edges,
const std::unordered_set<TaskEdge*>& (TaskNode::*in_out_edges)() const,
TaskNode* (TaskEdge::*src_dst_node)() const,
TaskEdge* (TaskNode::*SoleEdge)() const) {
chain2sorted_edges->clear();
HashMap<const TaskEdge*, const StageNode*> edge2stage;
for (const TaskEdge* edge : (this->*in_out_edges)()) {
const TaskNode* pred_succ_node = (edge->*src_dst_node)();
while (pred_succ_node->chain_node() == chain_node()) {
pred_succ_node = ((pred_succ_node->*SoleEdge)()->*src_dst_node)();
DEFINE_BLD_BOXING_OP_CONF_METHOD(OutBoxingTaskNode, DataConcatAndDataSplit) {}
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, DataConcatAndClone) { TODO(); }
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, DataConcatAndModelSplit) {
TODO();
}
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, ModelConcatAndDataSplit) {
TODO();
}
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, ModelConcatAndClone) {
TODO();
}
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndDataSplit) { TODO(); }
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndModelSplit) { TODO(); }
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndClone) { TODO(); }
void BoxingTaskNode::InitChain2SortedEdgeInfo(
const std::unordered_set<TaskEdge*>& (TaskNode::*GetEdges)() const,
TaskEdge* (TaskNode::*SoleEdge)() const,
TaskNode* (TaskEdge::*SoleNode)() const,
HashMap<const ChainNode*, std::vector<EdgeInfo>>* chain2sorted_edge_info) {
chain2sorted_edge_info->clear();
for (const TaskEdge* edge : (this->*GetEdges)()) {
EdgeInfo edge_info;
edge_info.edge = edge;
edge_info.parallel_id_min = std::numeric_limits<int64_t>::max();
edge_info.parallel_id_max = std::numeric_limits<int64_t>::min();
std::queue<const TaskNode*> node_queue;
node_queue.push((edge->*SoleNode)());
const ChainNode* chain = nullptr;
while (node_queue.empty() == false) {
const TaskNode* cur_node = node_queue.front();
node_queue.pop();
auto cur_comp_node = dynamic_cast<const CompTaskNode*>(cur_node);
if (cur_comp_node) {
edge_info.parallel_id_min =
std::min(edge_info.parallel_id_min, cur_comp_node->parallel_id());
edge_info.parallel_id_max =
std::max(edge_info.parallel_id_max, cur_comp_node->parallel_id());
if (chain == nullptr) { chain = cur_comp_node->chain_node(); }
} else {
for (const TaskEdge* cur_edge : (cur_node->*GetEdges)()) {
node_queue.push((cur_edge->*SoleNode)());
}
}
}
(*chain2sorted_edges)[pred_succ_node->chain_node()].push_back(edge);
edge2stage[edge] = pred_succ_node->stage_node();
(*chain2sorted_edge_info)[chain].push_back(edge_info);
}
for (auto& pair : *chain2sorted_edges) {
std::vector<const TaskEdge*>& edges = pair.second;
for (auto& pair : *chain2sorted_edge_info) {
std::vector<EdgeInfo>& edges = pair.second;
std::sort(edges.begin(), edges.end(),
[&edge2stage](const TaskEdge* lhs, const TaskEdge* rhs) {
const StageNode* lhs_stage = edge2stage.at(lhs);
const StageNode* rhs_stage = edge2stage.at(rhs);
CHECK(lhs_stage->chain_node() == rhs_stage->chain_node());
return lhs_stage->parallel_range().begin()
< rhs_stage->parallel_range().begin();
[&](const EdgeInfo& lhs, const EdgeInfo& rhs) {
return lhs.parallel_id_min < rhs.parallel_id_min;
});
}
}
void BoxingTaskNode::FwSortEdgesInnerStage(
std::vector<const TaskEdge*>* edges_to_be_sorted,
TaskNode* (TaskEdge::*src_dst_node)() const,
TaskEdge* (TaskNode::*SoleEdge)() const) {
auto GetPredSuccCompTaskNode = [&](const TaskEdge* edge) {
const TaskNode* node = (edge->*src_dst_node)();
const CompTaskNode* ret = nullptr;
while (ret = dynamic_cast<const CompTaskNode*>(node), ret == nullptr) {
node = ((node->*SoleEdge)()->*src_dst_node)();
}
return ret;
};
std::sort(edges_to_be_sorted->begin(), edges_to_be_sorted->end(),
[&](const TaskEdge* lhs, const TaskEdge* rhs) {
const CompTaskNode* lhs_node = GetPredSuccCompTaskNode(lhs);
const CompTaskNode* rhs_node = GetPredSuccCompTaskNode(rhs);
return lhs_node->parallel_id() < rhs_node->parallel_id();
});
}
void BoxingTaskNode::FwBuildChainSortedEdgesPair(
const ChainEdgesPair& chain_sorted_in_edges,
const ChainEdgesPair& chain_sorted_out_edges) {
// useful vars
const ChainNode* in_chain = chain_sorted_in_edges.first;
const auto& sorted_in_edges = chain_sorted_in_edges.second;
const ChainNode* out_chain = chain_sorted_out_edges.first;
const auto& sorted_out_edges = chain_sorted_out_edges.second;
// 4 case
ParallelPolicy in_policy = in_chain->parallel_desc()->policy();
ParallelPolicy out_policy = out_chain->parallel_desc()->policy();
void (*CompleteBoxOp)(BoxingOpConf*);
if (in_policy == kDataParallel && out_policy == kDataParallel) {
CompleteBoxOp = &FwCompleteBoxOpConfDataData;
} else if (in_policy == kDataParallel && out_policy == kModelParallel) {
CompleteBoxOp = &FwCompleteBoxOpConfDataModel;
} else if (in_policy == kModelParallel && out_policy == kDataParallel) {
CompleteBoxOp = &FwCompleteBoxOpConfModelData;
} else if (in_policy == kModelParallel && out_policy == kModelParallel) {
CompleteBoxOp = &FwCompleteBoxOpConfModelModel;
} else if (in_policy == kFakerMdUpdt) {
CHECK_EQ(out_policy, kModelParallel);
CompleteBoxOp = &FwCompleteBoxOpConfAddClone;
} else {
CHECK_EQ(in_policy, kFakerLossRecord);
CHECK_EQ(out_policy, kDataParallel);
CompleteBoxOp = &FwCompleteBoxOpConfAddClone;
}
// func 4 construct boxing_op in this node
auto ConstructBoxingOp = [&](const std::string& lbn) {
OperatorConf op_conf;
op_conf.set_name("boxing_op_" + NewUniqueId());
BoxingOpConf* box_conf = op_conf.mutable_boxing_conf();
box_conf->set_lbn(lbn);
box_conf->set_in_num(sorted_in_edges.size());
box_conf->set_out_num(sorted_out_edges.size());
CompleteBoxOp(box_conf);
return OpMgr::Singleton()->AddOp(op_conf);
};
// lbns
void BoxingTaskNode::BuildWithChainPair(
const ChainNode* in_chain, const std::vector<EdgeInfo>& sorted_in_edges,
const ChainNode* out_chain, const std::vector<EdgeInfo>& sorted_out_edges) {
std::vector<std::string> lbns = FindLbnsBetween(in_chain, out_chain);
if (lbns.at(0) == kPackedBlobName) {
CHECK_EQ(lbns.size(), 1);
lbns.clear();
auto in_regst_0 = GetRelatedRegst(sorted_in_edges.at(0));
in_regst_0->ForEachLbn(
[&](const std::string& lbn) { lbns.push_back(lbn); });
}
// Enroll Lbn
auto middle_regst = GetProducedRegstDesc("middle");
auto middle_regst = GetProducedRegst("middle");
for (const std::string& lbn : lbns) {
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = ConstructBoxingOp(lbn);
// ibn
for (size_t i = 0; i < sorted_in_edges.size(); ++i) {
auto in_regst = GetRelatedRegst(sorted_in_edges.at(i));
int64_t used_in_edge_begin = -1;
int64_t used_out_edge_begin = -1;
node->mut_op() =
NewBoxingOp(lbn, in_chain, out_chain, sorted_in_edges, sorted_out_edges,
&used_in_edge_begin, &used_out_edge_begin);
CHECK_NE(used_in_edge_begin, -1);
CHECK_NE(used_out_edge_begin, -1);
for (size_t i = 0; i < node->op()->input_bns().size(); ++i) {
auto regst = sorted_in_edges[used_in_edge_begin + i].edge->GetSoleRegst();
const std::string& ibn = node->op()->input_bns().at(i);
node->BindBnInOpAndRegst(ibn, in_regst);
node->BindBnInOpAndRegst(ibn, regst);
}
// obn
for (size_t i = 0; i < sorted_out_edges.size(); ++i) {
auto out_regst = GetRelatedRegst(sorted_out_edges.at(i));
for (size_t i = 0; i < node->op()->output_bns().size(); ++i) {
auto regst =
sorted_out_edges[used_out_edge_begin + i].edge->GetSoleRegst();
const std::string& obn = node->op()->output_bns().at(i);
out_regst->EnrollLbn(lbn);
node->BindBnInOpAndRegst(obn, out_regst);
regst->AddLbn(lbn);
node->BindBnInOpAndRegst(obn, regst);
}
// dtbn
for (const std::string& dtbn : node->op()->data_tmp_bns()) {
middle_regst->EnrollLbn(node->op()->Lbn4BnInOp(dtbn));
middle_regst->AddLbn(node->op()->Lbn4BnInOp(dtbn));
node->BindBnInOpAndRegst(dtbn, middle_regst);
}
node->op()->InferBlobDescs(node->GetBlobDesc4BnInOpFunc(), nullptr);
}
}
void BoxingTaskNode::FwInferBlobDescInProducedRegsts(TaskGraph*) {
exec_gph().ConstForEachNode([this](const ExecNode* exec_node) {
exec_node->op()->InferBlobDesc4FwBlobs(
exec_node->GetBlobDesc4BnInOpFunc(),
chain_node()->parallel_desc()->policy(), 0, 0);
});
}
namespace {
std::shared_ptr<RegstDesc> GetBpRegstFromFwRegst(
std::shared_ptr<RegstDesc> fw_regst) {
const TaskEdge* fw_edge = GetRelatedTaskEdge(fw_regst);
const TaskEdge* bp_edge = fw_edge->related_fwbp_edge();
if (bp_edge == nullptr) { return nullptr; }
return GetRelatedRegst(bp_edge);
}
} // namespace
void BoxingTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
EnrollAllRegstAndBindRelatedEdge();
GetFwNode()->exec_gph().ConstForEachNode([&](const ExecNode* fw_node) {
std::unique_ptr<ExecNode> bp_node(new ExecNode);
bp_node->mut_op() = fw_node->op();
bool need_enroll = true;
// in_diff
for (const std::string& ibn : fw_node->op()->input_bns()) {
std::string idbn = GenDiffBn(ibn);
const std::string& lbn = fw_node->op()->Lbn4BnInOp(ibn);
auto in_regst = fw_node->GetRegstFromBnInOp(ibn);
auto in_diff_regst = GetBpRegstFromFwRegst(in_regst);
if (!in_diff_regst) {
need_enroll = false;
break;
}
in_diff_regst->EnrollLbn(lbn);
bp_node->BindBnInOpAndRegst(idbn, in_diff_regst);
}
if (need_enroll == false) { return; }
// out_diff
for (const std::string& obn : fw_node->op()->output_bns()) {
std::string odbn = GenDiffBn(obn);
auto out_regst = fw_node->GetRegstFromBnInOp(obn);
auto out_diff_regst = GetBpRegstFromFwRegst(out_regst);
bp_node->BindBnInOpAndRegst(odbn, out_diff_regst);
}
// data tmp
for (const std::string& dtbn : fw_node->op()->data_tmp_bns()) {
const std::string& lbn = fw_node->op()->Lbn4BnInOp(dtbn);
auto bp_middle_regst = GetProducedRegstDesc("middle");
bp_middle_regst->EnrollLbn(lbn);
bp_node->BindBnInOpAndRegst(dtbn, bp_middle_regst);
}
mut_exec_gph().EnrollNode(std::move(bp_node));
});
mut_exec_gph().UpdateSourceAndSink();
}
void BoxingTaskNode::BpInferBlobDescInProducedRegsts(TaskGraph*) {
for (TaskEdge* fw_in_edge : GetFwNode()->in_edges()) {
auto in_regst = GetRelatedRegst(fw_in_edge);
if (auto in_diff_regst = GetBpRegstFromFwRegst(in_regst)) {
in_diff_regst->CopyBlobDescFrom(in_regst.get());
}
std::shared_ptr<Operator> BoxingTaskNode::NewBoxingOp(
const std::string& lbn, const ChainNode* in_chain,
const ChainNode* out_chain, const std::vector<EdgeInfo>& sorted_in_edges,
const std::vector<EdgeInfo>& sorted_out_edges, int64_t* used_in_edge_begin,
int64_t* used_out_edge_begin) {
BldBoxingOpConfMthd method = in_chain->GetMthdForBldBoxingOpConfTo(out_chain);
OperatorConf op_conf;
op_conf.set_name("boxing_op_" + NewUniqueId());
BoxingOpConf* boxing_conf = op_conf.mutable_boxing_conf();
boxing_conf->set_lbn(lbn);
int64_t in_edge_last = -1;
for (int64_t i = 0; i < sorted_in_edges.size(); ++i) {
auto in_regst = sorted_in_edges[i].edge->GetSoleRegst();
if (in_regst->GetBlobDesc(lbn) == nullptr) { continue; }
if (*used_in_edge_begin == -1) { *used_in_edge_begin = i; }
in_edge_last = i;
}
auto fw_middle_regst = GetFwNode()->GetProducedRegstDesc("middle");
auto bp_middle_regst = GetProducedRegstDesc("middle");
bp_middle_regst->CopyBlobDescFrom(fw_middle_regst.get());
CHECK_NE(in_edge_last, -1);
boxing_conf->set_in_num(in_edge_last - (*used_in_edge_begin) + 1);
(this->*method)(lbn, sorted_in_edges,
in_chain->parallel_desc()->parallel_num(),
*used_in_edge_begin, in_edge_last, sorted_out_edges,
out_chain->parallel_desc()->parallel_num(),
used_out_edge_begin, boxing_conf);
return OpMgr::Singleton()->AddOp(op_conf);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_BOXING_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_BOXING_TASK_NODE_H_
#include "oneflow/core/graph/comp_task_node.h"
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class ChainNode;
class BoxingTaskNode : public TaskNode {
public:
struct EdgeInfo {
const TaskEdge* edge;
int64_t parallel_id_min;
int64_t parallel_id_max;
};
OF_DISALLOW_COPY_AND_MOVE(BoxingTaskNode);
BoxingTaskNode() = default;
virtual ~BoxingTaskNode() = default;
std::string VisualStr() const override {
return TaskNode::VisualStr() + "Boxing";
}
void ToProto(TaskProto* ret) const override { TaskNode::ToProto(ret); };
DeviceType GetDeviceType() const override { return DeviceType::kCPU; }
protected:
virtual void InitWithFwNode(TaskNode* fw_node) override {
TaskNode::InitWithFwNode(fw_node);
}
using ChainEdgesPair =
std::pair<const ChainNode*, std::vector<const TaskEdge*>>;
using Chain2EdgesMap =
HashMap<const ChainNode*, std::vector<const TaskEdge*>>;
void FwInitChain2SortedEdgesMaps(
Chain2EdgesMap* chain2sorted_edges,
const std::unordered_set<TaskEdge*>& (TaskNode::*in_out_edges)() const,
TaskNode* (TaskEdge::*src_dst_node)() const,
TaskEdge* (TaskNode::*SoleEdge)() const);
void FwSortEdgesInnerStage(std::vector<const TaskEdge*>* edges_to_be_sorted,
TaskNode* (TaskEdge::*src_dst_node)() const,
TaskEdge* (TaskNode::*SoleEdge)() const);
void FwBuildChainSortedEdgesPair(
const ChainEdgesPair& chain_sorted_in_edges,
const ChainEdgesPair& chain_sorted_out_edges);
virtual void FwVirtualBuild() = 0;
void Init(int64_t machine_id);
TodoTaskType GetTaskType() const override { return TodoTaskType::kBoxing; }
void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;
void Build() override;
#define DECLARE_BLD_BOXING_OP_CONF_METHOD(x) \
void BldBoxingOpConfWith##x( \
const std::string& lbn, const std::vector<EdgeInfo>& sorted_in_edges, \
int64_t in_parallel_num, int64_t in_edge_first, int64_t in_edge_last, \
const std::vector<EdgeInfo>& sorted_out_edges, int64_t out_parallel_num, \
int64_t* used_out_edge_begin, BoxingOpConf*)
#define DECLARE_VIRTUAL_BLD_BOXING_OP_CONF_METHOD(x) \
virtual DECLARE_BLD_BOXING_OP_CONF_METHOD(x) = 0
DECLARE_BLD_BOXING_OP_CONF_METHOD();
DECLARE_VIRTUAL_BLD_BOXING_OP_CONF_METHOD(DataConcatAndDataSplit);
DECLARE_BLD_BOXING_OP_CONF_METHOD(DataConcatAndClone);
DECLARE_BLD_BOXING_OP_CONF_METHOD(DataConcatAndModelSplit);
DECLARE_BLD_BOXING_OP_CONF_METHOD(ModelConcatAndDataSplit);
DECLARE_BLD_BOXING_OP_CONF_METHOD(ModelConcatAndClone);
DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndDataSplit);
DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndModelSplit);
DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndClone);
private:
void InitChain2SortedEdgeInfo(
const std::unordered_set<TaskEdge*>& (TaskNode::*GetEdges)() const,
TaskEdge* (TaskNode::*SoleEdge)() const,
TaskNode* (TaskEdge::*SoleNode)() const,
HashMap<const ChainNode*, std::vector<EdgeInfo>>*);
void BuildWithChainPair(const ChainNode* in_chain,
const std::vector<EdgeInfo>& sorted_in_edges,
const ChainNode* out_chain,
const std::vector<EdgeInfo>& sorted_out_edges);
std::shared_ptr<Operator> NewBoxingOp(
const std::string& lbn, const ChainNode* in_chain,
const ChainNode* out_chain, const std::vector<EdgeInfo>& sorted_in_edges,
const std::vector<EdgeInfo>& sorted_out_edges,
int64_t* used_in_edge_begin, int64_t* used_out_edge_begin);
};
#define OVERRIDE_BLD_BOXING_OP_METHOD(x) \
DECLARE_BLD_BOXING_OP_CONF_METHOD(x) override
class InBoxingTaskNode final : public BoxingTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(InBoxingTaskNode);
InBoxingTaskNode() = default;
~InBoxingTaskNode() = default;
OVERRIDE_BLD_BOXING_OP_METHOD(DataConcatAndDataSplit);
private:
OVERRIDE_IF_FW_BP_FOR_FUNC(BuildExecAndEnrollLbn2Regsts);
OVERRIDE_IF_FW_BP_FOR_FUNC(InferBlobDescInProducedRegsts);
};
class OutBoxingTaskNode final : public BoxingTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(OutBoxingTaskNode);
OutBoxingTaskNode() = default;
~OutBoxingTaskNode() = default;
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void FwInferBlobDescInProducedRegsts(TaskGraph*);
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void BpInferBlobDescInProducedRegsts(TaskGraph*);
OVERRIDE_BLD_BOXING_OP_METHOD(DataConcatAndDataSplit);
void EnrollAllRegstAndBindRelatedEdge();
TaskType task_type() const override { return kBoxingTask; }
private:
};
} // namespace oneflow
......
......@@ -19,19 +19,13 @@ struct Chain {
using ChainIt = std::list<Chain>::iterator;
using Logical2ChainItMap = HashMap<const LogicalNode*, ChainIt>;
void SetChainNodeWithChainIt(ChainNode* chain_node, ChainIt chain_it) {
CHECK(!chain_it->nodes.empty());
chain_node->mut_parallel_desc() = chain_it->nodes.front()->parallel_desc();
for (const LogicalNode* logical_node : chain_it->nodes) {
chain_node->mut_op_vec().push_back(logical_node->op());
}
}
void SetChainNodeWithChainIt(ChainNode* chain_node, ChainIt chain_it) {}
void InitChains(const LogicalGraph& logi_gph, std::list<Chain>* chain_list,
void InitChains(std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
chain_list->clear();
logical2chain_it->clear();
logi_gph.ConstForEachNode([&](const LogicalNode* node) {
LogicalGraph::Singleton()->ForEachNode([&](const LogicalNode* node) {
// Init one Chain with one Node
chain_list->emplace_back();
logical2chain_it->insert({node, --chain_list->end()});
......@@ -39,8 +33,8 @@ void InitChains(const LogicalGraph& logi_gph, std::list<Chain>* chain_list,
cur_chain.nodes = {node};
});
// Init ancestors
logi_gph.ConstTopoForEachNode([&](const LogicalNode* node) {
ChainIt cur_chain = logical2chain_it->at(&(*node));
LogicalGraph::Singleton()->TopoForEachNode([&](LogicalNode* node) {
ChainIt cur_chain = logical2chain_it->at(node);
cur_chain->ancestors.clear();
cur_chain->ancestors_and_this.clear();
cur_chain->ancestors_and_this.insert(cur_chain->nodes.begin(),
......@@ -59,8 +53,8 @@ void InitChains(const LogicalGraph& logi_gph, std::list<Chain>* chain_list,
cur_chain->ancestors.end());
});
// Init descendants
logi_gph.ConstReverseTopoForEachNode([&](const LogicalNode* node) {
ChainIt cur_chain = logical2chain_it->at(&(*node));
LogicalGraph::Singleton()->ReverseTopoForEachNode([&](LogicalNode* node) {
ChainIt cur_chain = logical2chain_it->at(node);
cur_chain->descendants.clear();
cur_chain->descendants_and_this.clear();
cur_chain->descendants_and_this.insert(cur_chain->nodes.begin(),
......@@ -85,15 +79,10 @@ void ModelMergeChains(std::list<Chain>* chain_list,
for (auto& pair : *logical2chain_it) {
// Get cur_node, pred_node
const LogicalNode* cur_node = pair.first;
if (cur_node->op()->IsElemWise() == false) { continue; }
if (cur_node->op()->IsElemWiseOp() == false) { continue; }
if (cur_node->parallel_desc()->policy() != kModelParallel) { continue; }
const LogicalNode* pred_node = cur_node->SoleInEdge()->src_node();
CHECK(pred_node->parallel_desc()->Equal(cur_node->parallel_desc().get()))
<< "the ParallelConf of "
<< "\"" << pred_node->op()->op_name() << "\" "
<< "and "
<< "\"" << cur_node->op()->op_name() << "\" "
<< "should be the same";
CHECK(pred_node->parallel_desc()->Equal(cur_node->parallel_desc().get()));
// Get chain
ChainIt pred_chain = logical2chain_it->at(pred_node);
ChainIt cur_chain = pair.second;
......@@ -195,14 +184,14 @@ bool DoOneDataMerge(const std::vector<const LogicalNode*>& data_parallel_node,
return false;
}
void DataMergeChains(const LogicalGraph& logical_gph,
std::list<Chain>* chain_list,
void DataMergeChains(std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
std::vector<const LogicalNode*> data_parallel_node;
for (const auto& pair : *logical2chain_it) {
const LogicalNode* cur_logi_node = pair.first;
if (cur_logi_node->parallel_desc()->policy() != kDataParallel) { continue; }
if (!cur_logi_node->IsChainMergeable()) { continue; }
if (cur_logi_node->op()->IsLossOp()) { continue; }
if (cur_logi_node->op()->IsDataLoaderOp()) { continue; }
data_parallel_node.push_back(cur_logi_node);
}
while (DoOneDataMerge(data_parallel_node, chain_list, logical2chain_it)) {}
......@@ -210,32 +199,24 @@ void DataMergeChains(const LogicalGraph& logical_gph,
} // namespace
std::string ChainNode::ConcatedOpsName() const {
std::stringstream ss;
for (auto op : op_vec_) { ss << "\\n" << op->op_name(); }
if (!op_vec_.empty()) {
return ss.str().substr(2);
} else {
return node_id_str();
}
}
bool ChainNode::HasOpWithModelOrModelTmpBlob() const {
for (std::shared_ptr<Operator> op : op_vec_) {
if (!op->model_bns().empty() || !op->model_tmp_bns().empty()) {
return true;
}
ChainGraph::ChainGraph(bool is_train) {
BuildFwStruct();
if (is_train) {
BuildBwStruct();
BuildLossRecordStruct();
}
return false;
BuildModelStruct(is_train);
BuildRnnStruct();
ToDotWithAutoFilePath();
}
ChainGraph::ChainGraph(const LogicalGraph* logical_gph) {
void ChainGraph::BuildFwStruct() {
// Build Chain
std::list<Chain> chain_list;
Logical2ChainItMap logical2chain_it;
InitChains(*logical_gph, &chain_list, &logical2chain_it);
InitChains(&chain_list, &logical2chain_it);
ModelMergeChains(&chain_list, &logical2chain_it);
DataMergeChains(*logical_gph, &chain_list, &logical2chain_it);
DataMergeChains(&chain_list, &logical2chain_it);
// Init chain_nodes
auto HashChainIt = [](const ChainIt& chain_it) {
return std::hash<Chain*>()(&(*chain_it));
......@@ -243,16 +224,29 @@ ChainGraph::ChainGraph(const LogicalGraph* logical_gph) {
HashMap<ChainIt, ChainNode*, decltype(HashChainIt)> chain_it2chain_node(
11, HashChainIt);
HashMap<ChainNode*, std::unordered_set<ChainNode*>> chain_node2pred;
for (auto chain_it = chain_list.begin(); chain_it != chain_list.end();
++chain_it) {
ChainNode* chain_node = NewNode();
FOR_EACH(chain_it, chain_list) {
ChainNode* chain_node = nullptr;
if (chain_it->nodes.size() == 1) {
std::shared_ptr<const Operator> op = chain_it->nodes[0]->op();
if (op->IsLossOp()) {
chain_node = NewNode<LossChainNode>();
} else if (op->IsDataLoaderOp()) {
chain_node = NewNode<SourceChainNode>();
} else {
// do nothing
}
}
if (chain_node == nullptr) { chain_node = NewNode<ForwardChainNode>(); }
chain_it2chain_node[chain_it] = chain_node;
chain_node2pred[chain_node] = {};
SetChainNodeWithChainIt(chain_node, chain_it);
CHECK(!chain_it->nodes.empty());
chain_node->mut_parallel_desc() = chain_it->nodes.front()->parallel_desc();
for (const LogicalNode* logical_node : chain_it->nodes) {
chain_node->mut_op_vec().push_back(logical_node->op());
}
}
// Record the predecessor
for (auto chain_it = chain_list.begin(); chain_it != chain_list.end();
++chain_it) {
FOR_EACH(chain_it, chain_list) {
ChainNode* chain_node = chain_it2chain_node.at(chain_it);
for (const LogicalNode* logi_node : chain_it->nodes) {
for (auto logi_in_edge : logi_node->in_edges()) {
......@@ -271,66 +265,125 @@ ChainGraph::ChainGraph(const LogicalGraph* logical_gph) {
Connect(pred_node, NewEdge(), cur_node);
}
}
// Post processing
UpdateSourceAndSink();
SetInOutLbn4AllChainNodeInDataTaskGraph();
ToDotWithAutoFilePath();
}
void ChainGraph::SetInOutLbn4AllChainNodeInDataTaskGraph() {
HashMap<ChainNode*, std::unordered_set<std::string>> chain2produced_lbns;
// Init chain2produced_lbns and Set InputLbns
ForEachNode([&](ChainNode* cur_node) {
auto& produced_lbns = chain2produced_lbns[cur_node];
for (std::shared_ptr<Operator> op : cur_node->op_vec()) {
for (const std::string& obn : op->output_bns()) {
const std::string& lbn = op->Lbn4BnInOp(obn);
produced_lbns.insert(lbn);
}
void ChainGraph::BuildBwStruct() {
HashSet<ForwardChainNode*> fw_nodes_that_need_bw;
TopoForEachNode([&](ChainNode* chain_node) {
auto fw_chain_node = dynamic_cast<ForwardChainNode*>(chain_node);
if (fw_chain_node == nullptr) { return; }
if (fw_chain_node->HasOpWithModelOrModelTmpBlob()) {
CHECK(fw_nodes_that_need_bw.insert(fw_chain_node).second);
return;
}
for (std::shared_ptr<Operator> op : cur_node->op_vec()) {
for (const std::string& ibn : op->input_bns()) {
const std::string& lbn = op->Lbn4BnInOp(ibn);
if (produced_lbns.find(lbn) == produced_lbns.end()) {
cur_node->mut_input_lbns().push_back(lbn);
}
for (ChainEdge* edge : fw_chain_node->in_edges()) {
auto fw_pred_node = static_cast<ForwardChainNode*>(edge->src_node());
if (fw_nodes_that_need_bw.find(fw_pred_node)
!= fw_nodes_that_need_bw.end()) {
CHECK(fw_nodes_that_need_bw.insert(fw_chain_node).second);
return;
}
}
SortAndRemoveDuplication(&(cur_node->mut_input_lbns()));
});
// Set OutputLbns
ForEachNode([&](ChainNode* cur_node) {
const auto& produced_lbns = chain2produced_lbns.at(cur_node);
for (ChainEdge* out_edge : cur_node->out_edges()) {
for (const std::string& lbn : out_edge->dst_node()->input_lbns()) {
if (produced_lbns.find(lbn) != produced_lbns.end()) {
cur_node->mut_output_lbns().push_back(lbn);
}
}
for (ForwardChainNode* fw_node : fw_nodes_that_need_bw) {
BackwardChainNode* bw_node = NewNode<BackwardChainNode>();
bw_node->mut_op_vec() = fw_node->op_vec();
bw_node->mut_parallel_desc() = fw_node->parallel_desc();
fw_node->set_bw_node(bw_node);
bw_node->set_fw_node(fw_node);
}
std::list<ChainEdge*> fw_edges;
ForEachEdge([&](ChainEdge* edge) { fw_edges.push_back(edge); });
for (ChainEdge* fw_edge : fw_edges) {
auto fw_src_node = dynamic_cast<ForwardChainNode*>(fw_edge->src_node());
if (fw_src_node == nullptr) { continue; }
auto fw_dst_node = dynamic_cast<ForwardChainNode*>(fw_edge->dst_node());
ChainNode* bw_src_node = fw_src_node->bw_node();
if (bw_src_node == nullptr) { continue; }
if (fw_dst_node == nullptr) {
Connect(fw_edge->dst_node(), NewEdge(), bw_src_node);
} else {
ChainNode* bw_dst_node = fw_dst_node->bw_node();
if (bw_dst_node == nullptr) { continue; }
Connect(bw_dst_node, NewEdge(), bw_src_node);
}
SortAndRemoveDuplication(&(cur_node->mut_output_lbns()));
}
for (ForwardChainNode* fw_node : fw_nodes_that_need_bw) {
BackwardChainNode* bw_node = fw_node->bw_node();
Connect<ChainNode>(fw_node, NewEdge(), bw_node);
}
}
void ChainGraph::BuildLossRecordStruct() {
ForEachChainNode<LossChainNode>([&](LossChainNode* loss_chain) {
// Loss Accumulate Chain
OperatorConf loss_acc_op_conf;
loss_acc_op_conf.set_name("loss_acc_" + NewUniqueId());
loss_acc_op_conf.mutable_accumulate_conf();
auto loss_acc_op = OpMgr::Singleton()->AddOp(loss_acc_op_conf);
auto loss_acc_chain = NewNode<LossAccChainNode>();
loss_acc_chain->mut_op_vec() = {loss_acc_op};
loss_acc_chain->mut_parallel_desc() = loss_chain->parallel_desc();
Connect<ChainNode>(loss_chain, NewEdge(), loss_acc_chain);
// Loss Record Chain
OperatorConf loss_record_op_conf;
loss_record_op_conf.set_name("loss_record_" + NewUniqueId());
loss_record_op_conf.mutable_loss_record_conf();
auto loss_record_op = OpMgr::Singleton()->AddOp(loss_record_op_conf);
ParallelConf loss_record_pr_conf;
loss_record_pr_conf.set_policy(kDataParallel);
loss_record_pr_conf.add_device_name(
IDMgr::Singleton()->MachineName4MachineId(0) + ":0");
auto loss_record_chain = NewNode<LossRecordChainNode>();
loss_record_chain->mut_op_vec() = {loss_record_op};
loss_record_chain->mut_parallel_desc().reset(
new ParallelDesc(loss_record_pr_conf));
Connect<ChainNode>(loss_acc_chain, NewEdge(), loss_record_chain);
});
}
std::vector<std::string> FindLbnsBetween(const ChainNode* src_node,
const ChainNode* dst_node) {
std::vector<std::string> matching_lbns;
for (const std::string& src_node_output_lbn : src_node->output_lbns()) {
for (const std::string& dst_node_input_lbn : dst_node->input_lbns()) {
if (src_node_output_lbn != dst_node_input_lbn) { continue; }
matching_lbns.push_back(src_node_output_lbn);
break;
void ChainGraph::BuildModelStruct(bool is_train) {
ForEachChainNode<ForwardChainNode>([&](ForwardChainNode* fw_chain) {
if (fw_chain->HasOpWithModelOrModelTmpBlob() == false) { return; }
// Model Update Chain
auto md_updt_chain = NewNode<MdUpdtChainNode>();
md_updt_chain->mut_op_vec() = {OpMgr::Singleton()->ModelUpdateOp()};
md_updt_chain->mut_parallel_desc() = fw_chain->parallel_desc();
Connect<ChainNode>(md_updt_chain, NewEdge(), fw_chain);
// Model Save Chain
OperatorConf model_save_op_conf;
model_save_op_conf.set_name("md_save_" + NewUniqueId());
for (std::shared_ptr<const Operator> op : fw_chain->op_vec()) {
for (const std::string& mbn : op->model_bns()) {
const std::string& lbn = op->Lbn4BnInOp(mbn);
model_save_op_conf.mutable_model_save_conf()->add_lbns(lbn);
}
}
}
CHECK_NE(matching_lbns.size(), 0);
return matching_lbns;
auto model_save_op = OpMgr::Singleton()->AddOp(model_save_op_conf);
auto md_save_chain = NewNode<MdSaveChainNode>();
md_save_chain->mut_op_vec() = {model_save_op};
auto md_save_pr_desc = new ParallelDesc(*(fw_chain->parallel_desc()));
if (fw_chain->parallel_desc()->policy() == ParallelPolicy::kDataParallel) {
md_save_pr_desc->RemoveNeedlessDevice(1);
}
md_save_chain->mut_parallel_desc().reset(md_save_pr_desc);
Connect<ChainNode>(md_updt_chain, NewEdge(), md_save_chain);
// Model Diff Accumulate Chain
if (is_train == false) { return; }
BackwardChainNode* bw_chain = fw_chain->bw_node();
Connect<ChainNode>(md_updt_chain, NewEdge(), bw_chain);
OperatorConf md_diff_acc_op_conf;
md_diff_acc_op_conf.set_name("md_diff_acc_" + NewUniqueId());
md_diff_acc_op_conf.mutable_accumulate_conf();
auto md_diff_acc_op = OpMgr::Singleton()->AddOp(md_diff_acc_op_conf);
auto md_diff_acc_chain = NewNode<MdDiffAccChainNode>();
md_diff_acc_chain->mut_op_vec() = {md_diff_acc_op};
md_diff_acc_chain->mut_parallel_desc() = fw_chain->parallel_desc();
Connect<ChainNode>(bw_chain, NewEdge(), md_diff_acc_chain);
Connect<ChainNode>(md_diff_acc_chain, NewEdge(), md_updt_chain);
});
}
std::string ChainEdge::VisualStr() const {
std::vector<std::string> lbns = FindLbnsBetween(src_node(), dst_node());
std::stringstream ss;
for (const std::string& lbn : lbns) { ss << "\\n" << lbn; }
return ss.str().substr(2);
}
void ChainGraph::BuildRnnStruct() {}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_CHAIN_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_CHAIN_GRAPH_H_
#include "oneflow/core/graph/chain_node.h"
#include "oneflow/core/graph/logical_graph.h"
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
class ChainEdge;
class ChainNode final : public Node<ChainNode, ChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainNode);
ChainNode() = default;
~ChainNode() = default;
std::string ConcatedOpsName() const;
std::string ChainTag() const {
std::string chain_tag = op_vec_.front()->op_name();
StringReplace(&chain_tag, '/', '_');
return chain_tag;
}
std::shared_ptr<Operator> SoleOp() const {
CHECK_EQ(op_vec_.size(), 1);
return op_vec_.front();
}
const std::vector<std::shared_ptr<Operator>>& op_vec() const {
return op_vec_;
}
std::vector<std::shared_ptr<Operator>>& mut_op_vec() { return op_vec_; }
std::shared_ptr<const ParallelDesc> parallel_desc() const {
return parallel_desc_;
}
std::shared_ptr<const ParallelDesc>& mut_parallel_desc() {
return parallel_desc_;
}
const std::vector<std::string>& input_lbns() const { return input_lbns_; }
std::vector<std::string>& mut_input_lbns() { return input_lbns_; }
const std::vector<std::string>& output_lbns() const { return output_lbns_; }
std::vector<std::string>& mut_output_lbns() { return output_lbns_; }
bool IsLossNode() const {
return op_vec_.size() == 1 && op_vec_.front()->IsLossOp();
}
bool IsRecordNode() const {
return op_vec_.size() == 1 && op_vec_.front()->IsRecordOp();
}
std::string VisualStr() const { return ConcatedOpsName(); }
bool HasOpWithModelOrModelTmpBlob() const;
private:
std::vector<std::shared_ptr<Operator>> op_vec_;
std::shared_ptr<const ParallelDesc> parallel_desc_;
std::vector<std::string> input_lbns_;
std::vector<std::string> output_lbns_;
};
class ChainEdge final : public Edge<ChainNode, ChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainEdge);
ChainEdge() = default;
~ChainEdge() = default;
std::string VisualStr() const override;
private:
};
class ChainGraph final : public Graph<ChainNode, ChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainGraph);
ChainGraph() = default;
~ChainGraph() = default;
ChainGraph(const LogicalGraph* logical_gph);
ChainGraph(bool is_train);
const char* TypeName() const override { return "ChainGraph"; }
private:
void SetInOutLbn4AllChainNodeInDataTaskGraph();
};
template<typename ChainNodeType>
void ForEachChainNode(std::function<void(ChainNodeType*)> Handler) {
// the Handler may call "NewNode"
std::vector<ChainNodeType*> valid_nodes;
ForEachNode([&](ChainNode* chain_node) {
auto valid_node = dynamic_cast<ChainNodeType*>(chain_node);
if (valid_node != nullptr) { valid_nodes.push_back(valid_node); }
});
for (ChainNodeType* valid_node : valid_nodes) { Handler(valid_node); }
}
std::vector<std::string> FindLbnsBetween(const ChainNode*, const ChainNode*);
void BuildFwStruct();
void BuildBwStruct();
void BuildLossRecordStruct();
void BuildModelStruct(bool is_train);
void BuildRnnStruct();
};
} // namespace oneflow
......
#include "oneflow/core/graph/chain_node.h"
#include "oneflow/core/graph/backward_compute_task_node.h"
#include "oneflow/core/graph/forward_compute_task_node.h"
#include "oneflow/core/graph/loss_accumulate_compute_task_node.h"
#include "oneflow/core/graph/loss_compute_task_node.h"
#include "oneflow/core/graph/loss_record_compute_task_node.h"
#include "oneflow/core/graph/model_diff_accumulate_compute_task_node.h"
#include "oneflow/core/graph/model_save_compute_task_node.h"
#include "oneflow/core/graph/model_update_compute_task_node.h"
#include "oneflow/core/graph/source_compute_task_node.h"
#include "oneflow/core/graph/task_graph.h"
namespace oneflow {
namespace {
BldBoxingOpConfMthd GetBldBoxingOpConfMethodByFwParallelPolicy(
const ChainNode* in_chain, const ChainNode* out_chain) {
ParallelPolicy in_policy = in_chain->parallel_desc()->policy();
ParallelPolicy out_policy = out_chain->parallel_desc()->policy();
if (in_policy == kDataParallel && out_policy == kDataParallel) {
return &BoxingTaskNode::BldBoxingOpConfWithDataConcatAndDataSplit;
} else if (in_policy == kDataParallel && out_policy == kModelParallel) {
return &BoxingTaskNode::BldBoxingOpConfWithDataConcatAndClone;
} else if (in_policy == kModelParallel && out_policy == kDataParallel) {
return &BoxingTaskNode::BldBoxingOpConfWithModelConcatAndDataSplit;
} else if (in_policy == kModelParallel && out_policy == kModelParallel) {
return &BoxingTaskNode::BldBoxingOpConfWithModelConcatAndClone;
} else {
LOG(FATAL) << "in " << in_policy << " out " << out_policy;
}
}
BldBoxingOpConfMthd GetBldBoxingOpConfMethodByBwParallelPolicy(
const ChainNode* in_chain, const ChainNode* out_chain) {
ParallelPolicy in_policy = in_chain->parallel_desc()->policy();
ParallelPolicy out_policy = out_chain->parallel_desc()->policy();
if (in_policy == kDataParallel && out_policy == kDataParallel) {
return &BoxingTaskNode::BldBoxingOpConfWithDataConcatAndDataSplit;
} else if (in_policy == kDataParallel && out_policy == kModelParallel) {
return &BoxingTaskNode::BldBoxingOpConfWithAddAndDataSplit;
} else if (in_policy == kModelParallel && out_policy == kDataParallel) {
return &BoxingTaskNode::BldBoxingOpConfWithDataConcatAndModelSplit;
} else if (in_policy == kModelParallel && out_policy == kModelParallel) {
return &BoxingTaskNode::BldBoxingOpConfWithAddAndModelSplit;
} else {
LOG(FATAL) << "out_diff " << in_policy << " in_diff " << out_policy;
}
}
std::vector<std::string> FindLbnsBetweenChainPair(
const ChainNode* in_chain,
const std::vector<std::string>& (Operator::*GetOutLbns)() const,
const ChainNode* out_chain,
const std::vector<std::string>& (Operator::*GetInLbns)() const) {
HashSet<std::string> out_lbns_in_chain;
for (std::shared_ptr<const Operator> op : in_chain->op_vec()) {
for (const std::string& bn_in_op : (op.get()->*GetOutLbns)()) {
const std::string& lbn = op->Lbn4BnInOp(bn_in_op);
CHECK(out_lbns_in_chain.insert(lbn).second);
}
}
std::vector<std::string> result;
for (std::shared_ptr<const Operator> op : out_chain->op_vec()) {
for (const std::string& bn_in_op : (op.get()->*GetInLbns)()) {
const std::string& lbn = op->Lbn4BnInOp(bn_in_op);
if (out_lbns_in_chain.find(lbn) != out_lbns_in_chain.end()) {
result.push_back(lbn);
}
}
}
SortAndRemoveDuplication(&result);
return result;
}
std::vector<std::string> FindLbnsBetweenFw(const ChainNode* in_chain,
const ChainNode* out_chain) {
return FindLbnsBetweenChainPair(in_chain, &Operator::output_bns, out_chain,
&Operator::input_bns);
}
std::vector<std::string> FindLbnsBetweenBw(const ChainNode* in_chain,
const ChainNode* out_chain) {
return FindLbnsBetweenChainPair(in_chain, &Operator::input_diff_bns,
out_chain, &Operator::output_diff_bns);
}
} // namespace
std::shared_ptr<const Operator> ChainNode::SoleOp() const {
CHECK_EQ(op_vec_.size(), 1);
return op_vec_.front();
}
const std::vector<std::shared_ptr<const Operator>>& ChainNode::op_vec() const {
return op_vec_;
}
std::shared_ptr<const ParallelDesc> ChainNode::parallel_desc() const {
return parallel_desc_;
}
std::shared_ptr<const ParallelDesc>& ChainNode::mut_parallel_desc() {
return parallel_desc_;
}
std::string ChainNode::VisualStr() const {
std::stringstream ss;
ss << TypeName();
for (auto op : op_vec_) { ss << "\\n" << op->op_name(); }
return ss.str();
}
bool ChainNode::HasOpWithModelOrModelTmpBlob() const {
for (std::shared_ptr<const Operator> op : op_vec_) {
if (!op->model_bns().empty() || !op->model_tmp_bns().empty()) {
return true;
}
}
return false;
}
void ChainNode::GenSortedCompTaskNodes(CompTaskNodeHandler Handler) const {
int64_t parallel_idx = 0;
int64_t parallel_num = parallel_desc_->parallel_num();
for (int64_t machine_id : parallel_desc_->sorted_machine_ids()) {
for (int64_t dev_phy_id : parallel_desc_->sorted_dev_phy_ids(machine_id)) {
CompTaskNode* comp_task_node = NewCompTaskNode();
comp_task_node->set_machine_id(machine_id);
comp_task_node->set_thrd_loc_id(dev_phy_id);
comp_task_node->set_chain_node(this);
comp_task_node->mut_parallel_ctx().set_parallel_id(parallel_idx++);
comp_task_node->mut_parallel_ctx().set_parallel_num(parallel_num);
comp_task_node->mut_parallel_ctx().set_policy(parallel_desc_->policy());
Handler(comp_task_node);
}
}
}
#define DEFINE_VIRTUAL_METHOD(x) \
const char* x##ChainNode::TypeName() const { return #x "ChainNode"; } \
BldSubTskGphMthd x##ChainNode::GetMthdForBldSubTskGphTo( \
const ChainNode* node) const { \
return node->GetMthdForBldSubTskGphFrom##x(this); \
} \
BldBoxingOpConfMthd x##ChainNode::GetMthdForBldBoxingOpConfTo( \
const ChainNode* node) const { \
return node->GetMthdForBldBoxingOpConfFrom##x(this); \
} \
std::vector<std::string> x##ChainNode::FindLbnsTo(const ChainNode* node) \
const { \
return node->FindLbnsFrom##x(this); \
} \
BldSubTskGphMthd ChainNode::GetMthdForBldSubTskGphFrom##x(const ChainNode*) \
const { \
UNEXPECTED_RUN(); \
return nullptr; \
} \
BldBoxingOpConfMthd ChainNode::GetMthdForBldBoxingOpConfFrom##x( \
const ChainNode*) const { \
UNEXPECTED_RUN(); \
return nullptr; \
} \
std::vector<std::string> ChainNode::FindLbnsFrom##x(const ChainNode*) \
const { \
UNEXPECTED_RUN(); \
return {}; \
} \
CompTaskNode* x##ChainNode::NewCompTaskNode() const { \
return new x##CompTaskNode; \
}
OF_PP_FOR_EACH_TUPLE(DEFINE_VIRTUAL_METHOD, CHAIN_TYPE_SEQ)
// ForwardChainNode
BldSubTskGphMthd ForwardChainNode::GetMthdForBldSubTskGphFromForward(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByBoxing;
}
BldSubTskGphMthd ForwardChainNode::GetMthdForBldSubTskGphFromSource(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByBoxing;
}
BldSubTskGphMthd ForwardChainNode::GetMthdForBldSubTskGphFromMdUpdt(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByOneToOne;
}
BldBoxingOpConfMthd ForwardChainNode::GetMthdForBldBoxingOpConfFromForward(
const ChainNode* node) const {
return GetBldBoxingOpConfMethodByFwParallelPolicy(node, this);
}
BldBoxingOpConfMthd ForwardChainNode::GetMthdForBldBoxingOpConfFromSource(
const ChainNode* node) const {
return GetBldBoxingOpConfMethodByFwParallelPolicy(node, this);
}
std::vector<std::string> ForwardChainNode::FindLbnsFromForward(
const ChainNode* node) const {
return FindLbnsBetweenFw(node, this);
}
std::vector<std::string> ForwardChainNode::FindLbnsFromSource(
const ChainNode* node) const {
return FindLbnsBetweenFw(node, this);
}
// BackwardChainNode
BldSubTskGphMthd BackwardChainNode::GetMthdForBldSubTskGphFromForward(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByOneToOne;
}
BldSubTskGphMthd BackwardChainNode::GetMthdForBldSubTskGphFromBackward(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByBoxing;
}
BldSubTskGphMthd BackwardChainNode::GetMthdForBldSubTskGphFromLoss(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByBoxing;
}
BldSubTskGphMthd BackwardChainNode::GetMthdForBldSubTskGphFromMdUpdt(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByOneToOne;
}
BldBoxingOpConfMthd BackwardChainNode::GetMthdForBldBoxingOpConfFromBackward(
const ChainNode* node) const {
return GetBldBoxingOpConfMethodByBwParallelPolicy(node, this);
}
BldBoxingOpConfMthd BackwardChainNode::GetMthdForBldBoxingOpConfFromLoss(
const ChainNode* node) const {
return GetBldBoxingOpConfMethodByBwParallelPolicy(node, this);
}
std::vector<std::string> BackwardChainNode::FindLbnsFromBackward(
const ChainNode* node) const {
return FindLbnsBetweenBw(node, this);
}
std::vector<std::string> BackwardChainNode::FindLbnsFromLoss(
const ChainNode* node) const {
return FindLbnsBetweenBw(node, this);
}
// LossChainNode
BldSubTskGphMthd LossChainNode::GetMthdForBldSubTskGphFromForward(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByBoxing;
}
BldSubTskGphMthd LossChainNode::GetMthdForBldSubTskGphFromSource(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByBoxing;
}
BldBoxingOpConfMthd LossChainNode::GetMthdForBldBoxingOpConfFromForward(
const ChainNode* node) const {
return GetBldBoxingOpConfMethodByFwParallelPolicy(node, this);
}
BldBoxingOpConfMthd LossChainNode::GetMthdForBldBoxingOpConfFromSource(
const ChainNode* node) const {
return GetBldBoxingOpConfMethodByFwParallelPolicy(node, this);
}
std::vector<std::string> LossChainNode::FindLbnsFromForward(
const ChainNode* node) const {
return FindLbnsBetweenFw(node, this);
}
std::vector<std::string> LossChainNode::FindLbnsFromSource(
const ChainNode* node) const {
return FindLbnsBetweenFw(node, this);
}
// LossAccChainNode
BldSubTskGphMthd LossAccChainNode::GetMthdForBldSubTskGphFromLoss(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByOneToOne;
}
// LossRecordChainNode
BldSubTskGphMthd LossRecordChainNode::GetMthdForBldSubTskGphFromLossAcc(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByBoxing;
}
BldBoxingOpConfMthd LossRecordChainNode::GetMthdForBldBoxingOpConfFromLossAcc(
const ChainNode*) const {
return &BoxingTaskNode::BldBoxingOpConfWithAddAndClone;
}
std::vector<std::string> LossRecordChainNode::FindLbnsFromLossAcc(
const ChainNode*) const {
return {kPackedBlobName};
}
// MdUpdtChainNode
BldSubTskGphMthd MdUpdtChainNode::GetMthdForBldSubTskGphFromMdDiffAcc(
const ChainNode*) const {
if (parallel_desc()->policy() == ParallelPolicy::kDataParallel) {
return &TaskGraph::BldSubTskGphByBoxing;
} else if (parallel_desc()->policy() == ParallelPolicy::kModelParallel) {
return &TaskGraph::BldSubTskGphByOneToOne;
} else {
UNEXPECTED_RUN();
}
}
BldBoxingOpConfMthd MdUpdtChainNode::GetMthdForBldBoxingOpConfFromMdDiffAcc(
const ChainNode*) const {
return &BoxingTaskNode::BldBoxingOpConfWithAddAndClone;
}
std::vector<std::string> MdUpdtChainNode::FindLbnsFromMdDiffAcc(
const ChainNode*) const {
return {kPackedBlobName};
}
// MdSaveChainNode
BldSubTskGphMthd MdSaveChainNode::GetMthdForBldSubTskGphFromMdUpdt(
const ChainNode*) const {
if (parallel_desc()->parallel_num() == 1) {
return &TaskGraph::BldSubTskGphBySelectOneSourceToSoleSink;
} else {
return &TaskGraph::BldSubTskGphByOneToOne;
}
}
// MdDiffAccChainNode
BldSubTskGphMthd MdDiffAccChainNode::GetMthdForBldSubTskGphFromBackward(
const ChainNode*) const {
return &TaskGraph::BldSubTskGphByOneToOne;
}
std::vector<std::string> FindLbnsBetween(const ChainNode* in_chain,
const ChainNode* out_chain) {
return in_chain->FindLbnsTo(out_chain);
}
std::string ChainEdge::VisualStr() const { return ""; }
BldSubTskGphMthd ChainEdge::GetMthdForBldSubTskGph() const {
return src_node()->GetMthdForBldSubTskGphTo(dst_node());
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_CHAIN_NODE_H_
#define ONEFLOW_CORE_GRAPH_CHAIN_NODE_H_
#include "oneflow/core/graph/boxing_task_node.h"
#include "oneflow/core/graph/compute_task_node.h"
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
class ChainEdge;
class TaskGraph;
using CompTaskNodeHandler = std::function<void(CompTaskNode*)>;
using BldSubTskGphMthd = void (TaskGraph::*)(
const ChainNode* src_chain, const ChainNode* dst_chain,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
HashMap<const ChainNode*, std::vector<TaskNode*>>* chain2sorted_in_box,
HashMap<const ChainNode*, std::vector<TaskNode*>>* chain2sorted_out_box);
using BldBoxingOpConfMthd = void (BoxingTaskNode::*)(
const std::string& lbn,
const std::vector<BoxingTaskNode::EdgeInfo>& sorted_in_edges,
int64_t in_parallel_num, int64_t in_edge_first, int64_t in_edge_last,
const std::vector<BoxingTaskNode::EdgeInfo>& sorted_out_edges,
int64_t out_parallel_num, int64_t* used_out_edge_begin, BoxingOpConf*);
#define CHAIN_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(Forward) \
OF_PP_MAKE_TUPLE_SEQ(Backward) \
OF_PP_MAKE_TUPLE_SEQ(Source) \
OF_PP_MAKE_TUPLE_SEQ(Loss) \
OF_PP_MAKE_TUPLE_SEQ(LossAcc) \
OF_PP_MAKE_TUPLE_SEQ(LossRecord) \
OF_PP_MAKE_TUPLE_SEQ(MdUpdt) \
OF_PP_MAKE_TUPLE_SEQ(MdSave) \
OF_PP_MAKE_TUPLE_SEQ(MdDiffAcc)
class ChainNode : public Node<ChainNode, ChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainNode);
virtual ~ChainNode() = default;
// op_vec_
std::shared_ptr<const Operator> SoleOp() const;
const std::vector<std::shared_ptr<const Operator>>& op_vec() const;
std::vector<std::shared_ptr<const Operator>>& mut_op_vec() { return op_vec_; }
// parallel_desc_
std::shared_ptr<const ParallelDesc> parallel_desc() const;
std::shared_ptr<const ParallelDesc>& mut_parallel_desc();
// util
virtual const char* TypeName() const = 0;
std::string VisualStr() const;
bool HasOpWithModelOrModelTmpBlob() const;
void GenSortedCompTaskNodes(CompTaskNodeHandler) const;
// To
virtual BldSubTskGphMthd GetMthdForBldSubTskGphTo(const ChainNode*) const = 0;
virtual BldBoxingOpConfMthd GetMthdForBldBoxingOpConfTo(
const ChainNode*) const = 0;
virtual std::vector<std::string> FindLbnsTo(const ChainNode*) const = 0;
// From
#define DECLARE_VIRTUAL_FROM_METHOD(x) \
virtual BldSubTskGphMthd GetMthdForBldSubTskGphFrom##x(const ChainNode*) \
const; \
virtual BldBoxingOpConfMthd GetMthdForBldBoxingOpConfFrom##x( \
const ChainNode*) const; \
virtual std::vector<std::string> FindLbnsFrom##x(const ChainNode*) const;
OF_PP_FOR_EACH_TUPLE(DECLARE_VIRTUAL_FROM_METHOD, CHAIN_TYPE_SEQ);
#undef DECLARE_VIRTUAL_METHOD
protected:
ChainNode() = default;
virtual CompTaskNode* NewCompTaskNode() const = 0;
private:
std::vector<std::shared_ptr<const Operator>> op_vec_;
std::shared_ptr<const ParallelDesc> parallel_desc_;
};
class BackwardChainNode;
#define OVERRIDE_PURE_VIRTUAL_METHOD() \
const char* TypeName() const override; \
BldSubTskGphMthd GetMthdForBldSubTskGphTo(const ChainNode*) const override; \
BldBoxingOpConfMthd GetMthdForBldBoxingOpConfTo(const ChainNode*) \
const override; \
std::vector<std::string> FindLbnsTo(const ChainNode*) const override; \
CompTaskNode* NewCompTaskNode() const override;
#define OVERRIDE_FROM_METHOD(x, y) x##From##y(const ChainNode*) const override;
class ForwardChainNode final : public ChainNode {
public:
OF_DISALLOW_COPY_AND_MOVE(ForwardChainNode);
ForwardChainNode() = default;
~ForwardChainNode() = default;
OVERRIDE_PURE_VIRTUAL_METHOD();
BackwardChainNode* bw_node() const { return bw_node_; }
void set_bw_node(BackwardChainNode* val) { bw_node_ = val; }
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(BldSubTskGphMthd GetMthdForBldSubTskGph),
(Forward)(Source)(MdUpdt));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
OVERRIDE_FROM_METHOD, (BldBoxingOpConfMthd GetMthdForBldBoxingOpConf),
(Forward)(Source));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(std::vector<std::string> FindLbns),
(Forward)(Source));
private:
BackwardChainNode* bw_node_;
};
class BackwardChainNode final : public ChainNode {
public:
OF_DISALLOW_COPY_AND_MOVE(BackwardChainNode);
BackwardChainNode() = default;
~BackwardChainNode() = default;
OVERRIDE_PURE_VIRTUAL_METHOD();
ForwardChainNode* fw_node() const { return fw_node_; }
void set_fw_node(ForwardChainNode* val) { fw_node_ = val; }
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(BldSubTskGphMthd GetMthdForBldSubTskGph),
(Forward)(Backward)(Loss)(MdUpdt));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
OVERRIDE_FROM_METHOD, (BldBoxingOpConfMthd GetMthdForBldBoxingOpConf),
(Backward)(Loss));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(std::vector<std::string> FindLbns),
(Backward)(Loss));
private:
ForwardChainNode* fw_node_;
};
class SourceChainNode final : public ChainNode {
public:
OF_DISALLOW_COPY_AND_MOVE(SourceChainNode);
SourceChainNode() = default;
~SourceChainNode() = default;
OVERRIDE_PURE_VIRTUAL_METHOD();
};
class LossChainNode final : public ChainNode {
public:
OF_DISALLOW_COPY_AND_MOVE(LossChainNode);
LossChainNode() = default;
~LossChainNode() = default;
OVERRIDE_PURE_VIRTUAL_METHOD();
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(BldSubTskGphMthd GetMthdForBldSubTskGph),
(Forward)(Source));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
OVERRIDE_FROM_METHOD, (BldBoxingOpConfMthd GetMthdForBldBoxingOpConf),
(Forward)(Source));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(std::vector<std::string> FindLbns),
(Forward)(Source));
};
class LossAccChainNode final : public ChainNode {
public:
OF_DISALLOW_COPY_AND_MOVE(LossAccChainNode);
LossAccChainNode() = default;
~LossAccChainNode() = default;
OVERRIDE_PURE_VIRTUAL_METHOD();
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(BldSubTskGphMthd GetMthdForBldSubTskGph),
(Loss));
};
class LossRecordChainNode final : public ChainNode {
public:
OF_DISALLOW_COPY_AND_MOVE(LossRecordChainNode);
LossRecordChainNode() = default;
~LossRecordChainNode() = default;
OVERRIDE_PURE_VIRTUAL_METHOD();
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(BldSubTskGphMthd GetMthdForBldSubTskGph),
(LossAcc));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
OVERRIDE_FROM_METHOD, (BldBoxingOpConfMthd GetMthdForBldBoxingOpConf),
(LossAcc));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(std::vector<std::string> FindLbns),
(LossAcc));
};
class MdUpdtChainNode final : public ChainNode {
public:
OF_DISALLOW_COPY_AND_MOVE(MdUpdtChainNode);
MdUpdtChainNode() = default;
~MdUpdtChainNode() = default;
OVERRIDE_PURE_VIRTUAL_METHOD();
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(BldSubTskGphMthd GetMthdForBldSubTskGph),
(MdDiffAcc));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
OVERRIDE_FROM_METHOD, (BldBoxingOpConfMthd GetMthdForBldBoxingOpConf),
(MdDiffAcc));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(std::vector<std::string> FindLbns),
(MdDiffAcc));
};
class MdSaveChainNode final : public ChainNode {
public:
OF_DISALLOW_COPY_AND_MOVE(MdSaveChainNode);
MdSaveChainNode() = default;
~MdSaveChainNode() = default;
OVERRIDE_PURE_VIRTUAL_METHOD();
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(BldSubTskGphMthd GetMthdForBldSubTskGph),
(MdUpdt));
};
class MdDiffAccChainNode final : public ChainNode {
public:
OF_DISALLOW_COPY_AND_MOVE(MdDiffAccChainNode);
MdDiffAccChainNode() = default;
~MdDiffAccChainNode() = default;
OVERRIDE_PURE_VIRTUAL_METHOD();
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(OVERRIDE_FROM_METHOD,
(BldSubTskGphMthd GetMthdForBldSubTskGph),
(Backward));
};
std::vector<std::string> FindLbnsBetween(const ChainNode* in_chain,
const ChainNode* out_chain);
class ChainEdge final : public Edge<ChainNode, ChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainEdge);
ChainEdge() = default;
~ChainEdge() = default;
std::string VisualStr() const override;
BldSubTskGphMthd GetMthdForBldSubTskGph() const;
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_CHAIN_NODE_H_
#include "oneflow/core/graph/comp_task_node.h"
#include "oneflow/core/graph/model_save_task_graph.h"
#include "oneflow/core/graph/model_update_task_graph.h"
#include "oneflow/core/operator/clone_op.h"
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
std::string CompTaskNode::VisualStr() const {
std::stringstream ss;
ss << TaskNode::VisualStr() << "Compute"
<< ":" << stage_node()->machine_id_str() << ":" << thrd_loc_id_str()
<< "\\n"
<< chain_node()->VisualStr();
return ss.str();
}
std::string CompTaskNode::device_name() const {
return IDMgr::Singleton()->MachineName4MachineId(stage_node()->machine_id())
+ ":"
+ std::to_string(
IDMgr::Singleton()->DevPhyId4ThrdLocId(thrd_loc_id()));
}
void SortByParallelId(std::vector<CompTaskNode*>* comp_node_vec) {
std::sort(comp_node_vec->begin(), comp_node_vec->end(),
[](const CompTaskNode* lhs, const CompTaskNode* rhs) {
return lhs->parallel_id() < rhs->parallel_id();
});
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_COMP_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_COMP_TASK_NODE_H_
#include "oneflow/core/graph/task_node.h"
namespace oneflow {
class CompTaskNode : public TaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(CompTaskNode);
CompTaskNode() = default;
virtual ~CompTaskNode() = default;
// Getters and Setters
int64_t parallel_id() const { return parallel_id_; }
void set_parallel_id(int64_t parallel_id) { parallel_id_ = parallel_id; }
bool IsLossNode() const { return chain_node()->IsLossNode(); }
std::string VisualStr() const override;
std::string device_name() const;
virtual void FillProtoWithParallelInfo(
TaskProto* proto,
std::function<int64_t(const ChainNode*)> MeaninglessTaskCnt4Chain) const {
UNEXPECTED_RUN();
}
protected:
virtual void InitWithFwNode(TaskNode* fw_node) override {
TaskNode::InitWithFwNode(fw_node);
auto fw_comp_code = static_cast<CompTaskNode*>(fw_node);
parallel_id_ = fw_comp_code->parallel_id_;
}
private:
int64_t parallel_id_;
};
void SortByParallelId(std::vector<CompTaskNode*>* comp_node_vec);
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_COMP_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
void SortByParallelId(std::vector<CompTaskNode*>* node_vec) {
std::sort(node_vec->begin(), node_vec->end(),
[](const CompTaskNode* lhs, const CompTaskNode* rhs) {
return lhs->parallel_ctx().parallel_id()
< rhs->parallel_ctx().parallel_id();
});
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/task_node.h"
namespace oneflow {
class ChainNode;
class CompTaskNode : public TaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(CompTaskNode);
CompTaskNode() = default;
virtual ~CompTaskNode() = default;
virtual void FixThrdLocId() {}
// parallel_ctx_
int64_t parallel_id() const { return parallel_ctx_.parallel_id(); }
const ParallelContext& parallel_ctx() const { return parallel_ctx_; }
ParallelContext& mut_parallel_ctx() { return parallel_ctx_; }
// chain_node_
const ChainNode* chain_node() const { return chain_node_; }
void set_chain_node(const ChainNode* val) { chain_node_ = val; }
protected:
private:
ParallelContext parallel_ctx_;
const ChainNode* chain_node_;
};
void SortByParallelId(std::vector<CompTaskNode*>* node_vec);
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/operator/copy_comm_net_op.h"
#include "oneflow/core/operator/copy_hd_op.h"
namespace oneflow {
void CopyTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph*) {
auto out_regst = NewProducedRegstDesc("copy_out", 1, kMaxRegisterNum);
BindProducedRegstAndOutEdge(out_regst, SoleOutEdge());
std::shared_ptr<RegstDesc> in_regst = GetRelatedRegst(SoleInEdge());
ConsumeRegstDesc("copy_in", in_regst);
out_regst->CopyLbnFrom(in_regst.get());
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = AddOp();
if (IsFwNode()) {
node->BindBnInOpAndRegst(node->op()->SoleIbn(), in_regst);
node->BindBnInOpAndRegst(node->op()->SoleObn(), out_regst);
} else {
node->BindBnInOpAndRegst(node->op()->SoleOdbn(), in_regst);
node->BindBnInOpAndRegst(node->op()->SoleIdbn(), out_regst);
}
mut_exec_gph().UpdateSourceAndSink();
}
void CopyTaskNode::InferBlobDescInProducedRegsts(TaskGraph*) {
std::shared_ptr<RegstDesc> in_regst = GetRelatedRegst(SoleInEdge());
std::shared_ptr<RegstDesc> out_regst = GetRelatedRegst(SoleOutEdge());
out_regst->CopyBlobDescFrom(in_regst.get());
}
void CopyHDTaskNode::SetFwInCopy() {
CHECK(IsFwNode());
is_fw_in_copy_ = true;
}
void CopyHDTaskNode::SetFwOutCopy() {
CHECK(IsFwNode());
is_fw_in_copy_ = false;
}
std::shared_ptr<Operator> CopyHDTaskNode::AddOp() const {
OperatorConf op_conf;
op_conf.set_name("copy_hd_" + NewUniqueId());
CopyHdOpConf* copy_hd_conf = op_conf.mutable_copy_hd_conf();
copy_hd_conf->set_type(IsH2D() ? CopyHdOpConf::H2D : CopyHdOpConf::D2H);
return OpMgr::Singleton()->AddOp(op_conf);
void CopyHdTaskNode::Init(const CompTaskNode* comp_task,
CopyHdOpConf::Type copy_type) {
set_machine_id(comp_task->machine_id());
set_thrd_loc_id(comp_task->thrd_loc_id());
}
std::shared_ptr<Operator> CopyCommNetTaskNode::AddOp() const {
OperatorConf op_conf;
op_conf.set_name("comm_net_" + NewUniqueId());
op_conf.mutable_copy_comm_net_conf();
return OpMgr::Singleton()->AddOp(op_conf);
void CopyCommNetTaskNode::Init(int64_t machine_id) {
set_machine_id(machine_id);
set_thrd_loc_id(IDMgr::Singleton()->CommNetThrdLocId());
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
......@@ -11,49 +11,22 @@ class CopyTaskNode : public TaskNode {
CopyTaskNode() = default;
virtual ~CopyTaskNode() = default;
protected:
virtual std::shared_ptr<Operator> AddOp() const = 0;
void ProduceAllRegstsAndBindEdges() override { TODO(); }
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void InferBlobDescInProducedRegsts(TaskGraph*) override;
};
class CopyHDTaskNode final : public CopyTaskNode {
class CopyHdTaskNode final : public CopyTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(CopyHDTaskNode);
CopyHDTaskNode() = default;
~CopyHDTaskNode() = default;
OF_DISALLOW_COPY_AND_MOVE(CopyHdTaskNode);
CopyHdTaskNode() = default;
~CopyHdTaskNode() = default;
bool IsH2D() const {
return ((IsFwInCopy() && IsFwNode()) || (IsFwOutCopy() && IsBpNode()));
}
bool IsD2H() const { return !IsH2D(); }
bool IsFwInCopy() const { return is_fw_in_copy_; }
bool IsFwOutCopy() const { return !is_fw_in_copy_; }
void SetFwInCopy();
void SetFwOutCopy();
std::string VisualStr() const override {
return TaskNode::VisualStr() + "CopyHD";
}
TodoTaskType GetTaskType() const override { return TodoTaskType::kCopyHd; }
void ToProto(TaskProto* ret) const override { TaskNode::ToProto(ret); };
void Init(const CompTaskNode*, CopyHdOpConf::Type);
private:
std::shared_ptr<Operator> AddOp() const override;
void InitWithFwNode(TaskNode* fw_node) override {
TaskNode::InitWithFwNode(fw_node);
is_fw_in_copy_ = static_cast<CopyHDTaskNode*>(fw_node)->is_fw_in_copy_;
}
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<CopyHDTaskNode>();
}
TaskType task_type() const override { return kCopyHdTask; }
bool is_fw_in_copy_;
};
class CopyCommNetTaskNode final : public CopyTaskNode {
......@@ -62,25 +35,13 @@ class CopyCommNetTaskNode final : public CopyTaskNode {
CopyCommNetTaskNode() = default;
~CopyCommNetTaskNode() = default;
std::string VisualStr() const override {
return TaskNode::VisualStr() + "CommNet";
TodoTaskType GetTaskType() const override {
return TodoTaskType::kCopyCommNet;
}
void ToProto(TaskProto* ret) const override { TaskNode::ToProto(ret); };
DeviceType GetDeviceType() const override { return DeviceType::kCPU; }
void Init(int64_t machine_id);
private:
std::shared_ptr<Operator> AddOp() const override;
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<CopyCommNetTaskNode>();
}
void InitWithFwNode(TaskNode* fw_node) override {
TaskNode::InitWithFwNode(fw_node);
set_stage_node(fw_node->SoleInEdge()->src_node()->stage_node());
set_task_id();
}
TaskType task_type() const override { return kCopyCommNetTask; }
};
} // namespace oneflow
......
#include "oneflow/core/graph/data_comp_task_node.h"
namespace oneflow {
void DataCompTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
Lbn2NodeBnMap lbn2producer;
Lbn2NodeBnMap extern_in_lbn2consumer;
FwBuildFromUserOps(&lbn2producer, &extern_in_lbn2consumer);
mut_exec_gph().UpdateSourceAndSink();
// Enroll Produced Regsts
if (!out_edges().empty()) {
auto out_regst = NewProducedRegstDesc("out", 1, kMaxRegisterNum);
BindProducedRegstAndOutEdge(out_regst, SoleOutEdge());
}
NewProducedRegstDesc("activation", 1, kMaxRegisterNum);
NewProducedRegstDesc("data_tmp", 1, kMaxRegisterNum);
NewProducedRegstDesc("model_tmp", 1);
NewProducedRegstDesc("model", 3, kMaxRegisterNum);
NewProducedRegstDesc("loss", 1, kMaxRegisterNum);
// Enroll Lbn
FwSetExecNodeFromInRegst(extern_in_lbn2consumer);
FwEnrollLbn2OutRegst(lbn2producer);
FwEnrollLbn2ActivationRegst();
FwEnrollLbn2ModelAndTmpRegsts(); // model model_tmp data_tmp
}
void DataCompTaskNode::FwInferBlobDescInProducedRegsts(TaskGraph*) {
exec_gph().ConstTopoForEachNode([this](const ExecNode* node) {
node->op()->InferBlobDesc4FwBlobs(
node->GetBlobDesc4BnInOpFunc(), chain_node()->parallel_desc()->policy(),
parallel_id(), chain_node()->parallel_desc()->parallel_num());
});
if (IsLossNode()) {
auto out_regst = GetRelatedRegst(SoleOutEdge());
auto in_regst = GetRelatedRegst(SoleInEdge());
out_regst->CopyBlobDescFrom(in_regst.get());
}
}
void DataCompTaskNode::FwBuildFromUserOps(
Lbn2NodeBnMap* lbn2producer, Lbn2NodeBnMap* extern_in_lbn2consumer) {
for (std::shared_ptr<Operator> op : chain_node()->op_vec()) {
ExecNode* cur_node = mut_exec_gph().NewNode();
cur_node->mut_op() = op;
for (const std::string& obn : op->output_bns()) {
const std::string& lbn = op->Lbn4BnInOp(obn);
CHECK(lbn2producer->insert({lbn, {cur_node, obn}}).second);
}
}
mut_exec_gph().ForEachNode([&](ExecNode* cur_node) {
for (const std::string& ibn : cur_node->op()->input_bns()) {
const std::string& lbn = cur_node->op()->Lbn4BnInOp(ibn);
auto producer_it = lbn2producer->find(lbn);
if (producer_it != lbn2producer->end()) {
ExecEdge* edge = mut_exec_gph().NewEdge();
edge->set_lbn(lbn);
edge->mut_src_bn() = producer_it->second.second;
edge->mut_dst_bn() = ibn;
Connect(producer_it->second.first, edge, cur_node);
} else {
CHECK(extern_in_lbn2consumer->insert({lbn, {cur_node, ibn}}).second);
}
}
});
}
void DataCompTaskNode::FwSetExecNodeFromInRegst(
const Lbn2NodeBnMap& extern_in_lbn2consumer) {
if (extern_in_lbn2consumer.empty()) { return; }
std::shared_ptr<RegstDesc> in_regst = GetRelatedRegst(SoleInEdge());
ConsumeRegstDesc("in", in_regst);
for (const auto& pair : extern_in_lbn2consumer) {
ExecNode* node = pair.second.first;
const std::string& ibn = pair.second.second;
node->BindBnInOpAndRegst(ibn, in_regst);
}
}
void DataCompTaskNode::FwEnrollLbn2OutRegst(const Lbn2NodeBnMap& lbn2producer) {
if (IsLossNode()) {
FwEnrollLbn2OutRegstWhenLoss();
} else {
FwEnrollLbn2OutRegstWhenNotLoss(lbn2producer);
}
}
void DataCompTaskNode::FwEnrollLbn2OutRegstWhenLoss() {
ExecNode* exec_node = exec_gph().SoleNode();
// loss regst
std::shared_ptr<RegstDesc> loss_regst = GetProducedRegstDesc("loss");
for (const std::string& obn : exec_node->op()->output_bns()) {
loss_regst->EnrollLbn(exec_node->op()->Lbn4BnInOp(obn));
exec_node->BindBnInOpAndRegst(obn, loss_regst);
}
// out regst
if (!out_edges().empty()) {
std::shared_ptr<RegstDesc> out_regst = GetRelatedRegst(SoleOutEdge());
for (const std::string& idbn : exec_node->op()->input_diff_bns()) {
const std::string& lbn = exec_node->op()->Lbn4BnInOp(idbn);
out_regst->EnrollLbn(lbn);
exec_node->BindBnInOpAndRegst(idbn, out_regst);
}
}
}
void DataCompTaskNode::FwEnrollLbn2OutRegstWhenNotLoss(
const Lbn2NodeBnMap& lbn2producer) {
if (out_edges().empty()) { return; }
std::shared_ptr<RegstDesc> out_regst = GetRelatedRegst(SoleOutEdge());
for (const std::string& lbn : chain_node()->output_lbns()) {
const std::pair<ExecNode*, std::string>& producer = lbn2producer.at(lbn);
ExecNode* node = producer.first;
const std::string& obn = producer.second;
out_regst->EnrollLbn(lbn);
node->BindBnInOpAndRegst(obn, out_regst);
}
}
void DataCompTaskNode::FwEnrollLbn2ActivationRegst() {
auto activation_regst = GetProducedRegstDesc("activation");
exec_gph().ConstForEachEdge([&](const ExecEdge* edge) {
activation_regst->EnrollLbn(edge->lbn());
edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), activation_regst);
edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), activation_regst);
});
}
void DataCompTaskNode::FwEnrollLbn2ModelAndTmpRegsts() {
auto data_tmp_regst = GetProducedRegstDesc("data_tmp");
auto model_tmp_regst = GetProducedRegstDesc("model_tmp");
auto model_regst = GetProducedRegstDesc("model");
ConsumeRegstDesc("model_tmp", model_tmp_regst);
ConsumeRegstDesc("model", model_regst);
mut_exec_gph().ForEachNode([&](ExecNode* node) {
for (const std::string& dtbn : node->op()->data_tmp_bns()) {
const std::string& lbn = node->op()->Lbn4BnInOp(dtbn);
data_tmp_regst->EnrollLbn(lbn);
node->BindBnInOpAndRegst(dtbn, data_tmp_regst);
}
for (const std::string& mtbn : node->op()->model_tmp_bns()) {
const std::string& lbn = node->op()->Lbn4BnInOp(mtbn);
model_tmp_regst->EnrollLbn(lbn);
node->BindBnInOpAndRegst(mtbn, model_tmp_regst);
}
for (const std::string& mbn : node->op()->model_bns()) {
const std::string& lbn = node->op()->Lbn4BnInOp(mbn);
model_regst->EnrollLbn(lbn);
node->BindBnInOpAndRegst(mbn, model_regst);
}
});
}
void DataCompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
BpBuildExecGraph();
// New produced registers
auto in_diff_regst = NewProducedRegstDesc("in_diff", 1, kMaxRegisterNum);
if (!out_edges().empty()) {
BindProducedRegstAndOutEdge(in_diff_regst, SoleOutEdge());
}
NewProducedRegstDesc("model_diff", 1, kMaxRegisterNum);
NewProducedRegstDesc("activation_diff", 1);
// Subscribe
ConsumeRegstDesc("activation",
GetFwNode()->GetProducedRegstDesc("activation"));
ConsumeRegstDesc("data_tmp", GetFwNode()->GetProducedRegstDesc("data_tmp"));
ConsumeRegstDesc("model", GetFwNode()->GetConsumedRegstDesc("model"));
ConsumeRegstDesc("model_tmp", GetFwNode()->GetConsumedRegstDesc("model_tmp"));
ConsumeRegstDesc("in", GetFwNode()->GetConsumedRegstDesc("in"));
ConsumeRegstDesc("out_diff", GetRelatedRegst(SoleInEdge()));
ConsumeRegstDesc("out", GetFwNode()->GetProducedRegstDesc("out"));
// Enroll Lbn
BpEnrollLbn2ProducedRegst();
}
void DataCompTaskNode::BpInferBlobDescInProducedRegsts(TaskGraph*) {
// in_diff_regst
auto in_diff_regst = GetProducedRegstDesc("in_diff");
auto in_regst = GetRelatedRegst(GetFwNode()->SoleInEdge());
in_diff_regst->CopyBlobDescFrom(in_regst.get());
// model_diff_regst
if (auto md_diff_regst = GetProducedRegstDesc("model_diff")) {
md_diff_regst->CopyBlobDescFrom(
GetFwNode()->GetConsumedRegstDesc("model").get());
}
// activation_diff_regst
if (auto acti_diff_regst = GetProducedRegstDesc("activation_diff")) {
auto acti_regst = GetFwNode()->GetProducedRegstDesc("activation");
acti_diff_regst->CopyBlobDescFrom(acti_regst.get());
}
}
void DataCompTaskNode::BpBuildExecGraph() {
const ExecGraph& fw_gph = GetFwNode()->exec_gph();
HashMap<const ExecNode*, ExecNode*> fw_node2bp_node;
fw_gph.ConstForEachNode([&](const ExecNode* fw_node) {
ExecNode* bp_node = mut_exec_gph().NewNode();
bp_node->mut_op() = fw_node->op();
CHECK(fw_node2bp_node.emplace(fw_node, bp_node).second);
});
fw_gph.ConstForEachEdge([&](const ExecEdge* fw_edge) {
ExecEdge* bp_edge = mut_exec_gph().NewEdge();
bp_edge->set_lbn(fw_edge->lbn());
bp_edge->mut_src_bn() = GenDiffBn(fw_edge->dst_bn());
bp_edge->mut_dst_bn() = GenDiffBn(fw_edge->src_bn());
Connect(fw_node2bp_node.at(fw_edge->dst_node()), bp_edge,
fw_node2bp_node.at(fw_edge->src_node()));
});
mut_exec_gph().UpdateSourceAndSink();
}
void DataCompTaskNode::BpEnrollLbn2ProducedRegst() {
BpEnrollLbn2ActivationDiffRegst();
BpSetExecNodeFromOutDiffRegst();
BpEnrollLbn2InDiffRegst();
BpEnrollLbn2ModelDiffRegst();
}
void DataCompTaskNode::BpEnrollLbn2ActivationDiffRegst() {
auto activation_regst = GetFwNode()->GetProducedRegstDesc("activation");
auto activation_diff_regst = GetProducedRegstDesc("activation_diff");
activation_diff_regst->CopyLbnFrom(activation_regst.get());
exec_gph().ConstForEachEdge([&](const ExecEdge* edge) {
edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), activation_diff_regst);
edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), activation_diff_regst);
edge->src_node()->BindBnInOpAndRegst(GenUnDiffBn(edge->src_bn()),
activation_regst);
edge->dst_node()->BindBnInOpAndRegst(GenUnDiffBn(edge->dst_bn()),
activation_regst);
});
}
void DataCompTaskNode::BpSetExecNodeFromOutDiffRegst() {
auto out_diff_regst = GetRelatedRegst(SoleInEdge());
auto out_regst = GetFwNode()->GetProducedRegstDesc("out");
mut_exec_gph().ForEachNode([&](ExecNode* bp_node) {
std::unordered_set<std::string> found_bns;
for (ExecEdge* edge : bp_node->in_edges()) {
found_bns.insert(edge->dst_bn());
}
for (const std::string& odbn : bp_node->op()->output_diff_bns()) {
if (found_bns.find(odbn) != found_bns.end()) { continue; }
bp_node->BindBnInOpAndRegst(odbn, out_diff_regst);
bp_node->BindBnInOpAndRegst(GenUnDiffBn(odbn), out_regst);
}
});
}
void DataCompTaskNode::BpEnrollLbn2InDiffRegst() {
auto in_regst = GetRelatedRegst(GetFwNode()->SoleInEdge());
auto in_diff_regst = GetProducedRegstDesc("in_diff");
mut_exec_gph().ForEachNode([&](ExecNode* bp_node) {
std::unordered_set<std::string> found_bns;
for (ExecEdge* edge : bp_node->out_edges()) {
found_bns.insert(edge->src_bn());
}
for (const std::string& idbn : bp_node->op()->input_diff_bns()) {
if (found_bns.find(idbn) != found_bns.end()) { continue; }
const std::string& lbn = bp_node->op()->Lbn4BnInOp(idbn);
in_diff_regst->EnrollLbn(lbn);
bp_node->BindBnInOpAndRegst(idbn, in_diff_regst);
bp_node->BindBnInOpAndRegst(GenUnDiffBn(idbn), in_regst);
}
});
}
void DataCompTaskNode::BpEnrollLbn2ModelDiffRegst() {
auto data_tmp_regst = GetFwNode()->GetProducedRegstDesc("data_tmp");
auto model_tmp_regst = GetFwNode()->GetProducedRegstDesc("model_tmp");
auto model_diff_regst = GetProducedRegstDesc("model_diff");
auto model_regst = GetConsumedRegstDesc("model");
mut_exec_gph().ForEachNode([&](ExecNode* node) {
for (const std::string& dtbn : node->op()->data_tmp_bns()) {
node->BindBnInOpAndRegst(dtbn, data_tmp_regst);
}
for (const std::string& mtbn : node->op()->model_tmp_bns()) {
node->BindBnInOpAndRegst(mtbn, model_tmp_regst);
}
for (const std::string& mdbn : node->op()->model_diff_bns()) {
const std::string& lbn = node->op()->Lbn4BnInOp(mdbn);
model_diff_regst->EnrollLbn(lbn);
node->BindBnInOpAndRegst(mdbn, model_diff_regst);
}
for (const std::string& mbn : node->op()->model_bns()) {
node->BindBnInOpAndRegst(mbn, model_regst);
}
});
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_DATA_COMP_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_DATA_COMP_TASK_NODE_H_
#include "oneflow/core/graph/comp_task_node.h"
namespace oneflow {
class DataCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(DataCompTaskNode);
DataCompTaskNode() = default;
~DataCompTaskNode() = default;
void ToProto(TaskProto* proto, std::function<int64_t(const ChainNode*)>
MeaninglessTaskCnt4Chain) const override {
TaskNode::ToProto(proto, MeaninglessTaskCnt4Chain);
FillProtoWithParallelInfo(proto, MeaninglessTaskCnt4Chain);
}
void FillProtoWithParallelInfo(TaskProto* proto,
std::function<int64_t(const ChainNode*)>
MeaninglessTaskCnt4Chain) const override {
auto parallel_desc = chain_node()->parallel_desc();
proto->set_parallel_policy(parallel_desc->policy());
proto->set_parallel_id(parallel_id());
proto->set_parallel_num(parallel_desc->parallel_num()
- MeaninglessTaskCnt4Chain(chain_node()));
}
bool IsMeaningLess() const override {
if (IsFwNode()) {
if (chain_node()->IsRecordNode()) {
return false;
} else {
return TaskNode::IsMeaningLess();
}
} else {
return TaskNode::IsMeaningLess() || GetFwNode()->IsMeaningLess();
}
}
private:
OVERRIDE_IF_FW_BP_FOR_FUNC(BuildExecAndEnrollLbn2Regsts);
OVERRIDE_IF_FW_BP_FOR_FUNC(InferBlobDescInProducedRegsts);
using Lbn2NodeBnMap = HashMap<std::string, std::pair<ExecNode*, std::string>>;
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph);
void FwInferBlobDescInProducedRegsts(TaskGraph* gph);
void FwBuildFromUserOps(Lbn2NodeBnMap* lbn2producer,
Lbn2NodeBnMap* extern_in_lbn2consumer);
void FwSetExecNodeFromInRegst(const Lbn2NodeBnMap& extern_in_lbn2consumer);
void FwEnrollLbn2OutRegst(const Lbn2NodeBnMap& lbn2producer);
void FwEnrollLbn2OutRegstWhenLoss();
void FwEnrollLbn2OutRegstWhenNotLoss(const Lbn2NodeBnMap& lbn2producer);
void FwEnrollLbn2ActivationRegst();
void FwEnrollLbn2ModelAndTmpRegsts();
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void BpInferBlobDescInProducedRegsts(TaskGraph*);
void BpBuildExecGraph();
void BpEnrollLbn2ProducedRegst();
void BpEnrollLbn2ActivationDiffRegst();
void BpSetExecNodeFromOutDiffRegst();
void BpEnrollLbn2InDiffRegst();
void BpEnrollLbn2ModelDiffRegst();
TaskType task_type() const override { return kDataCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<DataCompTaskNode>();
}
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_DATA_COMP_TASK_NODE_H_
#include "oneflow/core/graph/data_task_graph.h"
namespace oneflow {
class DataCompTaskNode;
DataTaskGraph::DataTaskGraph(const std::string& name,
const DLNetConf& dl_net_conf,
const Placement& placement, bool need_bp) {
mut_name() = name;
LogicalGraph logical_gph(dl_net_conf, placement);
auto chain_gph = of_make_unique<ChainGraph>(&logical_gph);
BuildFromChainGph<DataCompTaskNode>(std::move(chain_gph), need_bp);
BuildExecAndEnrollLbn2Regsts();
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_DATA_TASK_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_DATA_TASK_GRAPH_H_
#include "oneflow/core/graph/task_graph.h"
namespace oneflow {
class DataTaskGraph final : public TaskGraph {
public:
OF_DISALLOW_COPY_AND_MOVE(DataTaskGraph);
DataTaskGraph() = delete;
~DataTaskGraph() = default;
DataTaskGraph(const std::string& name, const DLNetConf& dl_net_conf,
const Placement& placement, bool need_bp);
const char* TypeName() const override { return "DataTaskGraph"; }
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_DATA_TASK_GRAPH_H_
......@@ -2,24 +2,14 @@
namespace oneflow {
void ExecEdge::set_lbn(const std::string& lbn) { lbn_ = lbn; }
void ExecNode::BindBnInOpAndRegst(const std::string& bn_in_op,
std::weak_ptr<RegstDesc> regst) {
CHECK(bn_in_op2regst_.emplace(bn_in_op, regst).second);
}
std::function<BlobDesc*(const std::string&)> ExecNode::GetBlobDesc4BnInOpFunc()
const {
return [this](const std::string& bn_in_op) {
return GetBlobDesc4BnInOp(bn_in_op);
};
}
void ExecNode::GetBnInOp2DataType(
google::protobuf::Map<std::string, DataType>* pbmap) const {
for (const auto& pair : bn_in_op2regst_) {
const std::string& bn_in_op = pair.first;
BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn_in_op);
if (blob_desc) {
CHECK(pbmap->insert({bn_in_op, blob_desc->data_type()}).second);
}
}
return std::bind(&ExecNode::GetBlobDesc4BnInOp, this, std::placeholders::_1);
}
void ExecNode::ToProto(ExecNodeProto* ret) const {
......@@ -38,15 +28,11 @@ BlobDesc* ExecNode::GetBlobDesc4BnInOp(const std::string& bn_in_op) const {
if (it == this->bn_in_op2regst_.end()) { return nullptr; }
std::shared_ptr<RegstDesc> regst = it->second.lock();
const std::string& lbn = this->op()->Lbn4BnInOp(bn_in_op);
return regst->GetMutBlobDesc(lbn);
return regst->MutBlobDesc(lbn);
}
void ExecGraph::ToExecSequence(ExecSequence* ret) const {
ConstTopoForEachNode([&](const ExecNode* node) {
if (!node->bn_in_op2regst().empty()) {
node->ToProto(ret->add_exec_node());
}
});
TopoForEachNode([&](ExecNode* node) { node->ToProto(ret->add_exec_node()); });
}
} // namespace oneflow
......@@ -23,7 +23,7 @@ class ExecEdge final : public Edge<ExecNode, ExecEdge> {
const std::string& dst_bn() const { return dst_bn_; }
// Setters
void set_lbn(const std::string& lbn);
void set_lbn(const std::string& lbn) { lbn_ = lbn; }
std::string& mut_src_bn() { return src_bn_; }
std::string& mut_dst_bn() { return dst_bn_; }
......@@ -43,23 +43,11 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
std::shared_ptr<Operator> op() const { return op_; }
std::shared_ptr<Operator>& mut_op() { return op_; }
void BindBnInOpAndRegst(const std::string& bn_in_op,
std::weak_ptr<RegstDesc> regst) {
CHECK(bn_in_op2regst_.emplace(bn_in_op, regst).second);
}
std::shared_ptr<RegstDesc> GetRegstFromBnInOp(
const std::string& bn_in_op) const {
return bn_in_op2regst_.at(bn_in_op).lock();
}
const HashMap<std::string, std::weak_ptr<RegstDesc>>& bn_in_op2regst() const {
return bn_in_op2regst_;
}
void BindBnInOpAndRegst(const std::string&, std::weak_ptr<RegstDesc>);
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOpFunc() const;
void GetBnInOp2DataType(google::protobuf::Map<std::string, DataType>*) const;
std::string VisualStr() const { return op_->op_name(); }
std::string VisualStr() const override { return op_->op_name(); }
void ToProto(ExecNodeProto* ret) const;
private:
......
#include "oneflow/core/graph/forward_compute_task_node.h"
namespace oneflow {
void ForwardCompTaskNode::ProduceAllRegstsAndBindEdges() {
ProduceRegst("out", 1, kMaxRegisterNum);
ProduceRegst("activation", 1, kMaxRegisterNum);
ProduceRegst("data_tmp", 1, kMaxRegisterNum);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_FORWARD_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_FORWARD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class ForwardCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(ForwardCompTaskNode);
ForwardCompTaskNode() = default;
~ForwardCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
TodoTaskType GetTaskType() const override { return TodoTaskType::kForward; }
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_FORWARD_COMPUTE_TASK_NODE_H_
#ifndef ONEFLOW_CORE_GRAPH_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_GRAPH_H_
#include <gflags/gflags.h>
#include "oneflow/core/common/str_util.h"
#include "oneflow/core/graph/node.h"
#include "oneflow/core/persistence/persistent_out_stream.h"
......@@ -15,17 +14,13 @@ class Graph {
Graph() = default;
virtual ~Graph() = default;
// For Each Node
void ForEachNode(std::function<void(NodeType*)>);
void TopoForEachNode(std::function<void(NodeType*)>);
void ReverseTopoForEachNode(std::function<void(NodeType*)>);
void ConstForEachNode(std::function<void(const NodeType*)>) const;
void ConstTopoForEachNode(std::function<void(const NodeType*)>) const;
void ConstReverseTopoForEachNode(std::function<void(const NodeType*)>) const;
// For Each Edge
void ForEachEdge(std::function<void(EdgeType*)>);
void ConstForEachEdge(std::function<void(const EdgeType*)>) const;
// For Each
void ForEachNode(std::function<void(NodeType*)> NodeHandler) const;
void ForEachNode(std::function<void(NodeType*)> NodeHandler,
std::function<bool(NodeType*)> IsNodeReady) const;
void TopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
void ReverseTopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
void ForEachEdge(std::function<void(EdgeType*)> EdgeHandler) const;
// Getters
const std::unordered_set<NodeType*>& source_nodes() const;
......@@ -35,150 +30,90 @@ class Graph {
NodeType* SoleNode() const;
size_t node_num() const { return nodes_.size(); }
size_t edge_num() const { return edges_.size(); }
virtual const char* TypeName() const { return "Not Defined"; }
virtual const char* TypeName() const { return ""; }
// Setters
NodeType* NewNode();
template<typename DerivedNodeType = NodeType>
DerivedNodeType* NewNode();
EdgeType* NewEdge();
void EnrollNode(NodeType*);
void EnrollNode(std::unique_ptr<NodeType>&&);
void EnrollEdge(EdgeType*);
void EnrollEdge(std::unique_ptr<EdgeType>&&);
void UpdateSourceAndSink();
void AddAllocatedNode(NodeType*);
void AddAllocatedEdge(EdgeType*);
// ToDot
template<typename StreamT>
void ToDotWithStream(StreamT& out_stream) const;
void ToDotWithFilePath(const std::string& file_path) const;
void ToDotWithAutoFilePath() const;
void ToDotWithStream(StreamT& out_stream);
void ToDotWithFilePath(const std::string& file_path);
void ToDotWithAutoFilePath();
private:
class TopoIterator;
class ReverseTopoIterator;
TopoIterator begin() { return source_nodes_; }
TopoIterator end() { return std::unordered_set<NodeType*>(); }
ReverseTopoIterator rbegin() { return sink_nodes_; }
ReverseTopoIterator rend() { return std::unordered_set<NodeType*>(); }
//
std::unordered_set<NodeType*> source_nodes_;
std::unordered_set<NodeType*> sink_nodes_;
std::vector<std::unique_ptr<NodeType>> nodes_;
std::vector<std::unique_ptr<EdgeType>> edges_;
};
template<typename NodeType, typename EdgeType>
class Graph<NodeType, EdgeType>::TopoIterator final {
public:
// OF_DISALLOW_COPY_AND_MOVE(TopoIterator);
TopoIterator() = default;
~TopoIterator() = default;
TopoIterator(const std::unordered_set<NodeType*>& source_nodes) {
for (NodeType* node : source_nodes) { bfs_queue_.push(node); }
}
NodeType& operator*() { return *(bfs_queue_.front()); }
NodeType* operator->() { return &(*(*this)); }
TopoIterator& operator++();
bool operator!=(const TopoIterator&) const;
private:
std::queue<NodeType*> bfs_queue_;
HashMap<NodeType*, int32_t> visited_cnt_;
};
template<typename NodeType, typename EdgeType>
class Graph<NodeType, EdgeType>::ReverseTopoIterator final {
public:
// OF_DISALLOW_COPY_AND_MOVE(ReverseTopoIterator);
ReverseTopoIterator() = default;
~ReverseTopoIterator() = default;
ReverseTopoIterator(const std::unordered_set<NodeType*>& sink_nodes) {
for (NodeType* node : sink_nodes) { bfs_queue_.push(node); }
}
NodeType& operator*() { return *(bfs_queue_.front()); }
NodeType* operator->() { return &(*(*this)); }
ReverseTopoIterator& operator++();
bool operator!=(const ReverseTopoIterator&) const;
private:
std::queue<NodeType*> bfs_queue_;
HashMap<NodeType*, int32_t> visited_cnt_;
};
void Graph<NodeType, EdgeType>::ForEachNode(
std::function<void(NodeType*)> NodeHandler) const {
for (auto& x : nodes_) { NodeHandler(x.get()); }
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ForEachNode(
std::function<void(NodeType*)> func) {
for (auto& x : nodes_) { func(x.get()); }
std::function<void(NodeType*)> NodeHandler,
std::function<bool(NodeType*)> IsNodeReady) const {
std::queue<NodeType*> node_queue;
HashSet<NodeType*> nodes_pushed;
for (auto& x : nodes_) {
if (IsNodeReady(x.get())) {
node_queue.push(x.get());
CHECK(nodes_pushed.insert(x.get()).second);
}
}
while (node_queue.empty() == false) {
NodeType* cur_node = node_queue.front();
node_queue.pop();
NodeHandler(cur_node);
cur_node->ForEachNodeOnInOutEdge([&](NodeType* candidate) {
if (nodes_pushed.find(candidate) == nodes_pushed.end()
&& IsNodeReady(candidate)) {
node_queue.push(candidate);
CHECK(nodes_pushed.insert(candidate).second);
}
});
}
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNode(
std::function<void(NodeType*)> func) {
for (TopoIterator it = begin(); it != end(); ++it) { func(&(*it)); }
std::function<void(NodeType*)> NodeHandler) const {
HashMap<NodeType*, size_t> node2cnt;
auto IncreaseCnt = [&](NodeType* node) { node2cnt[node] += 1; };
auto MyNodeHandler = [&](NodeType* node) {
NodeHandler(node);
node->ForEachNodeOnOutEdge(IncreaseCnt);
};
ForEachNode(MyNodeHandler, [&](NodeType* node) {
return node->in_edges().size() == node2cnt[node];
});
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ReverseTopoForEachNode(
std::function<void(NodeType*)> func) {
for (ReverseTopoIterator it = rbegin(); it != rend(); ++it) { func(&(*it)); }
std::function<void(NodeType*)> NodeHandler) const {
HashMap<NodeType*, size_t> node2cnt;
auto IncreaseCnt = [&](NodeType* node) { node2cnt[node] += 1; };
auto MyNodeHandler = [&](NodeType* node) {
NodeHandler(node);
node->ForEachNodeOnInEdge(IncreaseCnt);
};
ForEachNode(MyNodeHandler, [&](NodeType* node) {
return node->out_edges().size() == node2cnt[node];
});
}
#define OF_DEFINE_CONST_FOR_EACH_NODE(FuncName) \
template<typename NodeType, typename EdgeType> \
void Graph<NodeType, EdgeType>::Const##FuncName( \
std::function<void(const NodeType*)> func) const { \
auto cast_this = const_cast<Graph<NodeType, EdgeType>*>(this); \
cast_this->FuncName(std::bind(func, std::placeholders::_1)); \
}
OF_DEFINE_CONST_FOR_EACH_NODE(ForEachNode);
OF_DEFINE_CONST_FOR_EACH_NODE(TopoForEachNode);
OF_DEFINE_CONST_FOR_EACH_NODE(ReverseTopoForEachNode);
#undef OF_DEFINE_CONST_FOR_EACH_NODE
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ForEachEdge(
std::function<void(EdgeType*)> func) {
for (auto& x : edges_) { func(x.get()); }
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ConstForEachEdge(
std::function<void(const EdgeType*)> func) const {
auto cast_this = const_cast<Graph<NodeType, EdgeType>*>(this);
cast_this->ForEachEdge(std::bind(func, std::placeholders::_1));
}
template<typename NodeType, typename EdgeType>
const std::unordered_set<NodeType*>& Graph<NodeType, EdgeType>::source_nodes()
const {
return source_nodes_;
}
template<typename NodeType, typename EdgeType>
const std::unordered_set<NodeType*>& Graph<NodeType, EdgeType>::sink_nodes()
const {
return sink_nodes_;
}
template<typename NodeType, typename EdgeType>
NodeType* Graph<NodeType, EdgeType>::SoleSourceNode() const {
CHECK_EQ(source_nodes_.size(), 1);
return *(source_nodes_.begin());
}
template<typename NodeType, typename EdgeType>
NodeType* Graph<NodeType, EdgeType>::SoleSinkNode() const {
CHECK_EQ(sink_nodes_.size(), 1);
return *(sink_nodes_.begin());
std::function<void(EdgeType*)> EdgeHandler) const {
for (auto& x : edges_) { EdgeHandler(x.get()); }
}
template<typename NodeType, typename EdgeType>
......@@ -188,57 +123,38 @@ NodeType* Graph<NodeType, EdgeType>::SoleNode() const {
}
template<typename NodeType, typename EdgeType>
NodeType* Graph<NodeType, EdgeType>::NewNode() {
NodeType* ret = new NodeType;
EnrollNode(ret);
template<typename DerivedNodeType>
DerivedNodeType* Graph<NodeType, EdgeType>::NewNode() {
DerivedNodeType* ret = new DerivedNodeType;
AddAllocatedNode(ret);
return ret;
}
template<typename NodeType, typename EdgeType>
EdgeType* Graph<NodeType, EdgeType>::NewEdge() {
EdgeType* ret = new EdgeType;
EnrollEdge(ret);
AddAllocatedEdge(ret);
return ret;
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::EnrollNode(NodeType* node) {
void Graph<NodeType, EdgeType>::AddAllocatedNode(NodeType* node) {
nodes_.emplace_back(node);
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::EnrollNode(std::unique_ptr<NodeType>&& node) {
nodes_.push_back(std::move(node));
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::EnrollEdge(EdgeType* edge) {
void Graph<NodeType, EdgeType>::AddAllocatedEdge(EdgeType* edge) {
edges_.emplace_back(edge);
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::EnrollEdge(std::unique_ptr<EdgeType>&& edge) {
edges_.push_back(std::move(edge));
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::UpdateSourceAndSink() {
source_nodes_.clear();
sink_nodes_.clear();
for (const std::unique_ptr<NodeType>& node : nodes_) {
if (node->in_edges().empty()) { source_nodes_.insert(node.get()); }
if (node->out_edges().empty()) { sink_nodes_.insert(node.get()); }
}
}
template<typename NodeType, typename EdgeType>
template<typename StreamT>
void Graph<NodeType, EdgeType>::ToDotWithStream(StreamT& out_stream) const {
void Graph<NodeType, EdgeType>::ToDotWithStream(StreamT& out_stream) {
out_stream << "digraph {\n";
this->ConstForEachNode([&](const NodeType* node) {
this->ForEachNode([&](NodeType* node) {
out_stream << "\"" << node->VisualStr() << "\"\n";
});
this->ConstForEachEdge([&](const EdgeType* edge) {
this->ForEachEdge([&](const EdgeType* edge) {
out_stream << "\"" << edge->src_node()->VisualStr() << "\" -> "
<< "\"" << edge->dst_node()->VisualStr() << "\""
<< "[label=\"" << edge->VisualStr() << "\"];\n";
......@@ -248,7 +164,7 @@ void Graph<NodeType, EdgeType>::ToDotWithStream(StreamT& out_stream) const {
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ToDotWithFilePath(
const std::string& file_path) const {
const std::string& file_path) {
std::string dir_name = Dirname(file_path);
if (!LocalFS()->IsDirectory(dir_name)) {
LocalFS()->RecursivelyCreateDir(dir_name);
......@@ -258,63 +174,12 @@ void Graph<NodeType, EdgeType>::ToDotWithFilePath(
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ToDotWithAutoFilePath() const {
void Graph<NodeType, EdgeType>::ToDotWithAutoFilePath() {
std::string file_path =
LogDir() + "/dot/" + TypeName() + "/" + NewUniqueId() + ".dot";
ToDotWithFilePath(file_path);
}
template<typename NodeType>
bool IsNotEqual4BfsQueue(const std::queue<NodeType*>& lhs,
const std::queue<NodeType*>& rhs) {
if (lhs.empty() != rhs.empty()) { return true; }
if (lhs.empty() == false && rhs.empty() == false) {
return lhs.front() != rhs.front();
}
return false;
}
template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::TopoIterator::operator++() -> TopoIterator& {
NodeType* cur_node = bfs_queue_.front();
bfs_queue_.pop();
for (EdgeType* out_edge : cur_node->out_edges()) {
NodeType* dst_node = out_edge->dst_node();
visited_cnt_[dst_node] += 1;
if (visited_cnt_.at(dst_node) == dst_node->in_edges().size()) {
bfs_queue_.push(dst_node);
}
}
return *this;
}
template<typename NodeType, typename EdgeType>
bool Graph<NodeType, EdgeType>::TopoIterator::operator!=(
const TopoIterator& rhs) const {
return IsNotEqual4BfsQueue(bfs_queue_, rhs.bfs_queue_);
}
template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::ReverseTopoIterator::operator++()
-> ReverseTopoIterator& {
NodeType* cur_node = bfs_queue_.front();
bfs_queue_.pop();
for (EdgeType* in_edge : cur_node->in_edges()) {
NodeType* src_node = in_edge->src_node();
visited_cnt_[src_node] += 1;
if (visited_cnt_.at(src_node) == src_node->out_edges().size()) {
bfs_queue_.push(src_node);
}
}
return *this;
}
template<typename NodeType, typename EdgeType>
bool Graph<NodeType, EdgeType>::ReverseTopoIterator::operator!=(
const ReverseTopoIterator& rhs) const {
return IsNotEqual4BfsQueue(bfs_queue_, rhs.bfs_queue_);
}
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_GRAPH_H_
#include "oneflow/core/graph/graph.h"
namespace oneflow {
class TestEdge;
class TestNode final : public Node<TestNode, TestEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(TestNode);
TestNode(int64_t node_id_) { test_node_id_ = node_id_; }
~TestNode() = default;
int64_t test_node_id() const { return test_node_id_; }
private:
int64_t test_node_id_;
};
class TestEdge final : public Edge<TestNode, TestEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(TestEdge);
TestEdge() = default;
~TestEdge() = default;
private:
};
class TestGraph final : public Graph<TestNode, TestEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(TestGraph);
TestGraph() = delete;
~TestGraph() = default;
TestGraph(const std::vector<std::vector<int64_t>>& graph_conf) {
std::vector<TestNode*> node_id2node;
for (size_t i = 0; i < graph_conf.size(); ++i) {
TestNode* cur_node = new TestNode(i);
EnrollNode(cur_node);
node_id2node.push_back(cur_node);
}
for (size_t i = 0; i < graph_conf.size(); ++i) {
TestNode* src_node = node_id2node[i];
for (size_t j = 0; j < graph_conf[i].size(); ++j) {
TestEdge* edge = NewEdge();
TestNode* dst_node = node_id2node[graph_conf[i][j]];
Connect(src_node, edge, dst_node);
}
}
UpdateSourceAndSink();
}
};
using NodeIdPair = std::pair<int64_t, int64_t>;
void DoOneTestGraph(const TestGraph& test_graph,
const std::vector<std::vector<int64_t>>& graph_conf) {
int64_t node_num = graph_conf.size();
// 1. Determines whether the traversal result satisfies the topological order
HashMap<int64_t, int64_t> node_id2order, node_id2rorder;
auto NodePairHash = [](const NodeIdPair& val) {
return val.first ^ val.second;
};
std::unordered_set<NodeIdPair, decltype(NodePairHash)> edges_node_pair(
11, NodePairHash);
int64_t order = 0;
test_graph.ConstTopoForEachNode([&](const TestNode* node) {
node_id2order.emplace(node->test_node_id(), order);
++order;
});
ASSERT_EQ(node_id2order.size(), node_num);
order = 0;
test_graph.ConstReverseTopoForEachNode([&](const TestNode* node) {
node_id2rorder.emplace(node->test_node_id(), order);
++order;
});
ASSERT_EQ(node_id2rorder.size(), node_num);
// method :
// judge every directed edge <u,v>
// the node u's order is smaller than v
int64_t edge_num = 0;
for (int64_t src_node_id = 0; src_node_id < node_num; ++src_node_id) {
for (int64_t dst_node_id : graph_conf[src_node_id]) {
// check topo order
int64_t src_ord = node_id2order.at(src_node_id);
int64_t dst_ord = node_id2order.at(dst_node_id);
ASSERT_LT(src_ord, dst_ord);
// check reverse-topo order
src_ord = node_id2rorder.at(src_node_id);
dst_ord = node_id2rorder.at(dst_node_id);
ASSERT_GE(src_ord, dst_ord);
//
++edge_num;
edges_node_pair.insert(std::make_pair(src_node_id, dst_node_id));
}
}
// 2. judge whether the getter method of Graph can return all nodes and edges
ASSERT_EQ(test_graph.node_num(), node_num);
ASSERT_EQ(test_graph.edge_num(), edge_num);
std::unordered_set<int64_t> node_ids;
test_graph.ConstForEachNode([&](const TestNode* cur_node) {
int64_t cur_node_id = cur_node->test_node_id();
ASSERT_TRUE(node_ids.insert(cur_node_id).second);
ASSERT_LT(cur_node_id, node_num);
ASSERT_GE(cur_node_id, 0);
});
test_graph.ConstForEachEdge([&](const TestEdge* cur_edge) {
int64_t src_node_id = cur_edge->src_node()->test_node_id();
int64_t dst_node_id = cur_edge->dst_node()->test_node_id();
ASSERT_TRUE(edges_node_pair.count(std::make_pair(src_node_id, dst_node_id))
> 0);
});
}
TEST(TestGraph, test_graph_node_num_7) {
std::vector<std::vector<int64_t>> graph_conf;
for (int64_t i = 0; i < 7; ++i) {
graph_conf.push_back(std::vector<int64_t>());
}
graph_conf[2].push_back(1);
graph_conf[2].push_back(0);
graph_conf[2].push_back(3);
graph_conf[1].push_back(0);
graph_conf[0].push_back(4);
graph_conf[5].push_back(6);
TestGraph test_graph(graph_conf);
DoOneTestGraph(test_graph, graph_conf);
}
} // namespace oneflow
#include "oneflow/core/graph/in_boxing_task_node.h"
namespace oneflow {
void InBoxingTaskNode::FwVirtualBuild() {
Chain2EdgesMap chain2sorted_in_edges;
FwInitChain2SortedEdgesMaps(&chain2sorted_in_edges, &TaskNode::in_edges,
&TaskEdge::src_node, &TaskNode::SoleInEdge);
ChainEdgesPair chain_sorted_out_edges;
chain_sorted_out_edges.first = chain_node();
chain_sorted_out_edges.second.assign(out_edges().begin(), out_edges().end());
FwSortEdgesInnerStage(&chain_sorted_out_edges.second, &TaskEdge::dst_node,
&TaskNode::SoleOutEdge);
for (const ChainEdgesPair& chain_sorted_in_edges : chain2sorted_in_edges) {
FwBuildChainSortedEdgesPair(chain_sorted_in_edges, chain_sorted_out_edges);
}
mut_exec_gph().UpdateSourceAndSink();
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_IN_BOXING_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_IN_BOXING_TASK_NODE_H_
#include "oneflow/core/graph/boxing_task_node.h"
namespace oneflow {
class InBoxingTaskNode final : public BoxingTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(InBoxingTaskNode);
InBoxingTaskNode() = default;
~InBoxingTaskNode() = default;
private:
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<InBoxingTaskNode>();
}
void InitWithFwNode(TaskNode* fw_node) override {
BoxingTaskNode::InitWithFwNode(fw_node);
}
void FwVirtualBuild() override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_IN_BOXING_TASK_NODE_H_
#include "oneflow/core/graph/logical_graph.h"
#include <iostream>
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
LogicalGraph::LogicalGraph(const DLNetConf& dl_net_conf,
const Placement& placement) {
const LogicalNode* LogicalGraph::GetProducerNode(const std::string& lbn) {
return lbn2producer_.at(lbn);
}
LogicalGraph::LogicalGraph() {
HashMap<LogicalEdge*, std::string> edge2lbn;
HashMap<LogicalEdge*, std::string> edge2ibn;
NaiveBuildGraphStruct(dl_net_conf, &edge2lbn, &edge2ibn);
FillNodeWithParallelDesc(placement);
NaiveBuildGraphStruct(&edge2lbn, &edge2ibn);
FillNodeWithParallelDesc();
AddCloneNodes(edge2lbn, edge2ibn);
ForEachNode([&](LogicalNode* node) {
for (const std::string& obn : node->op()->output_bns()) {
const std::string& lbn = node->op()->Lbn4BnInOp(obn);
CHECK(lbn2producer_.emplace(lbn, node).second);
}
});
ToDotWithAutoFilePath();
}
void LogicalGraph::NaiveBuildGraphStruct(
const DLNetConf& dl_net_conf, HashMap<LogicalEdge*, std::string>* edge2lbn,
HashMap<LogicalEdge*, std::string>* edge2lbn,
HashMap<LogicalEdge*, std::string>* edge2ibn) {
const DLNetConf& dlnet_conf = JobDesc::Singleton()->dlnet_conf();
HashMap<std::string, LogicalNode*> lbn2producer;
// Process Op
for (int op_i = 0; op_i < dl_net_conf.op_size(); ++op_i) {
const OperatorConf& cur_op_conf = dl_net_conf.op(op_i);
// Construct cur node
for (const OperatorConf& cur_op_conf : dlnet_conf.op()) {
LogicalNode* cur_node = NewNode();
cur_node->mut_op() = OpMgr::Singleton()->AddOp(cur_op_conf);
// Connect input node
for (const std::string& obn : cur_node->op()->output_bns()) {
const std::string& lbn = cur_node->op()->Lbn4BnInOp(obn);
CHECK(lbn2producer.emplace(lbn, cur_node).second);
}
}
ForEachNode([&](LogicalNode* cur_node) {
for (const std::string& ibn : cur_node->op()->input_bns()) {
const std::string& lbn = cur_node->op()->Lbn4BnInOp(ibn);
LogicalNode* pred_node = lbn2producer.at(lbn);
......@@ -33,34 +44,33 @@ void LogicalGraph::NaiveBuildGraphStruct(
CHECK(edge2ibn->emplace(edge, ibn).second);
Connect(pred_node, edge, cur_node);
}
// Construct output
for (const std::string& obn : cur_node->op()->output_bns()) {
const std::string& lbn = cur_node->op()->Lbn4BnInOp(obn);
CHECK(lbn2producer.emplace(lbn, cur_node).second);
}
}
lbn2producer.clear();
// Post Processing
UpdateSourceAndSink();
});
}
void LogicalGraph::FillNodeWithParallelDesc(const Placement& placement) {
void LogicalGraph::FillNodeWithParallelDesc() {
const Placement& placement = JobDesc::Singleton()->placement();
HashMap<std::string, LogicalNode*> op_name2node;
ForEachNode([&](LogicalNode* logical_node) {
const std::string& op_name = logical_node->op()->op_name();
CHECK(op_name2node.emplace(op_name, logical_node).second);
});
for (int gid = 0; gid < placement.placement_group_size(); ++gid) {
const PlacementGroup& cur_group = placement.placement_group(gid);
for (int li = 0; li < cur_group.op_set().op_name_size(); ++li) {
const std::string& op_name = cur_group.op_set().op_name(li);
auto it = op_name2node.find(op_name);
CHECK(it != op_name2node.end());
for (const PlacementGroup& cur_group : placement.placement_group()) {
for (const std::string& op_name : cur_group.op_set().op_name()) {
LogicalNode* node = op_name2node.at(op_name);
auto parallel_desc_raw_ptr = new ParallelDesc(cur_group.parallel_conf());
it->second->op()->FixParallelDesc(parallel_desc_raw_ptr);
it->second->mut_parallel_desc().reset(parallel_desc_raw_ptr);
node->op()->FixParallelDesc(parallel_desc_raw_ptr);
node->mut_parallel_desc().reset(parallel_desc_raw_ptr);
}
}
ForEachNode([&](LogicalNode* cur_node) {
if (cur_node->op()->IsElemWiseOp()) {
LogicalNode* pred_node = cur_node;
while (pred_node->op()->IsElemWiseOp()) {
pred_node = pred_node->SoleInEdge()->src_node();
}
cur_node->mut_parallel_desc() = pred_node->parallel_desc();
}
});
}
void LogicalGraph::AddCloneNodes(
......@@ -104,6 +114,7 @@ void LogicalGraph::CollectCloneInfos(
void LogicalGraph::AddOneCloneNode(
const CloneInfo& clone_info,
const HashMap<LogicalEdge*, std::string>& edge2ibn) {
if (clone_info.pred_node->op()->IsDataLoaderOp()) { return; }
LogicalNode* clone_node = NewNode();
clone_node->mut_op() = clone_info.clone_op;
clone_node->mut_parallel_desc() = clone_info.pred_node->parallel_desc();
......
......@@ -17,7 +17,7 @@ class LogicalNode final : public Node<LogicalNode, LogicalEdge> {
LogicalNode() = default;
~LogicalNode() = default;
std::shared_ptr<Operator> op() const { return op_; }
std::shared_ptr<const Operator> op() const { return op_; }
std::shared_ptr<Operator>& mut_op() { return op_; }
std::shared_ptr<const ParallelDesc> parallel_desc() const {
......@@ -27,9 +27,6 @@ class LogicalNode final : public Node<LogicalNode, LogicalEdge> {
return parallel_desc_;
}
bool IsLossNode() const { return op_->IsLossOp(); }
bool IsChainMergeable() const { return op_->IsChainMergeable(); }
std::string VisualStr() const override { return op_->op_name(); }
private:
......@@ -49,18 +46,18 @@ class LogicalEdge final : public Edge<LogicalNode, LogicalEdge> {
class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(LogicalGraph);
LogicalGraph() = delete;
~LogicalGraph() = default;
LogicalGraph(const DLNetConf& dl_net_conf, const Placement& placement);
OF_SINGLETON(LogicalGraph);
const char* TypeName() const override { return "LogicalGraph"; }
const LogicalNode* GetProducerNode(const std::string& lbn);
private:
void NaiveBuildGraphStruct(const DLNetConf& dl_net_conf,
HashMap<LogicalEdge*, std::string>* edge2lbn,
LogicalGraph();
void NaiveBuildGraphStruct(HashMap<LogicalEdge*, std::string>* edge2lbn,
HashMap<LogicalEdge*, std::string>* edge2ibn);
void FillNodeWithParallelDesc(const Placement& placement);
void FillNodeWithParallelDesc();
struct CloneInfo {
std::shared_ptr<Operator> clone_op;
......@@ -73,6 +70,8 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
const HashMap<LogicalEdge*, std::string>& edge2lbn);
void AddOneCloneNode(const CloneInfo& clone_info,
const HashMap<LogicalEdge*, std::string>& edge2ibn);
HashMap<std::string, const LogicalNode*> lbn2producer_;
};
} // namespace oneflow
......
#include "oneflow/core/graph/loss_accumulate_comp_task_node.h"
#include "oneflow/core/graph/loss_accumulate_task_graph.h"
namespace oneflow {
void LossAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
if (chain_node()->op_vec().empty()) {
CompTaskNode* loss_task = static_cast<LossAccTaskGraph*>(gph)->loss_task();
auto loss_regst = loss_task->GetProducedRegstDesc("loss");
BindProducedRegstAndOutEdge(loss_regst, SoleOutEdge());
return;
}
NewProducedRegstDesc("loss_acc", 1, kMaxRegisterNum);
auto loss_regst = GetRelatedRegst(SoleInEdge());
auto loss_acc_regst = GetProducedRegstDesc("loss_acc");
ExecNode* exec_node = mut_exec_gph().NewNode();
exec_node->mut_op() = chain_node()->SoleOp();
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleIbn(), loss_regst);
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleObn(), loss_acc_regst);
ConsumeRegstDesc("loss", loss_regst);
loss_acc_regst->CopyLbnFrom(loss_regst.get());
mut_exec_gph().UpdateSourceAndSink();
}
void LossAccCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {
if (!chain_node()->op_vec().empty()) {
auto loss_regst = GetConsumedRegstDesc("loss");
auto loss_acc_regst = GetProducedRegstDesc("loss_acc");
loss_acc_regst->CopyBlobDescFrom(loss_regst.get());
}
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_LOSS_ACCUMULATE_COMP_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_LOSS_ACCUMULATE_COMP_TASK_NODE_H_
#include "oneflow/core/graph/comp_task_node.h"
namespace oneflow {
class LossAccCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(LossAccCompTaskNode);
LossAccCompTaskNode() = default;
~LossAccCompTaskNode() = default;
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferBlobDescInProducedRegsts(TaskGraph* gph) override;
TaskType task_type() const override { return kLossAccCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<LossAccCompTaskNode>();
}
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_LOSS_ACCUMULATE_COMP_TASK_NODE_H_
#ifndef ONEFLOW_CORE_GRAPH_LOSS_ACCUMULATE_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_LOSS_ACCUMULATE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/accumulate_compute_task_node.h"
namespace oneflow {
class LossAccCompTaskNode final : public AccCompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(LossAccCompTaskNode);
LossAccCompTaskNode() = default;
~LossAccCompTaskNode() = default;
TodoTaskType GetTaskType() const override { return TodoTaskType::kLossAcc; }
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_LOSS_ACCUMULATE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/loss_accumulate_task_graph.h"
#include "oneflow/core/graph/loss_accumulate_comp_task_node.h"
namespace oneflow {
LossAccTaskGraph::LossAccTaskGraph(const std::string& name,
CompTaskNode* loss_task) {
mut_name() = name;
loss_task_ = loss_task;
BuildTaskGraph();
BuildExecAndEnrollLbn2Regsts();
}
void LossAccTaskGraph::BuildTaskGraph() {
// loss acc op
OperatorConf op_conf;
op_conf.set_name("loss_acc_" + NewUniqueId());
op_conf.mutable_accumulate_conf();
auto loss_acc_op = OpMgr::Singleton()->AddOp(op_conf);
// parallel_desc
ParallelConf pr_conf;
pr_conf.set_policy(kDataParallel);
pr_conf.add_device_name(loss_task_->device_name());
auto pr_desc = std::make_shared<ParallelDesc>(pr_conf);
// faker chain
auto chain_gph = of_make_unique<ChainGraph>();
ChainNode* faker_chain = chain_gph->NewNode();
faker_chain->mut_op_vec() = {};
faker_chain->mut_parallel_desc() = pr_desc;
faker_chain->mut_output_lbns() = {kPackedBlobName};
// loss acc chain
ChainNode* loss_acc_chain = chain_gph->NewNode();
loss_acc_chain->mut_op_vec() = {loss_acc_op};
loss_acc_chain->mut_parallel_desc() = pr_desc;
loss_acc_chain->mut_input_lbns() = {kPackedBlobName};
//
Connect(faker_chain, chain_gph->NewEdge(), loss_acc_chain);
chain_gph->UpdateSourceAndSink();
chain_gph->ToDotWithAutoFilePath();
BuildFromChainGph<LossAccCompTaskNode>(std::move(chain_gph), false);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_LOSS_ACCUMULATE_TASK_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_LOSS_ACCUMULATE_TASK_GRAPH_H_
#include "oneflow/core/graph/task_graph.h"
namespace oneflow {
class LossAccTaskGraph final : public TaskGraph {
public:
OF_DISALLOW_COPY_AND_MOVE(LossAccTaskGraph);
LossAccTaskGraph() = delete;
~LossAccTaskGraph() = default;
LossAccTaskGraph(const std::string& name, CompTaskNode* loss_task);
const char* TypeName() const override { return "LossAccTaskGraph"; }
CompTaskNode* loss_task() { return loss_task_; }
private:
void BuildTaskGraph();
CompTaskNode* loss_task_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_LOSS_ACCUMULATE_TASK_GRAPH_H_
#include "oneflow/core/graph/loss_compute_task_node.h"
namespace oneflow {
void LossCompTaskNode::ProduceAllRegstsAndBindEdges() {
ProduceRegst("loss", 1, kMaxRegisterNum);
ProduceRegst("in_diff", 1, kMaxRegisterNum);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_LOSS_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_LOSS_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class LossCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(LossCompTaskNode);
LossCompTaskNode() = default;
~LossCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
TodoTaskType GetTaskType() const override { return TodoTaskType::kLoss; }
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_LOSS_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/loss_record_comp_task_node.h"
#include "oneflow/core/graph/loss_record_task_graph.h"
namespace oneflow {
void LossRecordCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
if (chain_node()->op_vec().empty()) {
auto loss_record_gph = static_cast<LossRecordTaskGraph*>(gph);
CompTaskNode* loss_acc_task =
loss_record_gph->GetLossAccCompTaskNodeFromParallelId(parallel_id());
auto loss_acc_regst = loss_acc_task->GetProducedRegstDesc("loss_acc");
BindProducedRegstAndOutEdge(loss_acc_regst, SoleOutEdge());
return;
}
auto loss_acc_regst = GetRelatedRegst(SoleInEdge());
ExecNode* exec_node = mut_exec_gph().NewNode();
exec_node->mut_op() = chain_node()->SoleOp();
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleIbn(), loss_acc_regst);
ConsumeRegstDesc("loss_acc", loss_acc_regst);
mut_exec_gph().UpdateSourceAndSink();
}
void LossRecordCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_LOSS_RECORD_COMP_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_LOSS_RECORD_COMP_TASK_NODE_H_
#include "oneflow/core/graph/comp_task_node.h"
namespace oneflow {
class LossRecordCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(LossRecordCompTaskNode);
LossRecordCompTaskNode() = default;
~LossRecordCompTaskNode() = default;
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferBlobDescInProducedRegsts(TaskGraph* gph) override;
bool IsMeaningLess() const override {
return !GetConsumedRegstDesc("loss_acc");
}
TaskType task_type() const override { return kLossRecordCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<LossRecordCompTaskNode>();
}
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_LOSS_RECORD_COMP_TASK_NODE_H_
#include "oneflow/core/graph/loss_record_compute_task_node.h"
namespace oneflow {
void LossRecordCompTaskNode::ProduceAllRegstsAndBindEdges() {}
void LossRecordCompTaskNode::FixThrdLocId() {
set_thrd_loc_id(IDMgr::Singleton()->PersistenceThrdLocId());
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_LOSS_RECORD_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_LOSS_RECORD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class LossRecordCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(LossRecordCompTaskNode);
LossRecordCompTaskNode() = default;
~LossRecordCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
TodoTaskType GetTaskType() const override {
return TodoTaskType::kLossRecord;
}
void FixThrdLocId();
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_LOSS_RECORD_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/loss_record_task_graph.h"
#include "oneflow/core/graph/loss_record_comp_task_node.h"
namespace oneflow {
LossRecordTaskGraph::LossRecordTaskGraph(
const std::string& name,
const std::vector<TaskNode*>& sorted_loss_acc_task) {
mut_name() = name;
BuildTaskGraph(sorted_loss_acc_task);
BuildExecAndEnrollLbn2Regsts();
}
void LossRecordTaskGraph::BuildTaskGraph(
const std::vector<TaskNode*>& sorted_loss_acc_task) {
// faker_pr_conf
ParallelConf faker_pr_conf;
faker_pr_conf.set_policy(kFakerLossRecord);
for (TaskNode* task : sorted_loss_acc_task) {
auto loss_acc_task = static_cast<CompTaskNode*>(task);
faker_pr_conf.add_device_name(loss_acc_task->device_name());
sorted_loss_acc_tasks_.push_back(loss_acc_task);
}
// faker chain
auto chain_gph = of_make_unique<ChainGraph>();
ChainNode* faker_chain = chain_gph->NewNode();
faker_chain->mut_op_vec() = {};
faker_chain->mut_parallel_desc().reset(new ParallelDesc(faker_pr_conf));
faker_chain->mut_output_lbns() = {kPackedBlobName};
// loss_record_pr_conf
ParallelConf loss_record_pr_conf;
loss_record_pr_conf.set_policy(kDataParallel);
loss_record_pr_conf.add_device_name(
IDMgr::Singleton()->MachineName4MachineId(0) + ":persistence");
// loss record op
OperatorConf op_conf;
op_conf.set_name("loss_record_" + NewUniqueId());
op_conf.mutable_loss_record_conf();
auto loss_record_op = OpMgr::Singleton()->AddOp(op_conf);
// loss record chain
ChainNode* loss_record_chain = chain_gph->NewNode();
loss_record_chain->mut_op_vec() = {loss_record_op};
loss_record_chain->mut_parallel_desc().reset(
new ParallelDesc(loss_record_pr_conf));
loss_record_chain->mut_input_lbns() = {kPackedBlobName};
//
Connect(faker_chain, chain_gph->NewEdge(), loss_record_chain);
chain_gph->UpdateSourceAndSink();
chain_gph->ToDotWithAutoFilePath();
BuildFromChainGph<LossRecordCompTaskNode>(std::move(chain_gph), false);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_LOSS_RECORD_TASK_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_LOSS_RECORD_TASK_GRAPH_H_
#include "oneflow/core/graph/task_graph.h"
namespace oneflow {
class LossRecordTaskGraph final : public TaskGraph {
public:
OF_DISALLOW_COPY_AND_MOVE(LossRecordTaskGraph);
LossRecordTaskGraph() = delete;
~LossRecordTaskGraph() = default;
LossRecordTaskGraph(const std::string& name,
const std::vector<TaskNode*>& sorted_loss_acc_task);
const char* TypeName() const override { return "LossRecordTaskGraph"; }
CompTaskNode* GetLossAccCompTaskNodeFromParallelId(int64_t parallel_id) {
return sorted_loss_acc_tasks_.at(parallel_id);
}
private:
void BuildTaskGraph(const std::vector<TaskNode*>& sorted_loss_acc_task);
std::vector<CompTaskNode*> sorted_loss_acc_tasks_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_LOSS_RECORD_TASK_GRAPH_H_
#include "oneflow/core/graph/model_diff_accumulate_comp_task_node.h"
#include "oneflow/core/graph/model_diff_accumulate_task_graph.h"
namespace oneflow {
void MdDiffAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
CHECK(IsFwNode());
auto md_diff_acc_gph = static_cast<MdDiffAccTaskGraph*>(gph);
fw_task_ = md_diff_acc_gph->GetFwTaskFromParallelId(parallel_id());
TaskNode* bp_task = fw_task_->GetBpNode();
std::shared_ptr<RegstDesc> model_diff_regst =
bp_task->GetProducedRegstDesc("model_diff");
// faker task node
if (chain_node()->op_vec().empty()) {
BindProducedRegstAndOutEdge(model_diff_regst, SoleOutEdge());
return;
}
// comp task node
NewProducedRegstDesc("model_diff_acc", 1, kMaxRegisterNum);
auto model_diff_acc_regst = GetProducedRegstDesc("model_diff_acc");
ExecNode* exec_node = mut_exec_gph().NewNode();
exec_node->mut_op() = chain_node()->SoleOp();
if (in_edges().empty()) {
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleIbn(), model_diff_regst);
ConsumeRegstDesc("model_diff", model_diff_regst);
} else {
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleIbn(),
GetRelatedRegst(SoleInEdge()));
ConsumeRegstDesc("model_diff", GetRelatedRegst(SoleInEdge()));
}
model_diff_acc_regst->CopyLbnFrom(GetConsumedRegstDesc("model_diff").get());
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleObn(),
model_diff_acc_regst);
mut_exec_gph().UpdateSourceAndSink();
}
void MdDiffAccCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {
CHECK(IsFwNode());
if (!chain_node()->op_vec().empty()) {
std::shared_ptr<RegstDesc> in_regst = GetConsumedRegstDesc("model_diff");
std::shared_ptr<RegstDesc> out_regst =
GetProducedRegstDesc("model_diff_acc");
out_regst->CopyBlobDescFrom(in_regst.get());
}
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_COMP_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_COMP_TASK_NODE_H_
#include "oneflow/core/graph/comp_task_node.h"
namespace oneflow {
class MdDiffAccCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(MdDiffAccCompTaskNode);
MdDiffAccCompTaskNode() = default;
~MdDiffAccCompTaskNode() = default;
void ToProto(TaskProto* proto, std::function<int64_t(const ChainNode*)>
MeaninglessTaskCnt4Chain) const override {
TaskNode::ToProto(proto, MeaninglessTaskCnt4Chain);
fw_task_->FillProtoWithParallelInfo(proto, MeaninglessTaskCnt4Chain);
}
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferBlobDescInProducedRegsts(TaskGraph* gph) override;
TaskType task_type() const override { return kMdDiffAccCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<MdDiffAccCompTaskNode>();
}
CompTaskNode* fw_task_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_COMP_TASK_NODE_H_
#ifndef ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/accumulate_compute_task_node.h"
namespace oneflow {
class MdDiffAccCompTaskNode final : public AccCompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(MdDiffAccCompTaskNode);
MdDiffAccCompTaskNode() = default;
~MdDiffAccCompTaskNode() = default;
TodoTaskType GetTaskType() const override { return TodoTaskType::kMdDiffAcc; }
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/model_diff_accumulate_task_graph.h"
#include "oneflow/core/graph/model_diff_accumulate_comp_task_node.h"
namespace oneflow {
MdDiffAccTaskGraph::MdDiffAccTaskGraph(
const std::string& name, const ChainNode* data_chain,
const std::vector<CompTaskNode*>& sorted_fw_comptasks4data_chain) {
mut_name() = name;
BuildTaskGraph(data_chain);
for (CompTaskNode* fw_task : sorted_fw_comptasks4data_chain) {
CHECK(parallel_id2fw_task_.emplace(fw_task->parallel_id(), fw_task).second);
}
BuildExecAndEnrollLbn2Regsts();
}
void MdDiffAccTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
// Construct ModelDiffAccOp
OperatorConf op_conf;
op_conf.set_name("model_diff_acc_" + NewUniqueId());
op_conf.mutable_accumulate_conf();
auto model_diff_acc_op = OpMgr::Singleton()->AddOp(op_conf);
// ModelDiffAccChain
auto chain_gph = of_make_unique<ChainGraph>();
ChainNode* diff_acc_chain = chain_gph->NewNode();
diff_acc_chain->mut_op_vec() = {model_diff_acc_op};
auto parallel_desc4diff_acc =
new ParallelDesc(*(data_chain->parallel_desc()));
parallel_desc4diff_acc->mut_policy() = kModelParallel;
diff_acc_chain->mut_parallel_desc().reset(parallel_desc4diff_acc);
// FakerChain
if (data_chain->parallel_desc()->policy() == kDataParallel) {
ChainNode* faker_chain = chain_gph->NewNode();
faker_chain->mut_op_vec().clear();
auto parallel_desc4faker = new ParallelDesc(*(data_chain->parallel_desc()));
parallel_desc4faker->mut_policy() = kFakerMdUpdt;
faker_chain->mut_parallel_desc().reset(parallel_desc4faker);
faker_chain->mut_output_lbns() = {kPackedBlobName};
diff_acc_chain->mut_input_lbns() = {kPackedBlobName};
Connect(faker_chain, chain_gph->NewEdge(), diff_acc_chain);
}
//
chain_gph->UpdateSourceAndSink();
chain_gph->ToDotWithAutoFilePath();
BuildFromChainGph<MdDiffAccCompTaskNode>(std::move(chain_gph), false);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_TASK_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_TASK_GRAPH_H_
#include "oneflow/core/graph/task_graph.h"
namespace oneflow {
class MdDiffAccTaskGraph final : public TaskGraph {
public:
OF_DISALLOW_COPY_AND_MOVE(MdDiffAccTaskGraph);
MdDiffAccTaskGraph() = delete;
~MdDiffAccTaskGraph() = default;
MdDiffAccTaskGraph(
const std::string& name, const ChainNode* data_chain,
const std::vector<CompTaskNode*>& sorted_fw_comptasks4data_chain);
CompTaskNode* GetFwTaskFromParallelId(int64_t parallel_id) const {
return parallel_id2fw_task_.at(parallel_id);
}
const char* TypeName() const override { return "MdDiffAccTaskGraph"; }
private:
void BuildTaskGraph(const ChainNode* data_chain);
HashMap<int64_t, CompTaskNode*> parallel_id2fw_task_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_TASK_GRAPH_H_
#include "oneflow/core/graph/model_save_comp_task_node.h"
#include "oneflow/core/graph/model_save_task_graph.h"
#include "oneflow/core/graph/model_update_comp_task_node.h"
namespace oneflow {
void MdSaveCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
CHECK(IsFwNode());
auto md_save_gph = static_cast<MdSaveTaskGraph*>(gph);
CompTaskNode* updt_task = md_save_gph->update_task();
if (in_edges().empty()) {
BindProducedRegstAndOutEdge(updt_task->GetProducedRegstDesc("model"),
SoleOutEdge());
} else if (out_edges().empty()) {
ConsumeRegstDesc("model", GetRelatedRegst(SoleInEdge()));
OperatorConf op_conf;
op_conf.set_name("model_save_op" + updt_task->node_id_str());
op_conf.mutable_model_save_conf();
GetRelatedRegst(SoleInEdge())->ForEachLbn([&](const std::string& lbn) {
op_conf.mutable_model_save_conf()->add_lbns(lbn);
});
ExecNode* exec_node = mut_exec_gph().NewNode();
exec_node->mut_op() = OpMgr::Singleton()->AddOp(op_conf);
for (const std::string& ibn : exec_node->op()->input_bns()) {
exec_node->BindBnInOpAndRegst(ibn, GetRelatedRegst(SoleInEdge()));
}
mut_exec_gph().UpdateSourceAndSink();
} else {
UNEXPECTED_RUN();
}
}
void MdSaveCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {
CHECK(IsFwNode());
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_MODEL_SAVE_COMP_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_MODEL_SAVE_COMP_TASK_NODE_H_
#include "oneflow/core/graph/comp_task_node.h"
namespace oneflow {
class MdSaveCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(MdSaveCompTaskNode);
MdSaveCompTaskNode() = default;
~MdSaveCompTaskNode() = default;
void ToProto(TaskProto* proto, std::function<int64_t(const ChainNode*)>
MeaninglessTaskCnt4Chain) const override {
TaskNode::ToProto(proto, MeaninglessTaskCnt4Chain);
fw_task_->FillProtoWithParallelInfo(proto, MeaninglessTaskCnt4Chain);
}
void set_fw_task(CompTaskNode* fw_task) { fw_task_ = fw_task; }
CompTaskNode* fw_task() { return fw_task_; }
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferBlobDescInProducedRegsts(TaskGraph* gph) override;
bool IsMeaningLess() const override { return !GetConsumedRegstDesc("model"); }
TaskType task_type() const override { return kMdSaveCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<MdSaveCompTaskNode>();
}
CompTaskNode* fw_task_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_SAVE_COMP_TASK_NODE_H_
#include "oneflow/core/graph/model_save_compute_task_node.h"
namespace oneflow {
void MdSaveCompTaskNode::ProduceAllRegstsAndBindEdges() {}
void MdSaveCompTaskNode::FixThrdLocId() {
set_thrd_loc_id(IDMgr::Singleton()->PersistenceThrdLocId());
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_MODEL_SAVE_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_MODEL_SAVE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class MdSaveCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(MdSaveCompTaskNode);
MdSaveCompTaskNode() = default;
~MdSaveCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
TodoTaskType GetTaskType() const override { return TodoTaskType::kMdSave; }
void FixThrdLocId() override;
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_SAVE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/model_save_task_graph.h"
#include "oneflow/core/graph/model_save_comp_task_node.h"
#include "oneflow/core/graph/model_update_comp_task_node.h"
namespace oneflow {
MdSaveTaskGraph::MdSaveTaskGraph(const std::string& name,
CompTaskNode* update_task) {
mut_name() = name;
update_task_ = update_task;
BuildTaskGraph();
BuildExecAndEnrollLbn2Regsts();
}
void MdSaveTaskGraph::BuildTaskGraph() {
auto chain_gph = of_make_unique<ChainGraph>();
// faker
ChainNode* faker_chain = chain_gph->NewNode();
ParallelConf faker_pr_conf;
faker_pr_conf.set_policy(kDataParallel);
faker_pr_conf.add_device_name(update_task_->device_name());
faker_chain->mut_parallel_desc().reset(new ParallelDesc(faker_pr_conf));
faker_chain->mut_output_lbns() = {kPackedBlobName};
// save
ChainNode* save_chain = chain_gph->NewNode();
std::string machine_name =
GetMachineNameFromDeviceName(update_task_->device_name());
ParallelConf save_pr_conf;
save_pr_conf.set_policy(kDataParallel);
save_pr_conf.add_device_name(machine_name + ":persistence");
save_chain->mut_parallel_desc().reset(new ParallelDesc(save_pr_conf));
save_chain->mut_input_lbns() = {kPackedBlobName};
//
Connect(faker_chain, chain_gph->NewEdge(), save_chain);
chain_gph->UpdateSourceAndSink();
chain_gph->ToDotWithAutoFilePath();
BuildFromChainGph<MdSaveCompTaskNode>(std::move(chain_gph), false);
ForEachNode([this](TaskNode* node) {
auto model_save_comp_task_node = dynamic_cast<MdSaveCompTaskNode*>(node);
if (model_save_comp_task_node != nullptr) {
auto model_update_comp_task_node =
static_cast<MdUpdtCompTaskNode*>(update_task_);
model_save_comp_task_node->set_fw_task(
model_update_comp_task_node->fw_task());
}
});
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_MODEL_SAVE_TASK_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_MODEL_SAVE_TASK_GRAPH_H_
#include "oneflow/core/graph/task_graph.h"
namespace oneflow {
class MdSaveTaskGraph final : public TaskGraph {
public:
OF_DISALLOW_COPY_AND_MOVE(MdSaveTaskGraph);
MdSaveTaskGraph() = delete;
~MdSaveTaskGraph() = default;
MdSaveTaskGraph(const std::string& name, CompTaskNode* update_task);
CompTaskNode* update_task() const { return update_task_; }
const char* TypeName() const override { return "MdSaveTaskGraph"; }
private:
void BuildTaskGraph();
CompTaskNode* update_task_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_SAVE_TASK_GRAPH_H_
#include "oneflow/core/graph/model_update_comp_task_node.h"
#include "oneflow/core/graph/model_update_task_graph.h"
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
void MdUpdtCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
CHECK(IsFwNode());
auto md_updt_gph = static_cast<MdUpdtTaskGraph*>(gph);
CompTaskNode* diff_acc_task = md_updt_gph->diff_acc_task();
std::shared_ptr<RegstDesc> model_diff_acc_regst;
if (diff_acc_task != nullptr) {
model_diff_acc_regst =
diff_acc_task->GetProducedRegstDesc("model_diff_acc");
}
TakeOverRegstDesc(fw_task_, "model");
TakeOverRegstDesc(fw_task_, "model_tmp");
auto model_regst = GetProducedRegstDesc("model");
ExecNode* exec_node = mut_exec_gph().NewNode();
exec_node->mut_op() = chain_node()->SoleOp();
const std::string ibn = "model_diffs";
if (model_diff_acc_regst != nullptr) {
exec_node->BindBnInOpAndRegst(ibn, model_diff_acc_regst);
ConsumeRegstDesc(ibn, model_diff_acc_regst);
}
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleObn(), model_regst);
auto data_tmp_regst = NewProducedRegstDesc("data_tmp", 1);
for (const std::string& dtbn : exec_node->op()->data_tmp_bns()) {
const std::string& lbn = exec_node->op()->Lbn4BnInOp(dtbn);
data_tmp_regst->EnrollLbn(lbn);
exec_node->BindBnInOpAndRegst(dtbn, data_tmp_regst);
}
mut_exec_gph().UpdateSourceAndSink();
}
void MdUpdtCompTaskNode::InferBlobDescInProducedRegsts(TaskGraph* gph) {
CHECK(IsFwNode());
ExecNode* exec_node = exec_gph().SoleNode();
auto model_diffs_regst = GetConsumedRegstDesc("model_diffs");
BlobDesc packed_blob_desc;
if (model_diffs_regst) {
packed_blob_desc = model_diffs_regst->CompPackedBlobDesc();
} else {
CHECK(JobDesc::Singleton()->is_predict());
packed_blob_desc =
BlobDesc(Shape(), JobDesc::Singleton()->default_data_type(), false);
}
exec_node->op()->InferBlobDesc4FwBlobs(
[&](const std::string& bn_in_op) -> BlobDesc* {
if (bn_in_op == "model_diffs") {
return &packed_blob_desc;
} else {
return exec_node->GetBlobDesc4BnInOpFunc()(bn_in_op);
}
},
kDataParallel, 0, 0);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_MODEL_UPDATE_COMP_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_MODEL_UPDATE_COMP_TASK_NODE_H_
#include "oneflow/core/graph/data_comp_task_node.h"
namespace oneflow {
class MdUpdtCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(MdUpdtCompTaskNode);
MdUpdtCompTaskNode() = default;
~MdUpdtCompTaskNode() = default;
void ToProto(TaskProto* proto, std::function<int64_t(const ChainNode*)>
MeaninglessTaskCnt4Chain) const override {
TaskNode::ToProto(proto, MeaninglessTaskCnt4Chain);
fw_task_->FillProtoWithParallelInfo(proto, MeaninglessTaskCnt4Chain);
int64_t related_save_task_id = -1;
for (const auto& pair : produced_regst_descs()) {
for (const TaskNode* consumer : pair.second->consumers()) {
if (dynamic_cast<const DataCompTaskNode*>(consumer) == nullptr) {
CHECK_EQ(related_save_task_id, -1);
related_save_task_id = consumer->task_id();
}
}
}
proto->set_related_save_task_id(related_save_task_id);
proto->set_random_seed(random_seed_);
}
void set_fw_task(CompTaskNode* fw_task) { fw_task_ = fw_task; }
CompTaskNode* fw_task() { return fw_task_; }
void set_random_seed(uint32_t val) { random_seed_ = val; }
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferBlobDescInProducedRegsts(TaskGraph* gph) override;
TaskType task_type() const override { return kMdUpdtCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<MdUpdtCompTaskNode>();
}
CompTaskNode* fw_task_;
uint32_t random_seed_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_UPDATE_COMP_TASK_NODE_H_
#include "oneflow/core/graph/model_update_compute_task_node.h"
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
void MdUpdtCompTaskNode::ProduceAllRegstsAndBindEdges() {
ProduceRegst("model_tmp", 1, 1);
ProduceRegst("model", 3, kMaxRegisterNum);
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_MODEL_UPDATE_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_MODEL_UPDATE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class MdUpdtCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(MdUpdtCompTaskNode);
MdUpdtCompTaskNode() = default;
~MdUpdtCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
void set_random_seed(uint32_t val) { random_seed_ = val; }
TodoTaskType GetTaskType() const override { return TodoTaskType::kMdUpdt; }
private:
uint32_t random_seed_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_UPDATE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/model_update_task_graph.h"
#include "oneflow/core/graph/model_update_comp_task_node.h"
namespace oneflow {
MdUpdtTaskGraph::MdUpdtTaskGraph(const std::string& name, CompTaskNode* fw_task,
CompTaskNode* diff_acc_task,
uint32_t random_seed) {
mut_name() = name;
fw_task_ = fw_task;
diff_acc_task_ = diff_acc_task;
BuildTaskGraph(random_seed);
BuildExecAndEnrollLbn2Regsts();
}
void MdUpdtTaskGraph::BuildTaskGraph(uint32_t random_seed) {
auto chain_gph = of_make_unique<ChainGraph>();
ChainNode* updt_chain = chain_gph->NewNode();
ParallelConf updt_pr_conf;
updt_pr_conf.set_policy(kDataParallel);
updt_pr_conf.add_device_name(fw_task_->device_name());
updt_chain->mut_parallel_desc().reset(new ParallelDesc(updt_pr_conf));
updt_chain->mut_input_lbns() = {kPackedBlobName};
updt_chain->mut_op_vec() = {OpMgr::Singleton()->ModelUpdateOp()};
chain_gph->UpdateSourceAndSink();
chain_gph->ToDotWithAutoFilePath();
BuildFromChainGph<MdUpdtCompTaskNode>(std::move(chain_gph), false);
ForEachNode([this, random_seed](TaskNode* node) {
auto model_updt_comp_task_node = dynamic_cast<MdUpdtCompTaskNode*>(node);
if (model_updt_comp_task_node == nullptr) { return; }
model_updt_comp_task_node->set_fw_task(fw_task_);
ParallelPolicy this_policy =
fw_task_->chain_node()->parallel_desc()->policy();
if (this_policy == kDataParallel) {
model_updt_comp_task_node->set_random_seed(random_seed);
} else if (this_policy == kModelParallel) {
model_updt_comp_task_node->set_random_seed(NewRandomSeed());
} else {
UNEXPECTED_RUN();
}
});
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_MODEL_UPDATE_TASK_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_MODEL_UPDATE_TASK_GRAPH_H_
#include "oneflow/core/graph/task_graph.h"
namespace oneflow {
class MdUpdtTaskGraph final : public TaskGraph {
public:
OF_DISALLOW_COPY_AND_MOVE(MdUpdtTaskGraph);
MdUpdtTaskGraph() = delete;
~MdUpdtTaskGraph() = default;
MdUpdtTaskGraph(const std::string& name, CompTaskNode* fw_task,
CompTaskNode* diff_acc_task, uint32_t random_seed);
CompTaskNode* fw_task() const { return fw_task_; }
CompTaskNode* diff_acc_task() const { return diff_acc_task_; }
const char* TypeName() const override { return "MdUpdtTaskGraph"; }
private:
void BuildTaskGraph(uint32_t random_seed);
CompTaskNode* fw_task_;
CompTaskNode* diff_acc_task_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_UPDATE_TASK_GRAPH_H_
......@@ -40,7 +40,6 @@ class Edge {
virtual ~Edge() = default;
int64_t edge_id() const { return edge_id_; }
std::string edge_id_str() const { return std::to_string(edge_id_); }
NodeType* src_node() const { return src_node_; }
NodeType* dst_node() const { return dst_node_; }
......@@ -67,6 +66,7 @@ class Node {
int64_t node_id() const { return node_id_; }
std::string node_id_str() const { return std::to_string(node_id_); }
EdgeType* SoleInEdge() const {
CHECK_EQ(in_edges_.size(), 1);
return *(in_edges_.begin());
......@@ -79,6 +79,17 @@ class Node {
const std::unordered_set<EdgeType*>& in_edges() const { return in_edges_; }
const std::unordered_set<EdgeType*>& out_edges() const { return out_edges_; }
void ForEachNodeOnInEdge(std::function<void(NodeType*)> Handler) const {
for (EdgeType* edge : in_edges_) { Handler(edge->src_node()); }
}
void ForEachNodeOnOutEdge(std::function<void(NodeType*)> Handler) const {
for (EdgeType* edge : out_edges_) { Handler(edge->dst_node()); }
}
void ForEachNodeOnInOutEdge(std::function<void(NodeType*)> Handler) const {
ForEachNodeOnInEdge(Handler);
ForEachNodeOnOutEdge(Handler);
}
void DisconnectAllEdges() {
for (EdgeType* edge : in_edges_) { DisConnect(edge); }
for (EdgeType* edge : out_edges_) { DisConnect(edge); }
......
#include "oneflow/core/graph/out_boxing_task_node.h"
namespace oneflow {
void OutBoxingTaskNode::FwVirtualBuild() {
Chain2EdgesMap chain2sorted_out_edges;
FwInitChain2SortedEdgesMaps(&chain2sorted_out_edges, &TaskNode::out_edges,
&TaskEdge::dst_node, &TaskNode::SoleOutEdge);
ChainEdgesPair chain_sorted_in_edges;
chain_sorted_in_edges.first = chain_node();
chain_sorted_in_edges.second.assign(in_edges().begin(), in_edges().end());
FwSortEdgesInnerStage(&chain_sorted_in_edges.second, &TaskEdge::src_node,
&TaskNode::SoleInEdge);
for (const ChainEdgesPair& chain_sorted_out_edges : chain2sorted_out_edges) {
FwBuildChainSortedEdgesPair(chain_sorted_in_edges, chain_sorted_out_edges);
}
mut_exec_gph().UpdateSourceAndSink();
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_OUT_BOXING_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_OUT_BOXING_TASK_NODE_H_
#include "oneflow/core/graph/boxing_task_node.h"
namespace oneflow {
class OutBoxingTaskNode final : public BoxingTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(OutBoxingTaskNode);
OutBoxingTaskNode() = default;
~OutBoxingTaskNode() = default;
private:
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<OutBoxingTaskNode>();
}
void InitWithFwNode(TaskNode* fw_node) override {
BoxingTaskNode::InitWithFwNode(fw_node);
}
void FwVirtualBuild() override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_OUT_BOXING_TASK_NODE_H_
#include "oneflow/core/graph/source_compute_task_node.h"
namespace oneflow {
void SourceCompTaskNode::ProduceAllRegstsAndBindEdges() {
ProduceRegst("out", 1, kMaxRegisterNum);
}
void SourceCompTaskNode::FixThrdLocId() {
set_thrd_loc_id(IDMgr::Singleton()->PersistenceThrdLocId());
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_SOURCE_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_SOURCE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/compute_task_node.h"
namespace oneflow {
class SourceCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(SourceCompTaskNode);
SourceCompTaskNode() = default;
~SourceCompTaskNode() = default;
void ProduceAllRegstsAndBindEdges() override;
TodoTaskType GetTaskType() const override { return TodoTaskType::kSource; }
void FixThrdLocId() override;
private:
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_SOURCE_COMPUTE_TASK_NODE_H_
#include "oneflow/core/graph/stage_graph.h"
namespace oneflow {
StageGraph::StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph) {
chain_gph_ = std::move(chain_gph);
HashMap<const ChainNode*, std::vector<StageNode*>> chain2stages;
// Construct Stage
chain_gph_->ConstForEachNode([&](const ChainNode* cur_chain) {
chain2stages[cur_chain] = {};
auto parallel_desc = cur_chain->parallel_desc();
int64_t range_idx = 0;
for (int64_t machine_id : parallel_desc->sorted_machine_ids()) {
StageNode* stage_node = NewNode();
stage_node->mut_machine_id() = machine_id;
stage_node->set_chain_node(cur_chain);
stage_node->mut_parallel_range().mut_begin() = range_idx;
size_t device_num =
parallel_desc->sorted_device_phy_ids(machine_id).size();
if (device_num == 0) {
device_num = 1; // persistence
}
range_idx += device_num;
stage_node->mut_parallel_range().mut_end() = range_idx;
chain2stages.at(cur_chain).push_back(stage_node);
}
});
// Connect Stage
chain_gph_->ConstForEachNode([&](const ChainNode* cur_chain) {
for (const ChainEdge* edge : cur_chain->out_edges()) {
const auto& cur_stages = chain2stages.at(cur_chain);
const auto& succ_stages = chain2stages.at(edge->dst_node());
for (StageNode* cur_stage_node : cur_stages) {
for (StageNode* succ_stage_node : succ_stages) {
Connect(cur_stage_node, NewEdge(), succ_stage_node);
}
}
}
});
// Post processing
UpdateSourceAndSink();
ToDotWithAutoFilePath();
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_STAGE_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_STAGE_GRAPH_H_
#include "oneflow/core/common/range.h"
#include "oneflow/core/graph/chain_graph.h"
namespace oneflow {
class StageEdge;
class StageNode final : public Node<StageNode, StageEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(StageNode);
StageNode() = default;
~StageNode() = default;
std::string machine_id_str() const { return std::to_string(machine_id_); }
const int64_t& machine_id() const { return machine_id_; }
int64_t& mut_machine_id() { return machine_id_; }
const ChainNode* chain_node() const { return chain_node_; }
void set_chain_node(const ChainNode* new_chain_node) {
chain_node_ = new_chain_node;
}
const Range& parallel_range() const { return parallel_range_; }
Range& mut_parallel_range() { return parallel_range_; }
const std::vector<int64_t>& SortedDevicePhyIds() const {
return chain_node_->parallel_desc()->sorted_device_phy_ids(machine_id_);
}
std::string VisualStr() const override {
return machine_id_str() + "\\n" + chain_node_->VisualStr();
}
private:
const ChainNode* chain_node_;
int64_t machine_id_;
Range parallel_range_;
};
class StageEdge final : public Edge<StageNode, StageEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(StageEdge);
StageEdge() = default;
~StageEdge() = default;
private:
};
class StageGraph final : public Graph<StageNode, StageEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(StageGraph);
StageGraph() = delete;
~StageGraph() = default;
StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph);
const ChainGraph* chain_gph() const { return chain_gph_.get(); }
const char* TypeName() const override { return "StageGraph"; }
private:
std::unique_ptr<const ChainGraph> chain_gph_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_STAGE_GRAPH_H_
此差异已折叠。
#ifndef ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_
#include "oneflow/core/graph/boxing_task_node.h"
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/graph/stage_graph.h"
#include "oneflow/core/graph/chain_graph.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
class TaskGraph : public Graph<TaskNode, TaskEdge> {
class TaskGraph final : public Graph<TaskNode, TaskEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(TaskGraph);
virtual ~TaskGraph() = default;
// Getters
const StageGraph* stage_gph() const { return stage_gph_.get(); }
const ChainGraph* chain_gph() const { return stage_gph_->chain_gph(); }
std::vector<CompTaskNode*> CompTasksInChain(const ChainNode*);
void InferBlobDescInProducedRegsts();
const std::string& name() const { return name_; }
protected:
TaskGraph() = default;
template<typename CompTaskNodeType>
void BuildFromChainGph(std::unique_ptr<ChainGraph>&& chain_gph, bool need_bp);
void BuildExecAndEnrollLbn2Regsts();
std::string& mut_name() { return name_; }
TaskGraph() = delete;
~TaskGraph() = default;
TaskGraph(std::unique_ptr<const ChainGraph>&& chain_gph);
void BldSubTskGphByBoxing(
const ChainNode* src_chain, const ChainNode* dst_chain,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
HashMap<const ChainNode*, std::vector<TaskNode*>>* chain2sorted_in_box,
HashMap<const ChainNode*, std::vector<TaskNode*>>* chain2sorted_out_box);
void BldSubTskGphByOneToOne(
const ChainNode* src_chain, const ChainNode* dst_chain,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
HashMap<const ChainNode*, std::vector<TaskNode*>>* chain2sorted_in_box,
HashMap<const ChainNode*, std::vector<TaskNode*>>* chain2sorted_out_box);
void BldSubTskGphBySelectOneSourceToSoleSink(
const ChainNode* src_chain, const ChainNode* dst_chain,
const std::vector<CompTaskNode*>& sorted_src_comp_tasks,
const std::vector<CompTaskNode*>& sorted_dst_comp_tasks,
HashMap<const ChainNode*, std::vector<TaskNode*>>* chain2sorted_in_box,
HashMap<const ChainNode*, std::vector<TaskNode*>>* chain2sorted_out_box);
private:
template<typename CompTaskNodeType>
void BuildFromStageGph(bool need_bp);
template<typename TaskNodeType>
TaskNodeType* NewTaskNode() {
static_assert(std::is_base_of<TaskNode, TaskNodeType>::value, "");
TaskNodeType* ret = new TaskNodeType;
EnrollNode(ret);
return ret;
}
// Functions about Init
struct TaskNodesInStage {
std::vector<TaskNode*> comp_in_task_nodes;
std::vector<TaskNode*> comp_out_task_nodes;
BoxingTaskNode* in_boxing_task_node;
BoxingTaskNode* out_boxing_task_node;
};
using Stage2TaskNodesMap = HashMap<const StageNode*, TaskNodesInStage>;
template<typename TaskNodeType>
void InitCompTaskNodes(Stage2TaskNodesMap* stage2task_nodes);
template<typename TaskNodeType>
void Stage2DeviceCompTaskNodes(const StageNode* stage,
TaskNodesInStage* task_nodes_in_stage);
template<typename TaskNodeType>
void Stage2HostCompTaskNodes(const StageNode* stage,
TaskNodesInStage* task_nodes_in_stage);
void InitBoxingTaskNodes(Stage2TaskNodesMap* stage2task_nodes);
void InitInboxingTaskNode(const StageNode* stage,
TaskNodesInStage* task_nodes_in_stage);
void InitOutBoxingTaskNode(const StageNode* stage,
TaskNodesInStage* task_nodes_in_stage);
void ConnectBoxingTaskNodes(const Stage2TaskNodesMap* stage2task_nodes);
void GenerateRelatedBpNodes(std::vector<TaskNode*>* turning_node_vec);
void BackwardConnect(const std::vector<TaskNode*>& turning_node_vec);
void BuildBpStruct();
std::unique_ptr<const StageGraph> stage_gph_;
std::string name_;
TaskNode* AddCopyH2DTaskIfNotCpu(CompTaskNode*);
TaskNode* AddCopyD2HTaskIfNotCpu(CompTaskNode*);
void AddCopyCommNetTask(TaskNode* src, TaskNode* dst);
void BuildOutBoxingIfNeed(
const ChainNode*, const std::vector<CompTaskNode*>& sorted_comp_tasks,
HashMap<const ChainNode*, std::vector<TaskNode*>>* chain2sorted_out_box);
void BuildInBoxingIfNeed(
const ChainNode*, const std::vector<CompTaskNode*>& sorted_comp_tasks,
HashMap<const ChainNode*, std::vector<TaskNode*>>* chain2sorted_in_box);
void BuildStruct();
std::unique_ptr<const ChainGraph> chain_gph_;
};
} // namespace oneflow
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -21,6 +21,25 @@ class IDMgr final {
std::string MachineName4MachineId(int64_t machine_id) const {
return machine_id2machine_name_.at(machine_id);
}
int64_t GetThrdLocId(const std::string& name) const {
if (name == "persistence") {
return PersistenceThrdLocId();
} else if (name == "boxing") {
return BoxingThrdLocId();
} else {
UNEXPECTED_RUN();
}
}
DeviceType GetDeviceTypeFromThrdLocId(int64_t thrd_loc_id) {
if (thrd_loc_id < device_num_per_machine_) {
return JobDesc::Singleton()->resource().device_type();
} else {
return DeviceType::kCPU;
}
}
bool IsInherentThrd(int64_t thrd_loc_id) const {
return thrd_loc_id >= device_num_per_machine_;
}
int64_t ThrdLocId4DevPhyId(int64_t device_phy_id) const {
return device_phy_id;
}
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -8,8 +8,7 @@ void LossRecordKernel<T>::Forward(
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* loss_acc_blob = BnInOp2Blob("loss_acc");
T loss_mean = loss_acc_blob->dptr<T>()[0];
loss_mean /= JobDesc::Singleton()->piece_size()
* JobDesc::Singleton()->TotalMachineNum()
loss_mean /= JobDesc::Singleton()->ParallelPieceSize()
* JobDesc::Singleton()->piece_num_of_record_loss();
LOG(INFO) << "loss: " << loss_mean;
}
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册