提交 4a99b673 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix Elemwise dimension checks for Python implementations

上级 c6dbdccb
...@@ -736,30 +736,9 @@ second dimension ...@@ -736,30 +736,9 @@ second dimension
# should be disabled. # should be disabled.
super().perform(node, inputs, output_storage) super().perform(node, inputs, output_storage)
for dims in zip( for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
*[ if len(set(dim_shapes) - {1}) > 1:
list(zip(input.shape, sinput.type.broadcastable)) raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
for input, sinput in zip(inputs, node.inputs)
]
):
if max(d for d, b in dims) != 1 and (1, False) in dims:
# yes there may be more compact ways to write this code,
# but please maintain python 2.4 compatibility
# (no "x if c else y")
msg = []
assert len(inputs) == len(node.inputs)
for input, sinput in zip(inputs, node.inputs):
assert len(input.shape) == len(sinput.type.broadcastable)
msg2 = []
for d, b in zip(input.shape, sinput.type.broadcastable):
if b:
msg2 += ["*"]
else:
msg2 += [str(d)]
msg.append(f"({', '.join(msg2)})")
base_exc_str = f"Dimension mismatch; shapes are {', '.join(msg)}"
raise ValueError(base_exc_str)
# Determine the shape of outputs # Determine the shape of outputs
out_shape = [] out_shape = []
......
...@@ -57,24 +57,27 @@ class TestCallbacks: ...@@ -57,24 +57,27 @@ class TestCallbacks:
assert self.n_callbacks["IfElse"] == 2 assert self.n_callbacks["IfElse"] == 2
def test_c_thunks(): def test_use_c_thunks():
a = scalars("a") a_at = scalars("a")
b, c = vectors("bc") b_at = vectors("b")
a = np.array(0.0).astype(config.floatX)
b = np.array([2.0]).astype(config.floatX)
cases = [False] cases = [False]
if config.cxx: if config.cxx:
cases.append(True) cases.append(True)
for c_thunks in cases:
for use_c_thunks in cases:
f = function( f = function(
[a, b, c], [a_at, b_at],
ifelse(a, a * b, b * c), a_at * b_at,
mode=Mode( mode=Mode(
optimizer=None, linker=VMLinker(c_thunks=c_thunks, use_cloop=False) optimizer=None, linker=VMLinker(c_thunks=use_c_thunks, use_cloop=False)
), ),
) )
f(1, [2], [3, 2]) assert np.array_equal(a * b, f(a, b))
with pytest.raises(ValueError): assert any([hasattr(t, "cthunk") for t in f.fn.thunks]) == use_c_thunks
f(0, [2], [3, 4])
assert any([hasattr(t, "cthunk") for t in f.fn.thunks]) == c_thunks
@pytest.mark.skipif( @pytest.mark.skipif(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论