提交 bc0deb3f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make local_useless_switch use the ShapeFeature's shapes

上级 8867a720
......@@ -2409,24 +2409,39 @@ def local_useless_switch(fgraph, node):
at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
"""
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Switch):
if not isinstance(node.op.scalar_op, aes.Switch):
return False
shape_feature: Optional[ShapeFeature] = getattr(fgraph, "shape_feature", None)
if shape_feature is None:
return False
cond = extract_constant(node.inputs[0], only_process_constants=True)
left = node.inputs[1]
right = node.inputs[2]
cond_var = node.inputs[0]
cond = extract_constant(cond_var, only_process_constants=True)
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, (np.number, np.bool_)
):
if cond == 0:
correct_out = node.inputs[2]
correct_out = right
else:
correct_out = node.inputs[1]
correct_out = left
if correct_out.dtype != node.outputs[0].dtype:
out = cast(correct_out, node.outputs[0].dtype)
else:
out = correct_out
out_shape = broadcast_shape(*node.inputs)
input_shapes = [
tuple(shape_feature.get_shape(inp, i) for i in range(inp.type.ndim))
for inp in node.inputs
]
out_shape = broadcast_shape(*input_shapes, arrays_are_shapes=True)
out = alloc(out, *out_shape)
# Copy over stacktrace from selected output to new output
......@@ -2434,40 +2449,34 @@ def local_useless_switch(fgraph, node):
return [out]
# if left is right -> left
if node.inputs[1] is node.inputs[2]:
if left == right:
# Note: No need to copy over stacktrace, because the input node
# already has its own stacktrace
if cond.type.is_super(node.inputs[1].type):
return [node.inputs[1]]
if cond.type.is_super(left.type):
return [left]
ret = fill(cond, node.inputs[1])
ret = fill(cond, left)
# Copy over stacktrace from switch output and correct branch
copy_stack_trace(node.outputs + node.inputs[1], ret)
copy_stack_trace(node.outputs + left, ret)
return [ret]
# This case happens with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
left = node.inputs[1]
right = node.inputs[2]
cond_var = node.inputs[0]
if (
cond_var.owner
and isinstance(cond_var.owner.op, Elemwise)
and isinstance(cond_var.owner.op.scalar_op, aes.LE)
and cond_var.owner.inputs[0].owner
and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i)
and extract_constant(cond_var.owner.inputs[1], only_process_constants=True)
== 0
and extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0
and extract_constant(left, only_process_constants=True) == 0
and right is cond_var.owner.inputs[0]
and right == cond_var.owner.inputs[0]
):
assert node.outputs[0].type.is_super(right.type)
# No need to copy over stacktrace, because the right input node
# already has its own stacktrace
return [right]
return False
return False
@register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论