From 25ee526fd843f72df0095348be3c0d8eeae3d313 Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Wed, 9 Jun 2021 13:10:42 +0800 Subject: [PATCH] Fix keyerror (#794) * fix key_error in pruner * add unit test for get_ratios_by_loss --- paddleslim/prune/sensitive.py | 1 + tests/test_sensitivity.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/paddleslim/prune/sensitive.py b/paddleslim/prune/sensitive.py index 33a9934b..dcddd6c3 100644 --- a/paddleslim/prune/sensitive.py +++ b/paddleslim/prune/sensitive.py @@ -209,4 +209,5 @@ def get_ratios_by_loss(sensitivities, loss): _logger.info(losses, ratio, (r1 - r0) / (l1 - l0), i) break + if i == 0: ratios[param] = 0.0 return ratios diff --git a/tests/test_sensitivity.py b/tests/test_sensitivity.py index c3acf022..94857ee1 100644 --- a/tests/test_sensitivity.py +++ b/tests/test_sensitivity.py @@ -18,7 +18,7 @@ import numpy import paddle import paddle.fluid as fluid from static_case import StaticCase -from paddleslim.prune import sensitivity, merge_sensitive, load_sensitivities +from paddleslim.prune import sensitivity, merge_sensitive, load_sensitivities, get_ratios_by_loss from layers import conv_bn_layer @@ -107,9 +107,16 @@ class TestSensitivity(StaticCase): sensitivities_file="./sensitivities_file_2", pruned_ratios=[0.1, 0.2, 0.3, 0.4]) self.assertTrue(params_sens == origin_sens) - self.assertTrue(sens == origin_sens) + loss = 0.0 + ratios = get_ratios_by_loss(sens, loss) + self.assertTrue(len(ratios) == len(sens)) + + loss = min(list(sens.get('conv4_weights').values())) - 0.01 + ratios = get_ratios_by_loss(sens, loss) + self.assertTrue(len(ratios) == len(sens)) + if __name__ == '__main__': unittest.main() -- GitLab