提交 b379b0f2 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use broadcasted output shape in local_useless_switch optimization

Closes #270
上级 175e7843
......@@ -6163,8 +6163,7 @@ class TestLocalUselessSwitch:
def setup_method(self):
self.mode = mode_opt.excluding("constant_folding")
def test_const0(self):
def test_const_0(self):
for dtype1 in ["int32", "int64"]:
for dtype2 in ["int32", "int64"]:
x = tt.matrix("x", dtype=dtype1)
......@@ -6186,10 +6185,15 @@ class TestLocalUselessSwitch:
)
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
assert np.all(f(vx, vy) == vy)
np_res = np.where(0, vx, vy)
assert np.array_equal(f(vx, vy), np_res)
def test_const1(self):
res_non_bool_np = np.where(np.ones(10), 0, 1)
non_bool_graph = tt.switch(np.ones(10), 0, 1)
non_bool_fn = function([], non_bool_graph, mode=self.mode)
assert np.array_equal(non_bool_fn(), res_non_bool_np)
def test_const_1(self):
for dtype1 in ["int32", "int64"]:
for dtype2 in ["int32", "int64"]:
x = tt.matrix("x", dtype=dtype1)
......@@ -6211,10 +6215,10 @@ class TestLocalUselessSwitch:
)
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
assert np.all(f(vx, vy) == vx)
np_res = np.where(1, vx, vy)
assert np.array_equal(f(vx, vy), np_res)
def test_left_is_right(self):
for dtype1 in ["int32", "int64"]:
x = tt.matrix("x", dtype=dtype1)
varc = tt.matrix("varc", dtype=dtype1)
......@@ -6239,12 +6243,11 @@ class TestLocalUselessSwitch:
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
vc = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
assert np.all(f1(vx) == vx)
assert np.all(f0(vx) == vx)
assert np.all(f2(vx, vc) == vx)
assert np.array_equal(f1(vx), vx)
assert np.array_equal(f0(vx), vx)
assert np.array_equal(f2(vx, vc), vx)
def test_shape_le_0(self):
for dtype1 in ["float32", "float64"]:
x = tt.matrix("x", dtype=dtype1)
z0 = tt.switch(tt.le(x.shape[0], 0), 0, x.shape[0])
......@@ -6259,84 +6262,63 @@ class TestLocalUselessSwitch:
assert f0(vx) == 0
assert f1(vx) == 5
def test_broadcast1(self):
def test_broadcasting_1(self):
# test switch(cst, matrix, row)
x = tt.matrix("x", dtype="int32")
y = tt.vector("y", dtype="int64")
z = tt.switch(1, x, y)
f = function([x, y], z, mode=self.mode)
assert (
len(
[
node.op
for node in f.maker.fgraph.toposort()
if isinstance(node.op, tt.Elemwise)
and not isinstance(node.op.scalar_op, scal.basic.Cast)
]
)
== 0
)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, tt.Elemwise)
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, scal.basic.Cast)
assert not any(node.op == tt.switch for node in f.maker.fgraph.toposort())
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32")
vy = np.array([10, 11, 12], dtype="int64")
assert np.all(f(vx, vy) == vx)
np_res = np.where(1, vx, vy)
assert np.array_equal(f(vx, vy), np_res)
z = tt.switch(0, x, y)
f = function([x, y], z, mode=self.mode)
assert (
len(
[
node.op
for node in f.maker.fgraph.toposort()
if isinstance(node.op, tt.Elemwise)
]
)
== 0
)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, tt.Alloc)
assert f.maker.fgraph.inputs[1] == f.maker.fgraph.outputs[0].owner.inputs[0]
assert not any(node.op == tt.switch for node in f.maker.fgraph.toposort())
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32")
vy = np.array([10, 11, 12], dtype="int64")
assert np.all(f(vx, vy) == vy)
np_res = np.where(0, vx, vy)
assert np.array_equal(f(vx, vy), np_res)
def test_broadcast2(self):
def test_broadcasting_2(self):
# test switch(cst, vector, matrix)
# This case is not optimized for now.
x = tt.vector("x", dtype="int32")
y = tt.matrix("y", dtype="int64")
z = tt.switch(1, x, y)
f = function([x, y], z, mode=self.mode)
assert (
len(
[
node.op
for node in f.maker.fgraph.toposort()
if isinstance(node.op, tt.Elemwise)
and not isinstance(node.op.scalar_op, scal.basic.Cast)
]
)
== 0
)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, tt.Alloc)
assert not any(node.op == tt.switch for node in f.maker.fgraph.toposort())
vx = np.array([4, 5, 6], dtype="int32")
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype="int64")
assert np.all(f(vx, vy) == vx)
np_res = np.where(1, vx, vy)
assert np.array_equal(f(vx, vy), np_res)
z = tt.switch(0, x, y)
f = function([x, y], z, mode=self.mode)
assert (
len(
[
node.op
for node in f.maker.fgraph.toposort()
if isinstance(node.op, tt.Elemwise)
]
)
== 0
)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
assert not any(node.op == tt.switch for node in f.maker.fgraph.toposort())
vx = np.array([4, 5, 6], dtype="int32")
vy = np.array([[7, 8, 9], [10, 11, 12]], dtype="int64")
assert np.all(f(vx, vy) == vy)
np_res = np.where(0, vx, vy)
assert np.array_equal(f(vx, vy), np_res)
def test_broadcast3(self):
def test_broadcasting_3(self):
# test switch(matrix, same_vector, same_vector)
x = tt.matrix("x", dtype="int32")
......@@ -6346,16 +6328,9 @@ class TestLocalUselessSwitch:
vx = np.array([[0, 1], [1, 0]], dtype="int32")
vy = np.array([7, 8], dtype="int64")
utt.assert_allclose(f(vx, vy), np.where(vx, vy, vy))
assert (
len(
[
node.op
for node in f.maker.fgraph.toposort()
if isinstance(node.op, tt.Elemwise)
]
)
== 0
)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, tt.Alloc)
assert not any(node.op == tt.switch for node in f.maker.fgraph.toposort())
class TestLocalMergeSwitchSameCond:
......
......@@ -97,6 +97,7 @@ from theano.tensor.elemwise import (
ProdWithoutZeros,
Sum,
)
from theano.tensor.extra_ops import broadcast_shape
from theano.tensor.sort import TopKOp
from theano.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -4327,7 +4328,9 @@ def local_useless_switch(fgraph, node):
T.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
"""
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ts.Switch):
cond = tt.extract_constant(node.inputs[0], only_process_constants=True)
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, np.number
):
......@@ -4336,37 +4339,18 @@ def local_useless_switch(fgraph, node):
else:
correct_out = node.inputs[1]
if correct_out.ndim != node.outputs[0].ndim:
# TODO: broadcast?
return False
if correct_out.dtype != node.outputs[0].dtype:
out = tt.cast(correct_out, node.outputs[0].dtype)
else:
out = correct_out
if out.type.broadcastable != node.outputs[0].type.broadcastable:
# We need to copy data to the new dimensions during execution
# We should not depend on node.outputs as this would
# make the new node depend on the old one that will
# get optimized again. So this create a cycle.
shps = []
for idx, (b1, b2), in enumerate(
zip(out.type.broadcastable, node.outputs[0].type.broadcastable)
):
if b1 == b2:
shps.append(out.shape[idx])
elif not node.inputs[1].type.broadcastable[idx]:
shps.append(node.inputs[1].shape[idx])
else:
shps.append(node.inputs[2].shape[idx])
out = alloc(out, *shps)
else:
out = out
out_shape = broadcast_shape(*node.inputs)
out = alloc(out, *out_shape)
# Copy over stacktrace from selected output to new output
copy_stack_trace(node.outputs + correct_out, out)
return [out]
# if left is right -> left
if node.inputs[1] is node.inputs[2]:
# Note: No need to copy over stacktrace, because the input node
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论