提交 efc9d693 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

local_subtensor_make_vector: don't return make_vector when slice keeps only one item

上级 c2ede260
...@@ -613,10 +613,6 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -613,10 +613,6 @@ def local_subtensor_make_vector(fgraph, node):
something more general for constant ``*Subtensor*`` graphs (or perhaps something more general for constant ``*Subtensor*`` graphs (or perhaps
include this kind of work in the constant folding). include this kind of work in the constant folding).
""" """
if not isinstance(node.op, Subtensor | AdvancedSubtensor1):
return False
x = node.inputs[0] x = node.inputs[0]
if not (x.owner and isinstance(x.owner.op, MakeVector)): if not (x.owner and isinstance(x.owner.op, MakeVector)):
...@@ -666,7 +662,11 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -666,7 +662,11 @@ def local_subtensor_make_vector(fgraph, node):
const_slice = get_constant_idx( const_slice = get_constant_idx(
node.op.idx_list, node.inputs, allow_partial=False node.op.idx_list, node.inputs, allow_partial=False
)[0] )[0]
ret = make_vector_op(*x.owner.inputs[const_slice]) sliced_inputs = x.owner.inputs[const_slice]
if len(sliced_inputs) == 1:
ret = expand_dims(sliced_inputs[0], axis=0)
else:
ret = make_vector_op(*sliced_inputs)
copy_stack_trace(node.outputs, ret) copy_stack_trace(node.outputs, ret)
return [ret] return [ret]
except NotScalarConstantError: except NotScalarConstantError:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论