提交 7579a8cd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Elemwise tests for unknown/dynamic broadcasting case

上级 66d34ebe
...@@ -25,6 +25,7 @@ from aesara.tensor.type import ( ...@@ -25,6 +25,7 @@ from aesara.tensor.type import (
discrete_dtypes, discrete_dtypes,
matrix, matrix,
scalar, scalar,
vector,
vectors, vectors,
) )
from tests import unittest_tools from tests import unittest_tools
...@@ -644,6 +645,36 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -644,6 +645,36 @@ class TestElemwise(unittest_tools.InferShapeTester):
g = aesara.function([a, b, c, d, e, f], s, mode=Mode(linker="py")) g = aesara.function([a, b, c, d, e, f], s, mode=Mode(linker="py"))
g(*[np.zeros(2 ** 11, config.floatX) for i in range(6)]) g(*[np.zeros(2 ** 11, config.floatX) for i in range(6)])
def check_input_dimensions_match(self, mode):
"""Make sure that our input validation works correctly and doesn't
throw erroneous broadcast-based errors.
"""
x_v = matrix("x")
m_v = vector("m")
x = np.array([[-1.32720483], [0.23442016]]).astype(config.floatX)
m = np.array([0.0, 0.0]).astype(config.floatX)
z_v = x_v - m_v
f = aesara.function([x_v, m_v], z_v, mode=mode)
res = f(x, m)
assert np.array_equal(res, x - m)
def test_input_dimensions_match_python(self):
self.check_input_dimensions_match(Mode(linker="py"))
@pytest.mark.xfail(
reason="Elemwise C implementation does not broadcast parameters",
exception=ValueError,
)
@pytest.mark.skipif(
not aesara.config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_input_dimensions_match_c(self):
self.check_input_dimensions_match(Mode(linker="c"))
def test_not_implemented_elemwise_grad(): def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op. # Regression test for unimplemented gradient in an Elemwise Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论