提交 ce7990b1 编写于 作者: R Reed

Fix import issue in training.py.

The issue was I used a module which wasn't imported in 2.2. The import line was only added after the 2.2 branch cut.
上级 908cb44d
......@@ -1367,7 +1367,7 @@ class TrainingTest(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_gradients_are_none(self):
class DenseWithExtraWeight(layers_module.Dense):
class DenseWithExtraWeight(keras.layers.Dense):
def build(self, input_shape):
# Gradients w.r.t. extra_weights are None
......@@ -1377,9 +1377,9 @@ class TrainingTest(keras_parameterized.TestCase):
self.extra_weight_2 = self.add_weight('extra_weight_2', shape=(),
initializer='ones')
model = sequential.Sequential([DenseWithExtraWeight(4, input_shape=(4,))])
model = keras.models.Sequential([DenseWithExtraWeight(4, input_shape=(4,))])
# Test clipping can handle None gradients
opt = optimizer_v2.adam.Adam(clipnorm=1.0, clipvalue=1.0)
opt = keras.optimizer_v2.adam.Adam(clipnorm=1.0, clipvalue=1.0)
model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
inputs = np.random.normal(size=(64, 4))
targets = np.random.normal(size=(64, 4))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册