提交 3db127e9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid recreating Ops in `local_uint_constant_indices`

This prevents undoing the rewrite that introduces AdvancedSubtensor1
上级 a14cb2bd
......@@ -1805,7 +1805,8 @@ def local_join_subtensors(fgraph, node):
def local_uint_constant_indices(fgraph, node):
"""Convert constant indices to unsigned dtypes."""
if isinstance(node.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1):
op = node.op
if isinstance(op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1):
x, y, *indices = node.inputs
else:
x, *indices = node.inputs
......@@ -1864,21 +1865,18 @@ def local_uint_constant_indices(fgraph, node):
if not has_new_index:
return False
new_out = x[tuple(new_indices)]
if y is not None:
new_out = inc_subtensor(
new_out,
y,
inplace=node.op.inplace,
set_instead_of_inc=node.op.set_instead_of_inc,
ignore_duplicates=getattr(node.op, "ignore_duplicates", False),
)
new_outs = new_out.owner.outputs
copy_stack_trace(node.outputs, new_outs)
return new_outs
if isinstance(op, Subtensor | IncSubtensor):
# Basic index Ops contain information about the dtype of the indices, so wee have to recreate them
props = op._props_dict()
props["idx_list"] = new_indices
op = type(op)(**props)
# Basic index Ops don't expect slices, but the respective start/step/stop
new_indices = get_slice_elements(new_indices)
new_args = (x, *new_indices) if y is None else (x, y, *new_indices)
new_out = op(*new_args)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]
@register_canonicalize("shape_unsafe")
......
......@@ -550,7 +550,10 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
return res_shape
def get_slice_elements(idxs: list, cond: Callable) -> list:
def get_slice_elements(
idxs: list,
cond: Callable = lambda x: isinstance(x, Variable),
) -> list:
"""Extract slice elements conditional on a given predicate function.
Parameters
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论