提交 3bde5122 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Added infix dot product to TensorVariable

上级 0427130d
...@@ -642,6 +642,8 @@ class _tensor_py_operators: ...@@ -642,6 +642,8 @@ class _tensor_py_operators:
return aet.math.dense_dot(left, right) return aet.math.dense_dot(left, right)
dot = __dot__ dot = __dot__
__matmul__ = __dot__
__rmatmul__ = __rdot__
def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None): def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See `aesara.tensor.math.sum`.""" """See `aesara.tensor.math.sum`."""
......
...@@ -4,9 +4,11 @@ from numpy.testing import assert_equal, assert_string_equal ...@@ -4,9 +4,11 @@ from numpy.testing import assert_equal, assert_string_equal
import aesara import aesara
import tests.unittest_tools as utt import tests.unittest_tools as utt
from aesara.graph.basic import equal_computations
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import dot
from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor
from aesara.tensor.type import TensorType, dmatrix, iscalar, ivector, matrix from aesara.tensor.type import TensorType, dmatrix, dvector, iscalar, ivector, matrix
from aesara.tensor.type_other import MakeSlice from aesara.tensor.type_other import MakeSlice
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
...@@ -49,6 +51,20 @@ def test_numpy_method(fct): ...@@ -49,6 +51,20 @@ def test_numpy_method(fct):
utt.assert_allclose(np.nan_to_num(f(data)), np.nan_to_num(fct(data))) utt.assert_allclose(np.nan_to_num(f(data)), np.nan_to_num(fct(data)))
def test_infix_dot_method():
X = dmatrix("X")
y = dvector("y")
res = X @ y
exp_res = X.dot(y)
assert equal_computations([res], [exp_res])
X_val = np.arange(2 * 3).reshape((2, 3))
res = X_val @ y
exp_res = dot(X_val, y)
assert equal_computations([res], [exp_res])
def test_empty_list_indexing(): def test_empty_list_indexing():
ynp = np.zeros((2, 2))[:, []] ynp = np.zeros((2, 2))[:, []]
znp = np.zeros((2, 2))[:, ()] znp = np.zeros((2, 2))[:, ()]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论