提交 044e757c 编写于 作者: W willzhang4a58

fix bug: io worker exit


Former-commit-id: 79819900
上级 67a3fd6d
...@@ -15,7 +15,7 @@ train_conf { ...@@ -15,7 +15,7 @@ train_conf {
model_save_snapshots_path: "/willzhang/snapshots" model_save_snapshots_path: "/willzhang/snapshots"
num_of_batches_in_snapshot: 6 num_of_batches_in_snapshot: 6
staleness: 0 staleness: 0
total_batch_num: 24 total_batch_num: 15
default_fill_conf { default_fill_conf {
gaussian_conf { gaussian_conf {
mean: 0.0 mean: 0.0
......
...@@ -88,6 +88,7 @@ void CopyCommNetActor::Act() { ...@@ -88,6 +88,7 @@ void CopyCommNetActor::Act() {
[&](Regst* regst) { regst->set_piece_id(cur_piece_id); }); [&](Regst* regst) { regst->set_piece_id(cur_piece_id); });
AsyncSendRegstMsgToProducer(readable_regst, readable_it->second.producer); AsyncSendRegstMsgToProducer(readable_regst, readable_it->second.producer);
comm_net_device_ctx_->set_read_id(nullptr); comm_net_device_ctx_->set_read_id(nullptr);
DataCommNet::Singleton()->AddReadCallBackDone(read_id);
// Finish // Finish
piece_id2regst_ctx.erase(readable_it); piece_id2regst_ctx.erase(readable_it);
mut_num_of_read_empty() = piece_id2regst_ctx.empty(); mut_num_of_read_empty() = piece_id2regst_ctx.empty();
......
...@@ -25,6 +25,7 @@ class DataCommNet { ...@@ -25,6 +25,7 @@ class DataCommNet {
const void* dst_token) = 0; const void* dst_token) = 0;
virtual void AddReadCallBack(void* read_id, virtual void AddReadCallBack(void* read_id,
std::function<void()> callback) = 0; std::function<void()> callback) = 0;
virtual void AddReadCallBackDone(void* read_id) = 0;
// //
virtual void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) = 0; virtual void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) = 0;
......
...@@ -33,6 +33,11 @@ int64_t GetMachineId(const sockaddr_in& sa) { ...@@ -33,6 +33,11 @@ int64_t GetMachineId(const sockaddr_in& sa) {
} // namespace } // namespace
EpollDataCommNet::~EpollDataCommNet() { EpollDataCommNet::~EpollDataCommNet() {
for (size_t i = 0; i < pollers_.size(); ++i) {
LOG(INFO) << "IOWorker " << i << " finish";
pollers_[i]->Stop();
}
OF_BARRIER();
for (IOEventPoller* poller : pollers_) { delete poller; } for (IOEventPoller* poller : pollers_) { delete poller; }
for (auto& pair : sockfd2helper_) { delete pair.second; } for (auto& pair : sockfd2helper_) { delete pair.second; }
} }
...@@ -69,22 +74,87 @@ void EpollDataCommNet::RegisterMemoryDone() { ...@@ -69,22 +74,87 @@ void EpollDataCommNet::RegisterMemoryDone() {
void* EpollDataCommNet::Read(int64_t src_machine_id, const void* src_token, void* EpollDataCommNet::Read(int64_t src_machine_id, const void* src_token,
const void* dst_token) { const void* dst_token) {
auto callback_list = new CallBackList; // ReadContext
ReadContext* read_ctx = new ReadContext;
read_ctx->cbl.clear();
read_ctx->done_cnt = 0;
{
std::unique_lock<std::mutex> lck(undeleted_read_ctxs_mtx_);
CHECK(undeleted_read_ctxs_.insert(read_ctx).second);
}
// request write msg
SocketMsg msg; SocketMsg msg;
msg.msg_type = SocketMsgType::kRequestWrite; msg.msg_type = SocketMsgType::kRequestWrite;
msg.request_write_msg.src_token = src_token; msg.request_write_msg.src_token = src_token;
msg.request_write_msg.dst_machine_id = msg.request_write_msg.dst_machine_id =
RuntimeCtx::Singleton()->this_machine_id(); RuntimeCtx::Singleton()->this_machine_id();
msg.request_write_msg.dst_token = dst_token; msg.request_write_msg.dst_token = dst_token;
msg.request_write_msg.read_id = callback_list; msg.request_write_msg.read_id = read_ctx;
GetSocketHelper(src_machine_id)->AsyncWrite(msg); GetSocketHelper(src_machine_id)->AsyncWrite(msg);
return callback_list; return read_ctx;
} }
void EpollDataCommNet::AddReadCallBack(void* read_id, void EpollDataCommNet::AddReadCallBack(void* read_id,
std::function<void()> callback) { std::function<void()> callback) {
auto callback_list = static_cast<CallBackList*>(read_id); ReadContext* read_ctx = static_cast<ReadContext*>(read_id);
callback_list->push_back(callback); if (read_id) {
read_ctx->cbl.push_back(callback);
return;
}
CallBackContext* cb_ctx = new CallBackContext;
cb_ctx->callback = callback;
do {
std::unique_lock<std::mutex> read_ctxs_lck(undeleted_read_ctxs_mtx_);
if (undeleted_read_ctxs_.empty()) { break; }
cb_ctx->cnt = undeleted_read_ctxs_.size();
for (ReadContext* read_ctx : undeleted_read_ctxs_) {
std::unique_lock<std::mutex> cbl_lck(read_ctx->cbl_mtx);
read_ctx->cbl.push_back([cb_ctx]() { cb_ctx->DecreaseCnt(); });
}
return;
} while (0);
delete cb_ctx;
callback();
}
void EpollDataCommNet::AddReadCallBackDone(void* read_id) {
IncreaseDoneCnt(read_id);
}
void EpollDataCommNet::ReadDone(void* read_id) { IncreaseDoneCnt(read_id); }
void EpollDataCommNet::IncreaseDoneCnt(void* read_id) {
ReadContext* read_ctx = static_cast<ReadContext*>(read_id);
do {
std::unique_lock<std::mutex> lck(read_ctx->done_cnt_mtx);
read_ctx->done_cnt += 1;
if (read_ctx->done_cnt == 2) {
break;
} else {
return;
}
} while (0);
std::unique_lock<std::mutex> read_ctxs_lck(undeleted_read_ctxs_mtx_);
CHECK_EQ(undeleted_read_ctxs_.erase(read_ctx), 1);
{
std::unique_lock<std::mutex> cbl_lck(read_ctx->cbl_mtx);
for (std::function<void()>& callback : read_ctx->cbl) { callback(); }
}
delete read_ctx;
}
void EpollDataCommNet::CallBackContext::DecreaseCnt() {
do {
std::unique_lock<std::mutex> lck(cnt_mtx);
cnt -= 1;
if (cnt == 0) {
break;
} else {
return;
}
} while (0);
callback();
delete this;
} }
void EpollDataCommNet::SendActorMsg(int64_t dst_machine_id, void EpollDataCommNet::SendActorMsg(int64_t dst_machine_id,
......
...@@ -27,12 +27,15 @@ class EpollDataCommNet final : public DataCommNet { ...@@ -27,12 +27,15 @@ class EpollDataCommNet final : public DataCommNet {
void* Read(int64_t src_machine_id, const void* src_token, void* Read(int64_t src_machine_id, const void* src_token,
const void* dst_token) override; const void* dst_token) override;
void AddReadCallBack(void* read_id, std::function<void()> callback) override; void AddReadCallBack(void* read_id, std::function<void()> callback) override;
void AddReadCallBackDone(void* read_id) override;
void ReadDone(void* read_id);
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override; void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg); void SendSocketMsg(int64_t dst_machine_id, const SocketMsg& msg);
private: private:
EpollDataCommNet(); EpollDataCommNet();
void IncreaseDoneCnt(void* read_id);
void InitSockets(); void InitSockets();
SocketHelper* GetSocketHelper(int64_t machine_id); SocketHelper* GetSocketHelper(int64_t machine_id);
...@@ -40,6 +43,21 @@ class EpollDataCommNet final : public DataCommNet { ...@@ -40,6 +43,21 @@ class EpollDataCommNet final : public DataCommNet {
std::mutex mem_desc_mtx_; std::mutex mem_desc_mtx_;
std::list<SocketMemDesc*> mem_descs_; std::list<SocketMemDesc*> mem_descs_;
size_t unregister_mem_descs_cnt_; size_t unregister_mem_descs_cnt_;
// Read
struct ReadContext {
std::mutex cbl_mtx;
CallBackList cbl;
std::mutex done_cnt_mtx;
int8_t done_cnt;
};
struct CallBackContext {
void DecreaseCnt();
std::function<void()> callback;
std::mutex cnt_mtx;
int32_t cnt;
};
std::mutex undeleted_read_ctxs_mtx_;
HashSet<ReadContext*> undeleted_read_ctxs_;
// Socket // Socket
std::vector<IOEventPoller*> pollers_; std::vector<IOEventPoller*> pollers_;
std::vector<int> machine_id2sockfd_; std::vector<int> machine_id2sockfd_;
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#ifdef PLATFORM_POSIX #ifdef PLATFORM_POSIX
#include <sys/eventfd.h>
namespace oneflow { namespace oneflow {
const int IOEventPoller::max_event_num_ = 32; const int IOEventPoller::max_event_num_ = 32;
...@@ -9,15 +11,18 @@ const int IOEventPoller::max_event_num_ = 32; ...@@ -9,15 +11,18 @@ const int IOEventPoller::max_event_num_ = 32;
IOEventPoller::IOEventPoller() { IOEventPoller::IOEventPoller() {
epfd_ = epoll_create1(0); epfd_ = epoll_create1(0);
ep_events_ = new epoll_event[max_event_num_]; ep_events_ = new epoll_event[max_event_num_];
unclosed_fd_cnt_ = 0;
io_handlers_.clear(); io_handlers_.clear();
break_epoll_loop_fd_ = eventfd(0, 0);
PCHECK(break_epoll_loop_fd_ != -1);
AddFdWithOnlyReadHandler(break_epoll_loop_fd_,
[]() { LOG(INFO) << "Break Epoll Loop"; });
} }
IOEventPoller::~IOEventPoller() { IOEventPoller::~IOEventPoller() {
for (IOHandler* handler : io_handlers_) { PCHECK(close(handler->fd) == 0); } for (IOHandler* handler : io_handlers_) {
thread_.join(); PCHECK(close(handler->fd) == 0);
for (IOHandler* handler : io_handlers_) { delete handler; } delete handler;
CHECK_EQ(unclosed_fd_cnt_, 0); }
delete[] ep_events_; delete[] ep_events_;
PCHECK(close(epfd_) == 0); PCHECK(close(epfd_) == 0);
} }
...@@ -36,9 +41,14 @@ void IOEventPoller::Start() { ...@@ -36,9 +41,14 @@ void IOEventPoller::Start() {
thread_ = std::thread(&IOEventPoller::EpollLoop, this); thread_ = std::thread(&IOEventPoller::EpollLoop, this);
} }
void IOEventPoller::Stop() {
uint64_t break_epoll_loop_event = 1;
PCHECK(write(break_epoll_loop_fd_, &break_epoll_loop_event, 8) == 8);
thread_.join();
}
void IOEventPoller::AddFd(int fd, std::function<void()>* read_handler, void IOEventPoller::AddFd(int fd, std::function<void()>* read_handler,
std::function<void()>* write_handler) { std::function<void()>* write_handler) {
unclosed_fd_cnt_ += 1;
// Set Fd NONBLOCK // Set Fd NONBLOCK
int opt = fcntl(fd, F_GETFL); int opt = fcntl(fd, F_GETFL);
PCHECK(opt != -1); PCHECK(opt != -1);
...@@ -63,16 +73,20 @@ void IOEventPoller::AddFd(int fd, std::function<void()>* read_handler, ...@@ -63,16 +73,20 @@ void IOEventPoller::AddFd(int fd, std::function<void()>* read_handler,
} }
void IOEventPoller::EpollLoop() { void IOEventPoller::EpollLoop() {
while (unclosed_fd_cnt_ > 0) { while (true) {
int event_num = epoll_wait(epfd_, ep_events_, max_event_num_, -1); int event_num = epoll_wait(epfd_, ep_events_, max_event_num_, -1);
PCHECK(event_num >= 0); PCHECK(event_num >= 0);
const epoll_event* cur_event = ep_events_; const epoll_event* cur_event = ep_events_;
for (int event_idx = 0; event_idx < event_num; ++event_idx, ++cur_event) { for (int event_idx = 0; event_idx < event_num; ++event_idx, ++cur_event) {
auto io_handler = static_cast<IOHandler*>(cur_event->data.ptr); auto io_handler = static_cast<IOHandler*>(cur_event->data.ptr);
PCHECK(!(cur_event->events & EPOLLERR)) << "fd: " << io_handler->fd; PCHECK(!(cur_event->events & EPOLLERR)) << "fd: " << io_handler->fd;
if (io_handler->fd == break_epoll_loop_fd_) { return; }
if (cur_event->events & EPOLLIN) { if (cur_event->events & EPOLLIN) {
if (cur_event->events & EPOLLRDHUP) { unclosed_fd_cnt_ -= 1; } if (cur_event->events & EPOLLRDHUP) {
io_handler->read_handler(); LOG(FATAL) << "fd " << io_handler->fd << " closed by peer";
} else {
io_handler->read_handler();
}
} }
if (cur_event->events & EPOLLOUT) { io_handler->write_handler(); } if (cur_event->events & EPOLLOUT) { io_handler->write_handler(); }
} }
......
...@@ -18,6 +18,7 @@ class IOEventPoller final { ...@@ -18,6 +18,7 @@ class IOEventPoller final {
void AddFdWithOnlyReadHandler(int fd, std::function<void()> read_handler); void AddFdWithOnlyReadHandler(int fd, std::function<void()> read_handler);
void Start(); void Start();
void Stop();
private: private:
struct IOHandler { struct IOHandler {
...@@ -39,10 +40,9 @@ class IOEventPoller final { ...@@ -39,10 +40,9 @@ class IOEventPoller final {
int epfd_; int epfd_;
epoll_event* ep_events_; epoll_event* ep_events_;
int64_t unclosed_fd_cnt_;
std::forward_list<IOHandler*> io_handlers_; std::forward_list<IOHandler*> io_handlers_;
int break_epoll_loop_fd_;
std::thread thread_; std::thread thread_;
std::vector<int> fds_;
}; };
} // namespace oneflow } // namespace oneflow
......
...@@ -66,9 +66,7 @@ void SocketReadHelper::SetStatusWhenMsgHeadDone() { ...@@ -66,9 +66,7 @@ void SocketReadHelper::SetStatusWhenMsgHeadDone() {
void SocketReadHelper::SetStatusWhenMsgBodyDone() { void SocketReadHelper::SetStatusWhenMsgBodyDone() {
if (cur_msg_.msg_type == SocketMsgType::kRequestRead) { if (cur_msg_.msg_type == SocketMsgType::kRequestRead) {
auto cbl = static_cast<CallBackList*>(cur_msg_.request_read_msg.read_id); EpollDataCommNet::Singleton()->ReadDone(cur_msg_.request_read_msg.read_id);
for (std::function<void()>& callback : *cbl) { callback(); }
delete cbl;
} }
SwitchToMsgHeadReadHandle(); SwitchToMsgHeadReadHandle();
} }
......
...@@ -163,7 +163,7 @@ void TaskNode::ToProto( ...@@ -163,7 +163,7 @@ void TaskNode::ToProto(
std::string TaskNode::VisualStr() const { std::string TaskNode::VisualStr() const {
std::stringstream ss; std::stringstream ss;
ss << (is_fw_node_ ? "Fw" : "Bp"); ss << (is_fw_node_ ? "Fw" : "Bp");
ss << node_id_str() << "_"; ss << task_id_str() << "_";
return ss.str(); return ss.str();
} }
......
...@@ -70,6 +70,7 @@ void Runtime::Run(const Plan& plan, const std::string& this_machine_name) { ...@@ -70,6 +70,7 @@ void Runtime::Run(const Plan& plan, const std::string& this_machine_name) {
SendCmdMsg(mdupdt_tasks, ActorCmd::kSendInitialModel); SendCmdMsg(mdupdt_tasks, ActorCmd::kSendInitialModel);
SendCmdMsg(source_tasks, ActorCmd::kStart); SendCmdMsg(source_tasks, ActorCmd::kStart);
RuntimeCtx::Singleton()->mut_active_actor_cnt().WaitUntilCntEqualZero(); RuntimeCtx::Singleton()->mut_active_actor_cnt().WaitUntilCntEqualZero();
OF_BARRIER();
DeleteAllSingleton(); DeleteAllSingleton();
} }
......
...@@ -10,7 +10,7 @@ ThreadMgr::~ThreadMgr() { ...@@ -10,7 +10,7 @@ ThreadMgr::~ThreadMgr() {
ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread); ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread);
threads_[i]->GetMsgChannelPtr()->Send(msg); threads_[i]->GetMsgChannelPtr()->Send(msg);
threads_[i].reset(); threads_[i].reset();
LOG(INFO) << "thread " << i << " finish"; LOG(INFO) << "actor thread " << i << " finish";
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册