提交 bc0deb3f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make local_useless_switch use the ShapeFeature's shapes

上级 8867a720
...@@ -2409,65 +2409,74 @@ def local_useless_switch(fgraph, node): ...@@ -2409,65 +2409,74 @@ def local_useless_switch(fgraph, node):
at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
""" """
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Switch): if not isinstance(node.op.scalar_op, aes.Switch):
return False
cond = extract_constant(node.inputs[0], only_process_constants=True) shape_feature: Optional[ShapeFeature] = getattr(fgraph, "shape_feature", None)
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( if shape_feature is None:
cond, (np.number, np.bool_) return False
):
if cond == 0:
correct_out = node.inputs[2]
else:
correct_out = node.inputs[1]
if correct_out.dtype != node.outputs[0].dtype: left = node.inputs[1]
out = cast(correct_out, node.outputs[0].dtype) right = node.inputs[2]
else: cond_var = node.inputs[0]
out = correct_out cond = extract_constant(cond_var, only_process_constants=True)
out_shape = broadcast_shape(*node.inputs) if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
out = alloc(out, *out_shape) cond, (np.number, np.bool_)
):
if cond == 0:
correct_out = right
else:
correct_out = left
# Copy over stacktrace from selected output to new output if correct_out.dtype != node.outputs[0].dtype:
copy_stack_trace(node.outputs + correct_out, out) out = cast(correct_out, node.outputs[0].dtype)
return [out] else:
out = correct_out
# if left is right -> left input_shapes = [
if node.inputs[1] is node.inputs[2]: tuple(shape_feature.get_shape(inp, i) for i in range(inp.type.ndim))
# Note: No need to copy over stacktrace, because the input node for inp in node.inputs
# already has its own stacktrace ]
if cond.type.is_super(node.inputs[1].type):
return [node.inputs[1]]
ret = fill(cond, node.inputs[1]) out_shape = broadcast_shape(*input_shapes, arrays_are_shapes=True)
# Copy over stacktrace from switch output and correct branch out = alloc(out, *out_shape)
copy_stack_trace(node.outputs + node.inputs[1], ret)
return [ret]
# This case happens with scan. # Copy over stacktrace from selected output to new output
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) copy_stack_trace(node.outputs + correct_out, out)
left = node.inputs[1] return [out]
right = node.inputs[2]
cond_var = node.inputs[0] # if left is right -> left
if ( if left == right:
cond_var.owner # Note: No need to copy over stacktrace, because the input node
and isinstance(cond_var.owner.op, Elemwise) # already has its own stacktrace
and isinstance(cond_var.owner.op.scalar_op, aes.LE) if cond.type.is_super(left.type):
and cond_var.owner.inputs[0].owner return [left]
and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i)
and extract_constant(cond_var.owner.inputs[1], only_process_constants=True) ret = fill(cond, left)
== 0
and extract_constant(left, only_process_constants=True) == 0 # Copy over stacktrace from switch output and correct branch
and right is cond_var.owner.inputs[0] copy_stack_trace(node.outputs + left, ret)
): return [ret]
assert node.outputs[0].type.is_super(right.type)
# No need to copy over stacktrace, because the right input node # This case happens with scan.
# already has its own stacktrace # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
return [right] if (
return False cond_var.owner
return False and isinstance(cond_var.owner.op, Elemwise)
and isinstance(cond_var.owner.op.scalar_op, aes.LE)
and cond_var.owner.inputs[0].owner
and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i)
and extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0
and extract_constant(left, only_process_constants=True) == 0
and right == cond_var.owner.inputs[0]
):
assert node.outputs[0].type.is_super(right.type)
# No need to copy over stacktrace, because the right input node
# already has its own stacktrace
return [right]
@register_canonicalize @register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论