提交 58fb8501 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix TensorVariable __rmatmul__

上级 326cb2e3
......@@ -652,7 +652,7 @@ class _tensor_py_operators:
return at.math.matmul(left, right)
def __rmatmul__(right, left):
return at.math.matmul(right, left)
return at.math.matmul(left, right)
def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See :func:`pytensor.tensor.math.sum`."""
......
......@@ -10,7 +10,7 @@ from pytensor.compile import DeepCopyOp
from pytensor.compile.mode import get_default_mode
from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor, constant
from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq, matmul
from pytensor.tensor.shape import Shape
......@@ -98,10 +98,15 @@ def test_infix_matmul_method():
assert equal_computations([res], [exp_res])
X_val = np.arange(2 * 3).reshape((2, 3))
res = as_tensor(X_val) @ y
res = X_val @ y
exp_res = matmul(X_val, y)
assert equal_computations([res], [exp_res])
y_val = np.arange(3)
res = X @ y_val
exp_res = matmul(X, y_val)
assert equal_computations([res], [exp_res])
def test_empty_list_indexing():
ynp = np.zeros((2, 2))[:, []]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论