未验证 提交 04997b00 编写于 作者: M Matthias Fey 提交者: GitHub

Merge pull request #102 from shi27feng/master

modify spmm to support multi-dimensional tensor
......@@ -15,13 +15,13 @@ def spmm(index, value, m, n, matrix):
:rtype: :class:`Tensor`
"""
assert n == matrix.size(0)
assert n == matrix.size(-2)
row, col = index
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
out = matrix[col]
out = matrix.index_select(-2, col)
out = out * value.unsqueeze(-1)
out = scatter_add(out, row, dim=0, dim_size=m)
out = scatter_add(out, row, dim=-2, dim_size=m)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册