提交 288cfd44 编写于 作者: R rusty1s

add bipartite flag

上级 c493caaf
......@@ -46,7 +46,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx) {
torch::Tensor idx, bool bipartite) {
CHECK_CPU(rowptr);
CHECK_CPU(col);
......@@ -131,9 +131,10 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
}
}
out_rowptr =
torch::cat({out_rowptr, torch::full({(int64_t)n_ids.size()},
out_col.numel(), rowptr.options())});
if (bipartite)
out_rowptr = torch::cat(
{out_rowptr, torch::full({(int64_t)n_ids.size()}, out_col.numel(),
rowptr.options())});
idx = torch::cat({idx, torch::from_blob(n_ids.data(), {(int64_t)n_ids.size()},
idx.options())});
......
......@@ -9,4 +9,4 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx);
torch::Tensor idx, bool bipartite);
......@@ -24,7 +24,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch::Tensor>
relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value,
torch::Tensor idx) {
torch::Tensor idx, bool bipartite) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
AT_ERROR("No CUDA version supported");
......@@ -32,7 +32,7 @@ relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return relabel_one_hop_cpu(rowptr, col, optional_value, idx);
return relabel_one_hop_cpu(rowptr, col, optional_value, idx, bipartite);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册