未验证 提交 25ee526f 编写于 作者: M minghaoBD 提交者: GitHub

Fix keyerror (#794)

* fix key_error in pruner

* add unit test for get_ratios_by_loss
上级 6947a1e5
......@@ -209,4 +209,5 @@ def get_ratios_by_loss(sensitivities, loss):
_logger.info(losses, ratio, (r1 - r0) / (l1 - l0), i)
break
if i == 0: ratios[param] = 0.0
return ratios
......@@ -18,7 +18,7 @@ import numpy
import paddle
import paddle.fluid as fluid
from static_case import StaticCase
from paddleslim.prune import sensitivity, merge_sensitive, load_sensitivities
from paddleslim.prune import sensitivity, merge_sensitive, load_sensitivities, get_ratios_by_loss
from layers import conv_bn_layer
......@@ -107,9 +107,16 @@ class TestSensitivity(StaticCase):
sensitivities_file="./sensitivities_file_2",
pruned_ratios=[0.1, 0.2, 0.3, 0.4])
self.assertTrue(params_sens == origin_sens)
self.assertTrue(sens == origin_sens)
loss = 0.0
ratios = get_ratios_by_loss(sens, loss)
self.assertTrue(len(ratios) == len(sens))
loss = min(list(sens.get('conv4_weights').values())) - 0.01
ratios = get_ratios_by_loss(sens, loss)
self.assertTrue(len(ratios) == len(sens))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册