提交 b4fff097 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Test wasn't actually covering local_0_dot_x rewrite

上级 2fc5cc4e
......@@ -147,11 +147,11 @@ def local_0_dot_x(fgraph, node):
x, y = node.inputs
if (
get_underlying_scalar_constant_value(
x, only_process_constants=True, raise_not_constant=False
x, only_process_constants=False, raise_not_constant=False
)
== 0
or get_underlying_scalar_constant_value(
y, only_process_constants=True, raise_not_constant=False
y, only_process_constants=False, raise_not_constant=False
)
== 0
):
......
......@@ -1448,11 +1448,14 @@ class TestSubtensorAllocRewrites:
not isinstance(n.op, Dot) for n in f.maker.fgraph.toposort()
)
# test that we don't remove shape errors
# test that we don't remove shape errors if we exclude shape_unsafe
f_safe = f = function(
[_e1[0], _e2[0]], o, mode=self.mode.excluding("shape_unsafe")
)
with pytest.raises((ValueError, AssertionError)):
f(_e1[1], _e2[2])
f_safe(_e1[1], _e2[2])
with pytest.raises((ValueError, AssertionError)):
f(_e1[2], _e2[1])
f_safe(_e1[2], _e2[1])
def test_local_IncSubtensor_serialize():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论