提交 3db127e9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid recreating Ops in `local_uint_constant_indices`

This prevents undoing the rewrite that introduces AdvancedSubtensor1
上级 a14cb2bd
...@@ -1805,7 +1805,8 @@ def local_join_subtensors(fgraph, node): ...@@ -1805,7 +1805,8 @@ def local_join_subtensors(fgraph, node):
def local_uint_constant_indices(fgraph, node): def local_uint_constant_indices(fgraph, node):
"""Convert constant indices to unsigned dtypes.""" """Convert constant indices to unsigned dtypes."""
if isinstance(node.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1): op = node.op
if isinstance(op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1):
x, y, *indices = node.inputs x, y, *indices = node.inputs
else: else:
x, *indices = node.inputs x, *indices = node.inputs
...@@ -1864,21 +1865,18 @@ def local_uint_constant_indices(fgraph, node): ...@@ -1864,21 +1865,18 @@ def local_uint_constant_indices(fgraph, node):
if not has_new_index: if not has_new_index:
return False return False
new_out = x[tuple(new_indices)] if isinstance(op, Subtensor | IncSubtensor):
# Basic index Ops contain information about the dtype of the indices, so wee have to recreate them
if y is not None: props = op._props_dict()
new_out = inc_subtensor( props["idx_list"] = new_indices
new_out, op = type(op)(**props)
y, # Basic index Ops don't expect slices, but the respective start/step/stop
inplace=node.op.inplace, new_indices = get_slice_elements(new_indices)
set_instead_of_inc=node.op.set_instead_of_inc,
ignore_duplicates=getattr(node.op, "ignore_duplicates", False), new_args = (x, *new_indices) if y is None else (x, y, *new_indices)
) new_out = op(*new_args)
copy_stack_trace(node.outputs[0], new_out)
new_outs = new_out.owner.outputs return [new_out]
copy_stack_trace(node.outputs, new_outs)
return new_outs
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
......
...@@ -550,7 +550,10 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ...@@ -550,7 +550,10 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
return res_shape return res_shape
def get_slice_elements(idxs: list, cond: Callable) -> list: def get_slice_elements(
idxs: list,
cond: Callable = lambda x: isinstance(x, Variable),
) -> list:
"""Extract slice elements conditional on a given predicate function. """Extract slice elements conditional on a given predicate function.
Parameters Parameters
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论