diff --git a/paddleslim/prune/sensitive.py b/paddleslim/prune/sensitive.py index 33a9934b307c253085b77c608de0040b96d8a9a2..dcddd6c3f0ba38c2049ccd520ffdea9f7cb79819 100644 --- a/paddleslim/prune/sensitive.py +++ b/paddleslim/prune/sensitive.py @@ -209,4 +209,5 @@ def get_ratios_by_loss(sensitivities, loss): _logger.info(losses, ratio, (r1 - r0) / (l1 - l0), i) break + if i == 0: ratios[param] = 0.0 return ratios diff --git a/tests/test_sensitivity.py b/tests/test_sensitivity.py index c3acf02210098d3a90ad73bcb470b8e91f052cd2..94857ee19a7c9bac857a3e250840365c9771b942 100644 --- a/tests/test_sensitivity.py +++ b/tests/test_sensitivity.py @@ -18,7 +18,7 @@ import numpy import paddle import paddle.fluid as fluid from static_case import StaticCase -from paddleslim.prune import sensitivity, merge_sensitive, load_sensitivities +from paddleslim.prune import sensitivity, merge_sensitive, load_sensitivities, get_ratios_by_loss from layers import conv_bn_layer @@ -107,9 +107,16 @@ class TestSensitivity(StaticCase): sensitivities_file="./sensitivities_file_2", pruned_ratios=[0.1, 0.2, 0.3, 0.4]) self.assertTrue(params_sens == origin_sens) - self.assertTrue(sens == origin_sens) + loss = 0.0 + ratios = get_ratios_by_loss(sens, loss) + self.assertTrue(len(ratios) == len(sens)) + + loss = min(list(sens.get('conv4_weights').values())) - 0.01 + ratios = get_ratios_by_loss(sens, loss) + self.assertTrue(len(ratios) == len(sens)) + if __name__ == '__main__': unittest.main()