提交 4a99b673 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix Elemwise dimension checks for Python implementations

上级 c6dbdccb
......@@ -736,30 +736,9 @@ second dimension
# should be disabled.
super().perform(node, inputs, output_storage)
for dims in zip(
*[
list(zip(input.shape, sinput.type.broadcastable))
for input, sinput in zip(inputs, node.inputs)
]
):
if max(d for d, b in dims) != 1 and (1, False) in dims:
# yes there may be more compact ways to write this code,
# but please maintain python 2.4 compatibility
# (no "x if c else y")
msg = []
assert len(inputs) == len(node.inputs)
for input, sinput in zip(inputs, node.inputs):
assert len(input.shape) == len(sinput.type.broadcastable)
msg2 = []
for d, b in zip(input.shape, sinput.type.broadcastable):
if b:
msg2 += ["*"]
else:
msg2 += [str(d)]
msg.append(f"({', '.join(msg2)})")
base_exc_str = f"Dimension mismatch; shapes are {', '.join(msg)}"
raise ValueError(base_exc_str)
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
# Determine the shape of outputs
out_shape = []
......
......@@ -57,24 +57,27 @@ class TestCallbacks:
assert self.n_callbacks["IfElse"] == 2
def test_c_thunks():
a = scalars("a")
b, c = vectors("bc")
def test_use_c_thunks():
a_at = scalars("a")
b_at = vectors("b")
a = np.array(0.0).astype(config.floatX)
b = np.array([2.0]).astype(config.floatX)
cases = [False]
if config.cxx:
cases.append(True)
for c_thunks in cases:
for use_c_thunks in cases:
f = function(
[a, b, c],
ifelse(a, a * b, b * c),
[a_at, b_at],
a_at * b_at,
mode=Mode(
optimizer=None, linker=VMLinker(c_thunks=c_thunks, use_cloop=False)
optimizer=None, linker=VMLinker(c_thunks=use_c_thunks, use_cloop=False)
),
)
f(1, [2], [3, 2])
with pytest.raises(ValueError):
f(0, [2], [3, 4])
assert any([hasattr(t, "cthunk") for t in f.fn.thunks]) == c_thunks
assert np.array_equal(a * b, f(a, b))
assert any([hasattr(t, "cthunk") for t in f.fn.thunks]) == use_c_thunks
@pytest.mark.skipif(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论