diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 1bf08197e2dc9975c81cd2faff5a5f3d845b1b5b..f26b797db6425f487d7b189b9ef76893bb6d3a1b 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -315,6 +315,31 @@ class Metric(base_layer.Layer): ### End: For use by subclasses ### + @property + def trainable_weights(self): + # Overridden from Layer class to track submetric weights. + if self.trainable: + trainable_weights = self._trainable_weights + for m in self._metrics: + trainable_weights += m.trainable_weights + return self._dedup_weights(trainable_weights) + else: + return [] + + @property + def non_trainable_weights(self): + # Overridden from Layer class to track submetric weights. + if self.trainable: + non_trainable_weights = self._non_trainable_weights + for m in self._metrics: + non_trainable_weights += m.non_trainable_weights + else: + non_trainable_weights = ( + self._non_trainable_weights + self._trainable_weights) + for m in self._metrics: + non_trainable_weights += m.weights + return self._dedup_weights(non_trainable_weights) + @property def _trackable_saved_model_saver(self): return metric_serialization.MetricSavedModelSaver(self)