提交 03728d45 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mgb/build): fix multi-machine macro and add test_distributed

GitOrigin-RevId: cb1bfe8742f2d3ddf8133df5b7347429d76aebce
上级 d74a4ee9
......@@ -55,10 +55,10 @@ add_custom_command(
add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py)
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp)
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp)
if(MGE_WITH_DISTRIBUTED)
list(APPEND SRCS src/cpp/mm_handler.cpp src/cpp/zmq_rpc.cpp)
list(APPEND SRCS src/cpp/zmq_rpc.cpp)
endif()
include(UseSWIG)
......
......@@ -65,12 +65,10 @@ class _config {
static std::vector<std::pair<uint64_t, std::string>>
dump_registered_oprs();
#if MGB_ENABLE_OPR_MM
static int create_mm_server(const std::string& server_addr, int port);
static void group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank);
#endif
};
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -12,7 +12,7 @@
#include "megbrain/exception.h"
#include "megbrain_config.h"
#if MGB_CUDA
#if MGB_ENABLE_OPR_MM
#include "zmq_rpc.h"
#include <future>
......@@ -242,17 +242,11 @@ int _config::create_mm_server(const std::string& server_addr, int port) {
server_addr, port, std::make_unique<GroupServerProxy>());
}
#else
int _config::create_mm_server(const std::string& server_addr, int port) {
mgb_throw(mgb::MegBrainError, "CUDA suppport disable at compile time");
return 0;
}
#endif
/* ======================== Group Barrier ========================== */
/*! see definition : src/cpp/megbrain_config.h.
* Block until all ranks in the group reach this barrier
*/
void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_assert(rank < size, "invalid rank %d", rank);
......@@ -263,4 +257,18 @@ void _config::group_barrier(const std::string& server_addr,
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp);
}
#else
int _config::create_mm_server(const std::string& server_addr, int port) {
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
return 0;
}
void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -11,7 +11,7 @@
#include "megbrain_build_config.h"
#if MGB_CUDA
#if MGB_ENABLE_OPR_MM
#include "zmq_rpc.h"
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import subprocess
import sys
import numpy as np
def worker(master_ip, master_port, world_size, rank, dev, trace):
import megengine.distributed as dist
import megengine.functional as F
from megengine import is_cuda_available
from megengine import jit
from megengine.module import Linear, Module
from megengine.optimizer import SGD
if not is_cuda_available():
return
class MLP(Module):
def __init__(self):
super().__init__()
self.fc0 = Linear(3 * 224 * 224, 500)
self.fc1 = Linear(500, 10)
def forward(self, x):
x = self.fc0(x)
x = F.relu(x)
x = self.fc1(x)
return x
dist.init_process_group(
master_ip=master_ip, master_port=3456, world_size=world_size, rank=rank, dev=dev
)
net = MLP()
opt = SGD(net.parameters(requires_grad=True), lr=0.02)
data = np.random.random((64, 3 * 224 * 224)).astype(np.float32)
label = np.random.randint(0, 10, size=(64,)).astype(np.int32)
jit.trace.enabled = trace
@jit.trace()
def train_func(data, label):
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss)
return loss
for i in range(5):
opt.zero_grad()
loss = train_func(data, label)
opt.step()
def start_workers(worker, world_size, trace=False):
def run_subproc(rank):
cmd = "from test.integration.test_distributed import worker\n"
cmd += "worker('localhost', 3456, {}, {}, {}, {})".format(
world_size, rank, rank, "True" if trace else "False"
)
cmd = ["python3", "-c", cmd]
ret = subprocess.run(
cmd, stdout=sys.stdout, stderr=sys.stderr, universal_newlines=True
)
assert ret.returncode == 0, "subprocess failed"
procs = []
for rank in range(world_size):
p = mp.Process(target=run_subproc, args=(rank,))
p.start()
procs.append(p)
for p in procs:
p.join()
assert p.exitcode == 0
def test_distributed():
start_workers(worker, 2, trace=True)
start_workers(worker, 2, trace=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册