提交 692a6da2 编写于 作者: TensorSense's avatar TensorSense

fix index

上级 a3fe1766
......@@ -94,6 +94,8 @@ def comp_class_vec(ouput_vec, index=None):
"""
if not index:
index = np.argmax(ouput_vec.cpu().data.numpy())
else:
index = np.array(index)
index = index[np.newaxis, np.newaxis]
index = torch.from_numpy(index)
one_hot = torch.zeros(1, 10).scatter_(1, index, 1)
......@@ -128,7 +130,8 @@ def gen_cam(feature_map, grads):
if __name__ == '__main__':
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
path_img = os.path.join(BASE_DIR, "../../Data/cam_img/", "test_img_1.png")
path_img = os.path.join(BASE_DIR, "../../Data/cam_img/", "test_img_8.png")
path_img = "/Users/tingsongyu/Desktop/t.png"
path_net = os.path.join(BASE_DIR, "../../Data/", "net_params_72p.pkl")
output_dir = os.path.join(BASE_DIR, "../../Result/backward_hook_cam/")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册