提交 c2e3dbb1 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Fixed local_useless_switch optimization for boolean conditions

上级 29032f34
...@@ -2536,7 +2536,7 @@ def local_useless_switch(fgraph, node): ...@@ -2536,7 +2536,7 @@ def local_useless_switch(fgraph, node):
cond = extract_constant(node.inputs[0], only_process_constants=True) cond = extract_constant(node.inputs[0], only_process_constants=True)
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, np.number cond, (np.number, np.bool_)
): ):
if cond == 0: if cond == 0:
correct_out = node.inputs[2] correct_out = node.inputs[2]
......
...@@ -2325,42 +2325,16 @@ class TestLocalUselessSwitch: ...@@ -2325,42 +2325,16 @@ class TestLocalUselessSwitch:
def setup_method(self): def setup_method(self):
self.mode = mode_opt.excluding("constant_folding") self.mode = mode_opt.excluding("constant_folding")
def test_const_0(self): @pytest.mark.parametrize(
for dtype1 in ["int32", "int64"]: "cond",
for dtype2 in ["int32", "int64"]: [0, 1, np.array([True])],
x = matrix("x", dtype=dtype1) )
y = matrix("y", dtype=dtype2) def test_const(self, cond):
z = aet.switch(0, x, y)
f = function([x, y], z, mode=self.mode)
assert (
len(
[
node.op
for node in f.maker.fgraph.toposort()
if (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, aes.basic.Switch)
)
]
)
== 0
)
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
np_res = np.where(0, vx, vy)
assert np.array_equal(f(vx, vy), np_res)
res_non_bool_np = np.where(np.ones(10), 0, 1)
non_bool_graph = aet.switch(np.ones(10), 0, 1)
non_bool_fn = function([], non_bool_graph, mode=self.mode)
assert np.array_equal(non_bool_fn(), res_non_bool_np)
def test_const_1(self):
for dtype1 in ["int32", "int64"]: for dtype1 in ["int32", "int64"]:
for dtype2 in ["int32", "int64"]: for dtype2 in ["int32", "int64"]:
x = matrix("x", dtype=dtype1) x = matrix("x", dtype=dtype1)
y = matrix("y", dtype=dtype2) y = matrix("y", dtype=dtype2)
z = aet.switch(1, x, y) z = aet.switch(cond, x, y)
f = function([x, y], z, mode=self.mode) f = function([x, y], z, mode=self.mode)
assert ( assert (
len( len(
...@@ -2377,7 +2351,7 @@ class TestLocalUselessSwitch: ...@@ -2377,7 +2351,7 @@ class TestLocalUselessSwitch:
) )
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2) vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
np_res = np.where(1, vx, vy) np_res = np.where(cond, vx, vy)
assert np.array_equal(f(vx, vy), np_res) assert np.array_equal(f(vx, vy), np_res)
def test_left_is_right(self): def test_left_is_right(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论