未验证 提交 78a9870f 编写于 作者: J Jiangxinz 提交者: GitHub

fix bad super call (#33533)

上级 b4f82871
...@@ -82,7 +82,7 @@ class ProbabilityEntry(EntryAttr): ...@@ -82,7 +82,7 @@ class ProbabilityEntry(EntryAttr):
""" """
def __init__(self, probability): def __init__(self, probability):
super(EntryAttr, self).__init__() super(ProbabilityEntry, self).__init__()
if not isinstance(probability, float): if not isinstance(probability, float):
raise ValueError("probability must be a float in (0,1)") raise ValueError("probability must be a float in (0,1)")
...@@ -122,7 +122,7 @@ class CountFilterEntry(EntryAttr): ...@@ -122,7 +122,7 @@ class CountFilterEntry(EntryAttr):
""" """
def __init__(self, count_filter): def __init__(self, count_filter):
super(EntryAttr, self).__init__() super(CountFilterEntry, self).__init__()
if not isinstance(count_filter, int): if not isinstance(count_filter, int):
raise ValueError( raise ValueError(
......
...@@ -40,7 +40,7 @@ class EntryAttr(object): ...@@ -40,7 +40,7 @@ class EntryAttr(object):
class ProbabilityEntry(EntryAttr): class ProbabilityEntry(EntryAttr):
def __init__(self, probability): def __init__(self, probability):
super(EntryAttr, self).__init__() super(ProbabilityEntry, self).__init__()
if not isinstance(probability, float): if not isinstance(probability, float):
raise ValueError("probability must be a float in (0,1)") raise ValueError("probability must be a float in (0,1)")
...@@ -57,7 +57,7 @@ class ProbabilityEntry(EntryAttr): ...@@ -57,7 +57,7 @@ class ProbabilityEntry(EntryAttr):
class CountFilterEntry(EntryAttr): class CountFilterEntry(EntryAttr):
def __init__(self, count_filter): def __init__(self, count_filter):
super(EntryAttr, self).__init__() super(CountFilterEntry, self).__init__()
if not isinstance(count_filter, int): if not isinstance(count_filter, int):
raise ValueError( raise ValueError(
......
...@@ -591,7 +591,7 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -591,7 +591,7 @@ class GeneralRoleMaker(RoleMakerBase):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(RoleMakerBase, self).__init__() super(GeneralRoleMaker, self).__init__()
self._role_is_generated = False self._role_is_generated = False
self._hdfs_name = kwargs.get("hdfs_name", "") self._hdfs_name = kwargs.get("hdfs_name", "")
self._hdfs_ugi = kwargs.get("hdfs_ugi", "") self._hdfs_ugi = kwargs.get("hdfs_ugi", "")
......
...@@ -134,7 +134,7 @@ class TestMKLDNNWithValidPad(TestConv2DTransposeMKLDNNOp): ...@@ -134,7 +134,7 @@ class TestMKLDNNWithValidPad(TestConv2DTransposeMKLDNNOp):
class TestMKLDNNWithValidPad_NHWC(TestMKLDNNWithValidPad): class TestMKLDNNWithValidPad_NHWC(TestMKLDNNWithValidPad):
def init_test_case(self): def init_test_case(self):
super(TestMKLDNNWithValidPad, self).init_test_case() super(TestMKLDNNWithValidPad_NHWC, self).init_test_case()
self.data_format = "NHWC" self.data_format = "NHWC"
N, C, H, W = self.input_size N, C, H, W = self.input_size
self.input_size = [N, H, W, C] self.input_size = [N, H, W, C]
......
...@@ -160,7 +160,7 @@ class TestBackward(unittest.TestCase): ...@@ -160,7 +160,7 @@ class TestBackward(unittest.TestCase):
class SimpleNet(BackwardNet): class SimpleNet(BackwardNet):
def __init__(self): def __init__(self):
super(BackwardNet, self).__init__() super(SimpleNet, self).__init__()
self.stop_gradient_grad_vars = set([ self.stop_gradient_grad_vars = set([
u'x_no_grad@GRAD', u'x2_no_grad@GRAD', u'x3_no_grad@GRAD', u'x_no_grad@GRAD', u'x2_no_grad@GRAD', u'x3_no_grad@GRAD',
u'label_no_grad@GRAD' u'label_no_grad@GRAD'
...@@ -330,7 +330,7 @@ class TestAppendBackwardWithError(unittest.TestCase): ...@@ -330,7 +330,7 @@ class TestAppendBackwardWithError(unittest.TestCase):
# TODO(Aurelius84): add conditional network test # TODO(Aurelius84): add conditional network test
class ConditionalNet(BackwardNet): class ConditionalNet(BackwardNet):
def __init__(self): def __init__(self):
super(BackwardNet, self).__init__() super(ConditionalNet, self).__init__()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -35,7 +35,7 @@ class TestFleetMetric(unittest.TestCase): ...@@ -35,7 +35,7 @@ class TestFleetMetric(unittest.TestCase):
class FakeUtil(UtilBase): class FakeUtil(UtilBase):
def __init__(self, fake_fleet): def __init__(self, fake_fleet):
super(UtilBase, self).__init__() super(FakeUtil, self).__init__()
self.fleet = fake_fleet self.fleet = fake_fleet
def all_reduce(self, input, mode="sum", comm_world="worker"): def all_reduce(self, input, mode="sum", comm_world="worker"):
......
...@@ -199,7 +199,7 @@ class TestCloudRoleMaker2(unittest.TestCase): ...@@ -199,7 +199,7 @@ class TestCloudRoleMaker2(unittest.TestCase):
""" """
def __init__(self): def __init__(self):
super(Fleet, self).__init__() super(TmpFleet, self).__init__()
self._role_maker = None self._role_maker = None
def init_worker(self): def init_worker(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册