提交 a0fa399a authored 作者: Saizheng Zhang's avatar Saizheng Zhang

opt:inc_subtensor:alloc new space when broadcastable

上级 6f85d17f
......@@ -1911,7 +1911,13 @@ def local_subtensor_inc_subtensor(node):
return
if x.owner.inputs[2:] == node.inputs[1:] and tuple(x.owner.op.idx_list) == tuple(node.op.idx_list):
return [x.owner.inputs[1]]
# if x[idx] and y have the same ndim (and shape), directly return y
if x.owner.inputs[0].ndim-(len(node.op.idx_list)-sum([isinstance(idx, slice) for idx in node.op.idx_list])) == x.owner.inputs[1].ndim:
return [x.owner.inputs[1]]
# else y is broadcastable, return alloc of broadcastable y
else:
x_subtensor = Subtensor(node.op.idx_list)(x.owner.inputs[0], *x.owner.inputs[2:])
return [T.alloc(x.owner.inputs[1], *x_subtensor.shape)]
else:
return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论