提交 58ebb261 编写于 作者: M Megvii Engine Team

fix(imperative/tensor): fix ConstTensorCache

GitOrigin-RevId: 0767bcfa281dc10969c320c5717dabcb9b60c15f
上级 756c1eb7
......@@ -125,6 +125,7 @@ public:
size_t size;
BlobPtr blob;
Entry() = default;
Entry(const dt_byte* ptr, size_t size_, BlobPtr blob_)
: data(new dt_byte[size_]), size(size_), blob(blob_) {
memcpy(data.get(), ptr, size);
......@@ -136,6 +137,8 @@ public:
}
};
using KV = std::pair<uint64_t, Entry>;
bool check(const HostTensorND& hv) {
auto&& layout = hv.layout();
auto&& span = layout.span();
......@@ -190,7 +193,7 @@ public:
}
std::mutex mtx;
size_t hwm = 1024, lwm = 512, max_bytes = TensorShape::MAX_NDIM * 8, window = 65536;
const size_t hwm = 1024, lwm = 512, max_bytes = TensorShape::MAX_NDIM * 8, window = 65536;
private:
void maybe_collect_g0() {
......@@ -200,25 +203,37 @@ private:
}
}
void maybe_collect_g1() {
if (g1.size() <= hwm) return;
if (g1.size() < hwm) return;
using KV = std::pair<uint64_t, Entry>;
std::vector<KV> tmp;
tmp.reserve(g1.size());
tmp.clear();
for (auto&& kv : g1) {
tmp.emplace_back(kv.first, std::move(kv.second));
}
std::nth_element(tmp.begin(), tmp.begin() + lwm, tmp.end(), [](const KV& lhs, const KV& rhs) {
return lhs.second.hitcnt > rhs.second.hitcnt;
});
tmp.resize(lwm);
g1.clear();
for (auto&& kv : tmp) {
kv.second.hitcnt = 0;
g1.emplace(std::move(kv));
}
}
// g0: records blobs which have been seen at least once (within a window)
// g0b: backup of g0
// g1: records the most frequently used blobs which have been seen at least
// twice. When `g1.size() == hwm`, it will be refreshed and only the top
// `lhw` frequently used blobs will be kept.
std::unordered_set<uint64_t> g0, g0b;
std::unordered_map<uint64_t, Entry> g1;
std::vector<KV> tmp;
public:
ConstTensorCache() {
g0.reserve(window), g0b.reserve(window);
g1.reserve(hwm), tmp.reserve(hwm);
}
};
struct MultiCNConstTensorCache : CompNodeDepedentObject {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册