提交 82823dec authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove switches when both branches are equivalent constants with different dtype

上级 aad78d57
......@@ -958,6 +958,25 @@ def local_sum_make_vector(fgraph, node):
return [element_sum]
def equivalent_up_to_constant_casting(a, b) -> bool:
"""Return True if a and b are equivalent up to constant casting."""
if a == b:
return True
# Return equivalence based on data values, ignoring dtype
if (
isinstance(a, TensorConstant)
and isinstance(b, TensorConstant)
and a.type.shape == b.type.shape
# We don't want to spend a lot of time comparing large constant arrays
# First, check if dtype matches, otherwise a == b would be true if they hold the same values
and a.type.dtype != b.type.dtype
# Check property sum() that is cached for TensorConstants, to filter down candidates even more
and a.signature().sum == b.signature().sum
):
return np.array_equal(a.data, b.data)
return False
@register_useless("shape_unsafe")
@register_canonicalize("fast_compile", "shape_unsafe")
@register_specialize("shape_unsafe")
......@@ -1004,17 +1023,19 @@ def local_useless_switch(fgraph, node):
return [out]
# if left is right -> left
if left == right:
# Note: No need to copy over stacktrace, because the input node
# already has its own stacktrace
if equivalent_up_to_constant_casting(left, right):
if left.type.broadcastable == out_bcast:
out_dtype = node.outputs[0].type.dtype
if left.type.dtype != out_dtype:
left = cast(left, out_dtype)
copy_stack_trace(node.outputs + left, left)
# When not casting, the other inputs of the switch aren't needed in the traceback
return [left]
ret = broadcast_arrays(left, cond)[0]
# Copy over stacktrace from switch output and correct branch
copy_stack_trace(node.outputs + left, ret)
return [ret]
else:
ret = broadcast_arrays(left, cond)[0]
copy_stack_trace(node.outputs + left, ret)
return [ret]
# This case happens with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
......
......@@ -26,6 +26,7 @@ from pytensor.tensor.basic import (
ScalarFromTensor,
Split,
TensorFromScalar,
as_tensor,
cast,
join,
tile,
......@@ -983,6 +984,21 @@ class TestLocalUselessSwitch:
assert np.array_equal(f0(vx), vx)
assert np.array_equal(f2(vx, vc), vx)
def test_left_is_right_constant(self):
int8_one = as_tensor(np.int8(1))
int8_zero = as_tensor(np.int8(0))
int64_zero = as_tensor(np.int64(0))
cond = scalar("cond", dtype=bool)
out = pt.switch(cond, int8_zero, int64_zero)
assert equal_computations([rewrite_graph(out)], [int64_zero])
out = pt.switch(cond, int64_zero, int8_zero)
assert equal_computations([rewrite_graph(out)], [int64_zero])
out = pt.switch(cond, int8_one, int8_zero)
assert equal_computations([rewrite_graph(out)], [out])
@pytest.mark.parametrize(
"dtype1",
["float32", "float64"],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论