提交 58fb8501 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix TensorVariable __rmatmul__

上级 326cb2e3
...@@ -652,7 +652,7 @@ class _tensor_py_operators: ...@@ -652,7 +652,7 @@ class _tensor_py_operators:
return at.math.matmul(left, right) return at.math.matmul(left, right)
def __rmatmul__(right, left): def __rmatmul__(right, left):
return at.math.matmul(right, left) return at.math.matmul(left, right)
def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None): def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See :func:`pytensor.tensor.math.sum`.""" """See :func:`pytensor.tensor.math.sum`."""
......
...@@ -10,7 +10,7 @@ from pytensor.compile import DeepCopyOp ...@@ -10,7 +10,7 @@ from pytensor.compile import DeepCopyOp
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import get_default_mode
from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor, constant from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq, matmul from pytensor.tensor.math import dot, eq, matmul
from pytensor.tensor.shape import Shape from pytensor.tensor.shape import Shape
...@@ -98,10 +98,15 @@ def test_infix_matmul_method(): ...@@ -98,10 +98,15 @@ def test_infix_matmul_method():
assert equal_computations([res], [exp_res]) assert equal_computations([res], [exp_res])
X_val = np.arange(2 * 3).reshape((2, 3)) X_val = np.arange(2 * 3).reshape((2, 3))
res = as_tensor(X_val) @ y res = X_val @ y
exp_res = matmul(X_val, y) exp_res = matmul(X_val, y)
assert equal_computations([res], [exp_res]) assert equal_computations([res], [exp_res])
y_val = np.arange(3)
res = X @ y_val
exp_res = matmul(X, y_val)
assert equal_computations([res], [exp_res])
def test_empty_list_indexing(): def test_empty_list_indexing():
ynp = np.zeros((2, 2))[:, []] ynp = np.zeros((2, 2))[:, []]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论