提交 17aa9662 编写于 作者: C Cysu

Fix a bug in recursively sharing data/diff

上级 8c2de269
......@@ -1079,7 +1079,10 @@ void Net<Dtype>::MemoryOptimize() {
if (layers_[i]->is_sharing_data(i_top, i_bottom)) {
sharing_data = true;
const string& bottom_name = blob_names_[bottom_id_vecs_[i][i_bottom]];
idx = FindSlot(slots, bottom_name + "_data");
// The shared blob is guaranteed to be assign either an active
// slot, or an excluded slot (index == -1, for I/O blobs).
assert(slot_index.find(bottom_name + "_data") != slot_index.end());
idx = slot_index[bottom_name + "_data"];
LOG(INFO) << "top " << top_name
<< " shares data with bottom " << bottom_name
<< " slot " << idx;
......@@ -1093,11 +1096,11 @@ void Net<Dtype>::MemoryOptimize() {
LOG(INFO) << "top " << top_name << " acquires data slot " << idx;
}
} else {
slot_index[top_name + "_data"] = idx;
if (idx != -1) {
// idx == -1 means the top blob is (recursively) sharing data with an excluded bottom blob
// This makes this blob itself excluded from the optimization
slots[idx].IncRef();
slot_index[top_name + "_data"] = idx;
}
}
} else {
......@@ -1149,7 +1152,10 @@ void Net<Dtype>::MemoryOptimize() {
if(layers_[i]->is_sharing_diff(i_top, i_bottom)){
const string& top_name = blob_names_[layer_top_idx[i_top]];
sharing_diff = true;
idx = FindSlot(slots, top_name + "_diff");
// The shared blob is guaranteed to be assign either an active
// slot, or an excluded slot (index == -1, for I/O blobs).
assert(slot_index.find(top_name + "_diff") != slot_index.end());
idx = slot_index[top_name + "_diff"];
}
}
if (!sharing_diff) {
......@@ -1158,11 +1164,11 @@ void Net<Dtype>::MemoryOptimize() {
LOG(INFO) << "acquired slot for new blob";
}else{
LOG(INFO) << "sharing diff using slot "<<idx;
slot_index[bottom_name + "_diff"] = idx;
if(idx != -1) {
// idx == -1 means the bottom blob is (recursively) sharing diff with an excluded top blob
// This makes this blob itself excluded from the optimization
slots[idx].IncRef();
slot_index[bottom_name + "_diff"] = idx;
}
}
}else{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册