提交 b4fff097 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Test wasn't actually covering local_0_dot_x rewrite

上级 2fc5cc4e
...@@ -147,11 +147,11 @@ def local_0_dot_x(fgraph, node): ...@@ -147,11 +147,11 @@ def local_0_dot_x(fgraph, node):
x, y = node.inputs x, y = node.inputs
if ( if (
get_underlying_scalar_constant_value( get_underlying_scalar_constant_value(
x, only_process_constants=True, raise_not_constant=False x, only_process_constants=False, raise_not_constant=False
) )
== 0 == 0
or get_underlying_scalar_constant_value( or get_underlying_scalar_constant_value(
y, only_process_constants=True, raise_not_constant=False y, only_process_constants=False, raise_not_constant=False
) )
== 0 == 0
): ):
......
...@@ -1448,11 +1448,14 @@ class TestSubtensorAllocRewrites: ...@@ -1448,11 +1448,14 @@ class TestSubtensorAllocRewrites:
not isinstance(n.op, Dot) for n in f.maker.fgraph.toposort() not isinstance(n.op, Dot) for n in f.maker.fgraph.toposort()
) )
# test that we don't remove shape errors # test that we don't remove shape errors if we exclude shape_unsafe
f_safe = f = function(
[_e1[0], _e2[0]], o, mode=self.mode.excluding("shape_unsafe")
)
with pytest.raises((ValueError, AssertionError)): with pytest.raises((ValueError, AssertionError)):
f(_e1[1], _e2[2]) f_safe(_e1[1], _e2[2])
with pytest.raises((ValueError, AssertionError)): with pytest.raises((ValueError, AssertionError)):
f(_e1[2], _e2[1]) f_safe(_e1[2], _e2[1])
def test_local_IncSubtensor_serialize(): def test_local_IncSubtensor_serialize():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论