未验证 提交 12e873b9 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #11160 from typhoonzero/fix_release_013_dist_bug

Fix release 013 dist bug
......@@ -33,10 +33,18 @@ ELSE()
SET(BUILD_CMD make HAS_SYSTEM_PROTOBUF=false -s -j ${NUM_OF_PROCESSOR} static grpc_cpp_plugin)
ENDIF()
# FIXME(wuyi): do not build zlib cares protobuf twice, find a way to build grpc with them
ExternalProject_Add(
extern_grpc
DEPENDS protobuf zlib
URL "http://paddlepaddledeps.bj.bcebos.com/grpc.tar.xz"
# NOTE(wuyi):
# this package is generated by following steps:
# 1. git clone -b v1.8.x https://github.com/grpc/grpc.git
# 2. submodule update --init
# 3. keep only zlib, cares, protobuf, boringssl under "third_party",
# checkout and clean other dirs under third_party
# 4. remove .git, and package the directory.
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.8.x.tar.gz"
PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
......@@ -49,7 +57,6 @@ ExternalProject_Add(
INSTALL_COMMAND make prefix=${GRPC_INSTALL_DIR} install
)
# FIXME(typhoonzero): hack to get static lib path, try a better way like merge them.
ADD_LIBRARY(grpc++_unsecure STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET grpc++_unsecure PROPERTY IMPORTED_LOCATION
"${GRPC_INSTALL_DIR}/lib/libgrpc++_unsecure.a")
......
......@@ -186,12 +186,17 @@ class DistributeTranspiler:
param_list = []
grad_list = []
param_grad_set = set()
for p, g in self.params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
param_list.append(p)
grad_list.append(g)
if p.name not in param_grad_set:
param_list.append(p)
param_grad_set.add(p.name)
if g.name not in param_grad_set:
grad_list.append(g)
param_grad_set.add(g.name)
self._update_dist_lookup_table_vars(param_list, grad_list,
self.params_grads)
......@@ -802,6 +807,9 @@ class DistributeTranspiler:
if not block_map.has_key(varname):
block_map[varname] = []
block_map[varname].append((long(offset), long(size)))
# Do not remove this important debug message:
print("block map: %s" % block_map)
for varname, splited in block_map.iteritems():
orig_var = program.global_block().var(varname)
if len(splited) == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册