提交 8741f7c6 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement vectorize for xtensor Dot

上级 3933c47c
...@@ -566,6 +566,9 @@ class Dot(XOp): ...@@ -566,6 +566,9 @@ class Dot(XOp):
out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims) out = xtensor(dtype=out_dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, y], [out]) return Apply(self, [x, y], [out])
def vectorize_node(self, node, *new_inputs, new_dim):
return self(*new_inputs, return_list=True)
def dot(x, y, dim: str | Sequence[str] | EllipsisType | None = None): def dot(x, y, dim: str | Sequence[str] | EllipsisType | None = None):
"""Generalized dot product for XTensorVariables. """Generalized dot product for XTensorVariables.
......
...@@ -353,3 +353,13 @@ def test_xelemwise_vectorize(): ...@@ -353,3 +353,13 @@ def test_xelemwise_vectorize():
check_vectorization([ab], [exp(ab)]) check_vectorization([ab], [exp(ab)])
check_vectorization([ab, bc], [ab + bc]) check_vectorization([ab, bc], [ab + bc])
def test_dot_vectorize():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
check_vectorization([x, y], [x.dot(y)])
check_vectorization([x, y], [x.dot(y, dim=("a", "b"))])
check_vectorization([x, y], [x.dot(y, dim="c")])
check_vectorization([x, y], [x.dot(y, dim=...)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论