未验证 提交 4d95c8c7 编写于 作者: F Feiyu Chan 提交者: GitHub

avoid polluting logging's root logger (#32673)

avoid polluting logging's root logger
上级 109fdf14
......@@ -29,9 +29,12 @@ from paddle.fluid.framework import Program, Variable, name_scope, default_main_p
from paddle.fluid import layers
import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
formatter = logging.Formatter(
fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
from functools import reduce
__all__ = ["ShardingOptimizer"]
......@@ -136,7 +139,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# FIXME (JZ-LIANG) deprecated hybrid_dp
if self.user_defined_strategy.sharding_configs["hybrid_dp"]:
logging.warning(
logger.warning(
"[hybrid_dp] API setting is deprecated. Now when dp_degree >= 2, its will be in hybrid dp mode automatically"
)
assert self.dp_degree >= 1
......@@ -174,7 +177,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self._gradient_merge_acc_step = self.user_defined_strategy.pipeline_configs[
'accumulate_steps']
if self._gradient_merge_acc_step > 1:
logging.info("Gradient merge in [{}], acc step = [{}]".format(
logger.info("Gradient merge in [{}], acc step = [{}]".format(
self.gradient_merge_mode, self._gradient_merge_acc_step))
# optimize offload
......@@ -338,7 +341,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100)
# sync its memcpy could not be overlap with calc, otherwise it will slower down training severely.
if self.optimize_offload:
logging.info("Sharding with optimize offload !")
logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
offload_helper.offload_fp32param(main_block, startup_block)
......@@ -641,15 +644,15 @@ class ShardingOptimizer(MetaOptimizerBase):
for varname in sorted(
var2broadcast_time, key=var2broadcast_time.get,
reverse=True):
logging.info("Sharding broadcast: [{}] times [{}]".format(
logger.info("Sharding broadcast: [{}] times [{}]".format(
var2broadcast_time[varname], varname))
for idx_ in range(len(self._segments)):
logging.info("segment [{}] :".format(idx_))
logging.info("start op: [{}] [{}]".format(block.ops[
logger.info("segment [{}] :".format(idx_))
logger.info("start op: [{}] [{}]".format(block.ops[
self._segments[idx_]._start_idx].desc.type(), block.ops[
self._segments[idx_]._start_idx].desc.input_arg_names(
)))
logging.info("end op: [{}] [{}]".format(block.ops[
logger.info("end op: [{}] [{}]".format(block.ops[
self._segments[idx_]._end_idx].desc.type(), block.ops[
self._segments[idx_]._end_idx].desc.input_arg_names()))
return
......@@ -1108,7 +1111,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.dp_group_endpoints.append(self.global_endpoints[
dp_first_rank_idx + dp_offset * i])
assert self.current_endpoint in self.dp_group_endpoints
logging.info("Hybrid DP mode turn on !")
logger.info("Hybrid DP mode turn on !")
else:
self.dp_ring_id = -1
self.dp_rank = -1
......@@ -1119,40 +1122,40 @@ class ShardingOptimizer(MetaOptimizerBase):
# NOTE (JZ-LIANG) when use global ring for calc global norm and dp_degree > 1, the allreduce result should be devided by dp_degree
self.global_ring_id = 3
logging.info("global word size: {}".format(self.global_word_size))
logging.info("global rank: {}".format(self.global_rank))
logging.info("global endpoints: {}".format(self.global_endpoints))
logging.info("global ring id: {}".format(self.global_ring_id))
logging.info("#####" * 6)
logging.info("mp group size: {}".format(self.mp_degree))
logging.info("mp rank: {}".format(self.mp_rank))
logging.info("mp group id: {}".format(self.mp_group_id))
logging.info("mp group endpoints: {}".format(self.mp_group_endpoints))
logging.info("mp ring id: {}".format(self.mp_ring_id))
logging.info("#####" * 6)
logging.info("sharding group size: {}".format(self.sharding_degree))
logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("sharding group id: {}".format(self.sharding_group_id))
logging.info("sharding group endpoints: {}".format(
logger.info("global word size: {}".format(self.global_word_size))
logger.info("global rank: {}".format(self.global_rank))
logger.info("global endpoints: {}".format(self.global_endpoints))
logger.info("global ring id: {}".format(self.global_ring_id))
logger.info("#####" * 6)
logger.info("mp group size: {}".format(self.mp_degree))
logger.info("mp rank: {}".format(self.mp_rank))
logger.info("mp group id: {}".format(self.mp_group_id))
logger.info("mp group endpoints: {}".format(self.mp_group_endpoints))
logger.info("mp ring id: {}".format(self.mp_ring_id))
logger.info("#####" * 6)
logger.info("sharding group size: {}".format(self.sharding_degree))
logger.info("sharding rank: {}".format(self.sharding_rank))
logger.info("sharding group id: {}".format(self.sharding_group_id))
logger.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints))
logging.info("sharding ring id: {}".format(self.sharding_ring_id))
logging.info("#####" * 6)
logging.info("pp group size: {}".format(self.pp_degree))
logging.info("pp rank: {}".format(self.pp_rank))
logging.info("pp group id: {}".format(self.pp_group_id))
logging.info("pp group endpoints: {}".format(self.pp_group_endpoints))
logging.info("pp ring id: {}".format(self.pp_ring_id))
logging.info("#####" * 6)
logging.info("pure dp group size: {}".format(self.dp_degree))
logging.info("pure dp rank: {}".format(self.dp_rank))
logging.info("pure dp group endpoints: {}".format(
logger.info("sharding ring id: {}".format(self.sharding_ring_id))
logger.info("#####" * 6)
logger.info("pp group size: {}".format(self.pp_degree))
logger.info("pp rank: {}".format(self.pp_rank))
logger.info("pp group id: {}".format(self.pp_group_id))
logger.info("pp group endpoints: {}".format(self.pp_group_endpoints))
logger.info("pp ring id: {}".format(self.pp_ring_id))
logger.info("#####" * 6)
logger.info("pure dp group size: {}".format(self.dp_degree))
logger.info("pure dp rank: {}".format(self.dp_rank))
logger.info("pure dp group endpoints: {}".format(
self.dp_group_endpoints))
logging.info("pure dp ring id: {}".format(self.dp_ring_id))
logging.info("#####" * 6)
logger.info("pure dp ring id: {}".format(self.dp_ring_id))
logger.info("#####" * 6)
return
......
......@@ -19,9 +19,12 @@ from paddle.fluid import framework
import contextlib
import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
formatter = logging.Formatter(
fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
def detach_variable(inputs):
......@@ -40,7 +43,7 @@ def detach_variable(inputs):
def check_recompute_necessary(inputs):
if not any(input_.stop_gradient == False for input_ in inputs
if isinstance(input_, paddle.Tensor)):
logging.warn(
logger.warn(
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !")
......
......@@ -34,9 +34,12 @@ __all__ = [
"graphviz"
]
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
persistable_vars_out_fn = "vars_persistable.log"
all_vars_out_fn = "vars_all.log"
......
......@@ -32,9 +32,12 @@ from ...fluid import core
from ...fluid.framework import OpProtoHolder
from ...sysconfig import get_include, get_lib
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger("utils.cpp_extension")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
OS_NAME = sys.platform
IS_WINDOWS = OS_NAME.startswith('win')
......@@ -1125,4 +1128,4 @@ def log_v(info, verbose=True):
Print log information on stdout.
"""
if verbose:
logging.info(info)
logger.info(info)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册