提交 b2c3e950 编写于 作者: T Thomas O'Malley 提交者: TensorFlower Gardener

Add back support for tracking variables from submetrics of a Metric.

PiperOrigin-RevId: 340120524
Change-Id: I76d5880c95aafe7a6a9f64b5be1dc23af464e988
上级 898ae24a
......@@ -315,6 +315,31 @@ class Metric(base_layer.Layer):
### End: For use by subclasses ###
@property
def trainable_weights(self):
# Overridden from Layer class to track submetric weights.
if self.trainable:
trainable_weights = self._trainable_weights
for m in self._metrics:
trainable_weights += m.trainable_weights
return self._dedup_weights(trainable_weights)
else:
return []
@property
def non_trainable_weights(self):
# Overridden from Layer class to track submetric weights.
if self.trainable:
non_trainable_weights = self._non_trainable_weights
for m in self._metrics:
non_trainable_weights += m.non_trainable_weights
else:
non_trainable_weights = (
self._non_trainable_weights + self._trainable_weights)
for m in self._metrics:
non_trainable_weights += m.weights
return self._dedup_weights(non_trainable_weights)
@property
def _trackable_saved_model_saver(self):
return metric_serialization.MetricSavedModelSaver(self)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册