From b2c3e950580cec60df1fae4382b20624f8977e4f Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Sun, 1 Nov 2020 10:23:14 -0800 Subject: [PATCH] Add back support for tracking variables from submetrics of a Metric. PiperOrigin-RevId: 340120524 Change-Id: I76d5880c95aafe7a6a9f64b5be1dc23af464e988 --- tensorflow/python/keras/metrics.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index 1bf08197e2d..f26b797db64 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -315,6 +315,31 @@ class Metric(base_layer.Layer): ### End: For use by subclasses ### + @property + def trainable_weights(self): + # Overridden from Layer class to track submetric weights. + if self.trainable: + trainable_weights = self._trainable_weights + for m in self._metrics: + trainable_weights += m.trainable_weights + return self._dedup_weights(trainable_weights) + else: + return [] + + @property + def non_trainable_weights(self): + # Overridden from Layer class to track submetric weights. + if self.trainable: + non_trainable_weights = self._non_trainable_weights + for m in self._metrics: + non_trainable_weights += m.non_trainable_weights + else: + non_trainable_weights = ( + self._non_trainable_weights + self._trainable_weights) + for m in self._metrics: + non_trainable_weights += m.weights + return self._dedup_weights(non_trainable_weights) + @property def _trackable_saved_model_saver(self): return metric_serialization.MetricSavedModelSaver(self) -- GitLab