提交 8d837c78 编写于 作者: R rusty1s

relabel one hop

上级 f0609836
......@@ -41,3 +41,102 @@ std::tuple<torch::Tensor, torch::Tensor> relabel_cpu(torch::Tensor col,
return std::make_tuple(out_col, out_idx);
}
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
if (optional_value.has_value()) {
CHECK_CPU(optional_value.value());
CHECK_INPUT(optional_value.value().dim() == 1);
}
CHECK_CPU(idx);
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto idx_data = idx.data_ptr<int64_t>();
std::vector<int64_t> n_ids;
std::unordered_map<int64_t, int64_t> n_id_map;
std::unordered_map<int64_t, int64_t>::iterator it;
auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options());
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
out_rowptr_data[0] = 0;
int64_t v, w, c, row_start, row_end, offset = 0;
for (int64_t i = 0; i < idx.numel(); i++) {
v = idx_data[i];
n_id_map[i] = v;
offset += rowptr_data[v + 1] - rowptr_data[v];
out_rowptr_data[i + 1] = offset;
}
auto out_col = torch::empty(offset, col.options());
auto out_col_data = out_col.data_ptr<int64_t>();
torch::optional<torch::Tensor> out_value = torch::nullopt;
if (optional_value.has_value()) {
out_value = torch::empty(offset, optional_value.value().options());
AT_DISPATCH_ALL_TYPES(optional_value.value().scalar_type(), "relabel", [&] {
auto value_data = optional_value.value().data_ptr<scalar_t>();
auto out_value_data = out_value.value().data_ptr<scalar_t>();
offset = 0;
for (int64_t i = 0; i < idx.numel(); i++) {
v = idx_data[i];
row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
for (int64_t j = row_start; j < row_end; j++) {
w = col_data[j];
it = n_id_map.find(w);
if (it == n_id_map.end()) {
c = idx.numel() + n_ids.size();
n_id_map[w] = c;
n_ids.push_back(w);
out_col_data[offset] = c;
} else {
out_col_data[offset] = it->second;
}
out_value_data[offset] = value_data[j];
offset++;
}
}
});
} else {
offset = 0;
for (int64_t i = 0; i < idx.numel(); i++) {
v = idx_data[i];
row_start = rowptr_data[v], row_end = rowptr_data[v + 1];
for (int64_t j = row_start; j < row_end; j++) {
w = col_data[j];
it = n_id_map.find(w);
if (it == n_id_map.end()) {
c = idx.numel() + n_ids.size();
n_id_map[w] = c;
n_ids.push_back(w);
out_col_data[offset] = c;
} else {
out_col_data[offset] = it->second;
}
offset++;
}
}
}
out_rowptr =
torch::cat({out_rowptr, torch::full({(int64_t)n_ids.size()},
out_col.numel(), rowptr.options())});
idx = torch::cat(
{idx, torch::from_blob(n_ids.data(), {(int64_t)n_ids.size()})});
return std::make_tuple(out_rowptr, out_col, out_value, idx);
}
......@@ -4,3 +4,9 @@
std::tuple<torch::Tensor, torch::Tensor> relabel_cpu(torch::Tensor col,
torch::Tensor idx);
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx);
......@@ -20,5 +20,23 @@ std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return relabel_one_hop_cpu(rowptr, col, optional_value, idx);
}
}
static auto registry =
torch::RegisterOperators().op("torch_sparse::relabel", &relabel);
torch::RegisterOperators()
.op("torch_sparse::relabel", &relabel)
.op("torch_sparse::relabel_one_hop", &relabel_one_hop);
......@@ -18,6 +18,12 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
torch::Tensor idx);
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx);
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册