io_worker.cpp 12.2 KB
Newer Older
C
chengtbf 已提交
1
#include "oneflow/core/comm_network/iocp/io_worker.h"
C
chengtbf 已提交
2
#include "oneflow/core/actor/actor_message_bus.h"
C
chengtbf 已提交
3
#include "oneflow/core/comm_network/iocp/iocp_comm_network.h"
4
#include "oneflow/core/control/ctrl_client.h"
C
chengtbf 已提交
5
#include "oneflow/core/job/job_desc.h"
6
#include "oneflow/core/job/resource.pb.h"
C
chengtbf 已提交
7
#include "oneflow/core/job/runtime_context.h"
C
chengtbf 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28

#ifdef PLATFORM_WINDOWS

namespace oneflow {

namespace {

sockaddr_in GetSockAddr(int64_t machine_id, uint16_t port) {
  const Machine& machine = JobDesc::Singleton()->resource().machine(machine_id);
  const std::string& addr = machine.addr();
  sockaddr_in sa;
  sa.sin_family = AF_INET;
  sa.sin_port = htons(port);
  PCHECK(inet_pton(AF_INET, addr.c_str(), &(sa.sin_addr)) == 1);
  return sa;
}

int64_t GetMachineId(const sockaddr_in& sa) {
  char addr[INET_ADDRSTRLEN];
  memset(addr, '\0', sizeof(addr));
  PCHECK(inet_ntop(AF_INET, (void*)(&(sa.sin_addr)), addr, INET_ADDRSTRLEN));
29 30
  for (int64_t i = 0; i < JobDesc::Singleton()->TotalMachineNum(); ++i) {
    if (JobDesc::Singleton()->resource().machine(i).addr() == addr) {
C
chengtbf 已提交
31 32 33 34 35 36 37 38
      return i;
    }
  }
  UNEXPECTED_RUN();
}

}  // namespace

C
chengtbf 已提交
39 40 41 42
IOWorker::IOWorker() {
  num_of_concurrent_threads_ = JobDesc::Singleton()->CommNetIOWorkerNum();
  this_machine_id_ = RuntimeCtx::Singleton()->this_machine_id();
  total_machine_num_ = JobDesc::Singleton()->TotalMachineNum();
C
chengtbf 已提交
43 44 45
  // load winsock
  WSADATA wsd;
  PCHECK(WSAStartup(MAKEWORD(2, 2), &wsd) == 0) << "Unable to load Winsock.\n";
C
chengtbf 已提交
46 47
  // create completion port and start N worker thread for it
  // N = num_of_concurrent_threads_
48 49 50 51 52
  completion_port_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0,
                                            num_of_concurrent_threads_);
  PCHECK(completion_port_ != NULL)
      << "CreateIoCompletionPort failed. Error:" << GetLastError() << "\n";

C
chengtbf 已提交
53
  InitSockets();
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

  // post WSARecv to socket for recv msg from remote
  machine_id2io_data_recv_.clear();
  for (int64_t i = 0; i < total_machine_num_; ++i) {
    machine_id2io_data_recv_.push_back(new IOData);
  }
  for (int64_t i = 0; i < total_machine_num_; ++i) {
    if (i != this_machine_id_) { PostNewWSARecv2Socket(i); }
  }

  // init send queue vector
  machine_id2io_data_send_que_.clear();
  for (int64_t i = 0; i < total_machine_num_; ++i) {
    std::queue<IOData*> q;
    machine_id2io_data_send_que_.push_back(q);
  }
C
chengtbf 已提交
70
  machine_id2send_que_mtx_ = std::vector<std::mutex>(total_machine_num_);
C
chengtbf 已提交
71 72 73
}

IOWorker::~IOWorker() {
74 75
  for (int64_t i = 0; i < total_machine_num_; ++i) {
    if (i != this_machine_id_) {
C
chengtbf 已提交
76 77
      PCHECK(closesocket(machine_id2socket_[i]) == 0);
    }
C
chengtbf 已提交
78
    delete machine_id2io_data_recv_[i];
C
chengtbf 已提交
79 80
  }
  WSACleanup();
C
chengtbf 已提交
81 82
}

83 84
void IOWorker::PostSendMsgRequest(int64_t dst_machine_id,
                                  SocketMsg socket_msg) {
C
chengtbf 已提交
85 86 87 88
  SOCKET s = machine_id2socket_[dst_machine_id];
  IOData* io_data_ptr = new IOData;
  memset(&(io_data_ptr->overlapped), 0, sizeof(OVERLAPPED));
  io_data_ptr->socket_msg = socket_msg;
C
chengtbf 已提交
89
  io_data_ptr->IO_type = IOType::kFirstSendMsgHead;
90 91
  io_data_ptr->data_buff.buf =
      reinterpret_cast<char*>(&(io_data_ptr->socket_msg));
C
chengtbf 已提交
92 93 94
  io_data_ptr->data_buff.len = sizeof(SocketMsg);
  io_data_ptr->target_machine_id = dst_machine_id;
  io_data_ptr->target_socket_fd = s;
95 96 97
  io_data_ptr->flags = 0;
  PostQueuedCompletionStatus(completion_port_, 0, s,
                             reinterpret_cast<LPOVERLAPPED>(io_data_ptr));
C
chengtbf 已提交
98 99
}

C
chengtbf 已提交
100
void IOWorker::Start() {
101 102 103 104 105
  for (size_t i = 0; i < num_of_concurrent_threads_; ++i) {
    HANDLE worker_thread_handle =
        CreateThread(NULL, 0, IOWorker::StartThreadProc, this, 0, NULL);
    PCHECK(worker_thread_handle != NULL)
        << "Create Thread Handle failed. Error:" << GetLastError() << "\n";
C
chengtbf 已提交
106 107 108
    CloseHandle(worker_thread_handle);
  }
}
C
chengtbf 已提交
109 110

void IOWorker::Stop() {
C
chengtbf 已提交
111 112 113 114 115 116 117 118 119
  for (size_t i = 0; i < num_of_concurrent_threads_; ++i) {
    IOData* stop_io_data = new IOData;
    stop_io_data->IO_type = IOType::kStop;
    memset(&(stop_io_data->overlapped), 0, sizeof(OVERLAPPED));
    ResetIODataBuff(stop_io_data);
    PostQueuedCompletionStatus(completion_port_, 0, i,
                               reinterpret_cast<LPOVERLAPPED>(stop_io_data));
    LOG(INFO) << "Post stop request " << i << " to IOCP\n";
  }
C
chengtbf 已提交
120 121 122
}

void IOWorker::InitSockets() {
123
  machine_id2socket_.clear();
C
chengtbf 已提交
124 125 126
  machine_id2socket_.assign(total_machine_num_, -1);
  // listen
  SOCKET listen_socket = socket(AF_INET, SOCK_STREAM, 0);
127 128
  PCHECK(listen_socket != INVALID_SOCKET)
      << "socket failed with error:" << WSAGetLastError() << "\n";
C
chengtbf 已提交
129 130
  uint16_t this_listen_port = 1024;
  uint16_t listen_port_max = std::numeric_limits<uint16_t>::max();
131
  for (; this_listen_port < listen_port_max; ++this_listen_port) {
C
chengtbf 已提交
132 133
    sockaddr_in this_sockaddr = GetSockAddr(this_machine_id_, this_listen_port);
    int bind_result =
134 135 136
        bind(listen_socket, reinterpret_cast<sockaddr*>(&this_sockaddr),
             sizeof(this_sockaddr));
    if (bind_result == 0) {
C
chengtbf 已提交
137 138 139 140
      PCHECK(listen(listen_socket, total_machine_num_) == 0);
      CtrlClient::Singleton()->PushPort(this_listen_port);
      break;
    } else {
C
chengtbf 已提交
141 142
      PCHECK(WSAGetLastError() == WSAEACCES
             || WSAGetLastError() == WSAEADDRINUSE);
C
chengtbf 已提交
143 144 145 146
    }
  }
  CHECK_LT(this_listen_port, listen_port_max);
  // connect
147 148
  FOR_RANGE(int64_t, peer_machine_id, this_machine_id_ + 1,
            total_machine_num_) {
C
chengtbf 已提交
149 150 151 152 153 154 155
    uint16_t peer_port = CtrlClient::Singleton()->PullPort(peer_machine_id);
    sockaddr_in peer_sockaddr = GetSockAddr(peer_machine_id, peer_port);
    SOCKET s = socket(AF_INET, SOCK_STREAM, 0);
    PCHECK(connect(s, reinterpret_cast<sockaddr*>(&peer_sockaddr),
                   sizeof(peer_sockaddr))
           == 0);
    machine_id2socket_[peer_machine_id] = s;
156 157
    PCHECK(CreateIoCompletionPort((HANDLE)s, completion_port_, s, 0) != NULL)
        << "bind to completion port err:" << GetLastError() << "\n";
C
chengtbf 已提交
158 159 160 161 162 163 164
  }
  // accept
  FOR_RANGE(int64_t, idx, 0, this_machine_id_) {
    sockaddr_in peer_sockaddr;
    socklen_t len = sizeof(peer_sockaddr);
    SOCKET s = accept(listen_socket,
                      reinterpret_cast<sockaddr*>(&peer_sockaddr), &len);
165 166
    PCHECK(s != INVALID_SOCKET)
        << "socket accept error: " << WSAGetLastError() << "\n";
C
chengtbf 已提交
167 168
    int64_t peer_machine_id = GetMachineId(peer_sockaddr);
    machine_id2socket_[peer_machine_id] = s;
169 170
    PCHECK(CreateIoCompletionPort((HANDLE)s, completion_port_, s, 0) != NULL)
        << "bind to completion port err:" << GetLastError() << "\n";
C
chengtbf 已提交
171
  }
C
chengtbf 已提交
172
  PCHECK(closesocket(listen_socket) == 0);
C
chengtbf 已提交
173 174 175
  // useful log
  FOR_RANGE(int64_t, machine_id, 0, total_machine_num_) {
    LOG(INFO) << "machine " << machine_id << " sockfd "
176
              << machine_id2socket_[machine_id];
C
chengtbf 已提交
177 178
  }
}
C
chengtbf 已提交
179

180 181 182 183 184
void IOWorker::PostNewWSARecv2Socket(int64_t dst_machine_id) {
  SOCKET s = machine_id2socket_[dst_machine_id];
  IOData* io_data_ptr = machine_id2io_data_recv_[dst_machine_id];
  memset(&(io_data_ptr->overlapped), 0, sizeof(OVERLAPPED));
  io_data_ptr->IO_type = IOType::kRecvMsgHead;
C
chengtbf 已提交
185
  ResetIODataBuff(io_data_ptr);
186 187 188
  io_data_ptr->target_machine_id = dst_machine_id;
  io_data_ptr->target_socket_fd = s;
  io_data_ptr->flags = 0;
C
chengtbf 已提交
189
  WSARecvFromIOData(io_data_ptr);
190 191 192 193 194 195 196 197 198 199 200 201
}

DWORD IOWorker::ThreadProc() {
  DWORD bytes_transferred;
  SOCKET completion_key;
  IOData* io_data_ptr;
  while (true) {
    CHECK(GetQueuedCompletionStatus(
              completion_port_, &bytes_transferred, &completion_key,
              reinterpret_cast<LPOVERLAPPED*>(&io_data_ptr), INFINITE)
          == true)
        << "GetQueuedCompletionStatus Error: " << GetLastError() << "\n";
C
chengtbf 已提交
202 203 204
    io_data_ptr->data_buff.buf += bytes_transferred;
    io_data_ptr->data_buff.len -= bytes_transferred;
    CHECK_GE(io_data_ptr->data_buff.len, 0);
205 206 207
    switch (io_data_ptr->IO_type) {
      case IOType::kStop: {
        delete io_data_ptr;
C
chengtbf 已提交
208
        LOG(INFO) << "stop IOworker " << completion_key << " \n";
209 210 211
        return 0;
      }
      case IOType::kRecvMsgHead: {
C
chengtbf 已提交
212 213
        if (io_data_ptr->data_buff.len == 0) { OnRecvMsgHeadDone(io_data_ptr); }
        WSARecvFromIOData(io_data_ptr);
214 215 216
        break;
      }
      case IOType::kRecvMsgBody: {
C
chengtbf 已提交
217 218
        if (io_data_ptr->data_buff.len == 0) { OnRecvMsgBodyDone(io_data_ptr); }
        WSARecvFromIOData(io_data_ptr);
219 220
        break;
      }
C
chengtbf 已提交
221 222 223 224
      case IOType::kFirstSendMsgHead: {
        OnFirstSendMsgHead(io_data_ptr);
        break;
      }
225
      case IOType::kSendMsgHead: {
C
chengtbf 已提交
226 227 228 229 230
        if (io_data_ptr->data_buff.len == 0) {
          OnSendMsgHeadDone(io_data_ptr);
        } else {
          WSASendFromIOData(io_data_ptr);
        }
231 232 233
        break;
      }
      case IOType::kSendMsgBody: {
C
chengtbf 已提交
234 235 236 237 238
        if (io_data_ptr->data_buff.len == 0) {
          OnSendDone(io_data_ptr);
        } else {
          WSASendFromIOData(io_data_ptr);
        }
239 240 241 242 243 244 245 246
        break;
      }
      default: UNEXPECTED_RUN()
    }
  }
  return 0;
}

C
chengtbf 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
void IOWorker::OnRecvMsgHeadDone(IOData* io_data_ptr) {
  switch (io_data_ptr->socket_msg.msg_type) {
    case SocketMsgType::kActor: {
      ActorMsgBus::Singleton()->SendMsg(io_data_ptr->socket_msg.actor_msg);
      ResetIODataBuff(io_data_ptr);
      break;
    }
    case SocketMsgType::kRequsetRead: {
      auto mem_desc_ptr = static_cast<const SocketMemDesc*>(
          io_data_ptr->socket_msg.socket_token.read_machine_mem_desc_);
      io_data_ptr->data_buff.buf =
          reinterpret_cast<char*>(mem_desc_ptr->mem_ptr);
      io_data_ptr->data_buff.len = mem_desc_ptr->byte_size;
      io_data_ptr->IO_type = IOType::kRecvMsgBody;
      break;
C
chengtbf 已提交
262
    }
C
chengtbf 已提交
263 264 265 266 267 268
    case SocketMsgType::kRequestWrite: {
      SocketMsg msg;
      msg.msg_type = SocketMsgType::kRequsetRead;
      msg.socket_token = io_data_ptr->socket_msg.socket_token;
      PostSendMsgRequest(io_data_ptr->target_machine_id, msg);
      ResetIODataBuff(io_data_ptr);
C
chengtbf 已提交
269
      break;
C
chengtbf 已提交
270 271
    }
    default: UNEXPECTED_RUN()
C
chengtbf 已提交
272
  }
273 274
}

C
chengtbf 已提交
275 276 277 278 279 280
void IOWorker::OnRecvMsgBodyDone(IOData* io_data_ptr) {
  CHECK(io_data_ptr->socket_msg.msg_type == SocketMsgType::kRequsetRead);
  IOCPCommNet::Singleton()->ReadDone(
      io_data_ptr->socket_msg.socket_token.read_done_id);
  ResetIODataBuff(io_data_ptr);
  io_data_ptr->IO_type = IOType::kRecvMsgHead;
281 282
}

C
chengtbf 已提交
283 284 285 286 287 288 289 290 291 292
void IOWorker::OnFirstSendMsgHead(IOData* io_data_ptr) {
  {
    std::unique_lock<std::mutex> lck(
        machine_id2send_que_mtx_[io_data_ptr->target_machine_id]);
    std::queue<IOData*>& send_que =
        machine_id2io_data_send_que_[io_data_ptr->target_machine_id];
    io_data_ptr->IO_type = IOType::kSendMsgHead;
    send_que.push(io_data_ptr);
    if (send_que.size() == 1) { WSASendFromIOData(io_data_ptr); }
  }
293 294
}

C
chengtbf 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
void IOWorker::OnSendMsgHeadDone(IOData* io_data_ptr) {
  switch (io_data_ptr->socket_msg.msg_type) {
    case SocketMsgType::kActor: {
      OnSendDone(io_data_ptr);
      break;
    }
    case SocketMsgType::kRequsetRead: {
      auto mem_desc_ptr = static_cast<const SocketMemDesc*>(
          io_data_ptr->socket_msg.socket_token.write_machine_mem_desc_);
      io_data_ptr->data_buff.buf =
          reinterpret_cast<char*>(mem_desc_ptr->mem_ptr);
      io_data_ptr->data_buff.len = mem_desc_ptr->byte_size;
      io_data_ptr->IO_type = IOType::kSendMsgBody;
      WSASendFromIOData(io_data_ptr);
      break;
    }
    case SocketMsgType::kRequestWrite: {
      OnSendDone(io_data_ptr);
      break;
    }
    default: UNEXPECTED_RUN()
  }
}

void IOWorker::OnSendDone(IOData* io_data_ptr) {
  {
    std::unique_lock<std::mutex> lck(
        machine_id2send_que_mtx_[io_data_ptr->target_machine_id]);
    std::queue<IOData*>& send_que =
        machine_id2io_data_send_que_[io_data_ptr->target_machine_id];
    CHECK(io_data_ptr == send_que.front());
    send_que.pop();
    delete io_data_ptr;
    if (!send_que.empty()) {
      IOData* next_io_data_ptr = send_que.front();
      WSASendFromIOData(next_io_data_ptr);
    }
  }
333 334
}

C
chengtbf 已提交
335 336 337 338 339 340
void IOWorker::ResetIODataBuff(IOData* io_data_ptr) {
  io_data_ptr->data_buff.buf =
      reinterpret_cast<char*>(&(io_data_ptr->socket_msg));
  io_data_ptr->data_buff.len = sizeof(SocketMsg);
}

C
chengtbf 已提交
341 342 343 344 345 346
void IOWorker::WSARecvFromIOData(IOData* io_data_ptr) {
  WSARecv(io_data_ptr->target_socket_fd, &(io_data_ptr->data_buff), 1, NULL,
          &(io_data_ptr->flags), reinterpret_cast<LPOVERLAPPED>(io_data_ptr),
          NULL);
}

C
chengtbf 已提交
347 348 349 350 351 352
void IOWorker::WSASendFromIOData(IOData* io_data_ptr) {
  WSASend(io_data_ptr->target_socket_fd, &(io_data_ptr->data_buff), 1, NULL,
          io_data_ptr->flags, reinterpret_cast<LPOVERLAPPED>(io_data_ptr),
          NULL);
}

C
chengtbf 已提交
353 354 355
}  // namespace oneflow

#endif  // PLATFORM_WINDOW