提交 ec2dab80 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix issue in local_subtensor_incsubtensor

The replacement was not correct when there were explicitly broadcastable dimensions, or if the dtype changed.
上级 ecfc65ec
...@@ -1910,14 +1910,22 @@ def local_subtensor_inc_subtensor(node): ...@@ -1910,14 +1910,22 @@ def local_subtensor_inc_subtensor(node):
if not x.owner.op.set_instead_of_inc: if not x.owner.op.set_instead_of_inc:
return return
if x.owner.inputs[2:] == node.inputs[1:] and tuple(x.owner.op.idx_list) == tuple(node.op.idx_list): if (x.owner.inputs[2:] == node.inputs[1:] and
# if x[idx] and y have the same ndim (and shape), directly return y tuple(x.owner.op.idx_list) == tuple(node.op.idx_list)):
if x.owner.inputs[0].ndim - (len(node.op.idx_list) - sum([isinstance(idx, slice) for idx in node.op.idx_list])) == x.owner.inputs[1].ndim: out = node.outputs[0]
return [x.owner.inputs[1]] y = x.owner.inputs[1]
# else y is broadcastable, return alloc of broadcastable y # If the dtypes differ, cast y into x.dtype
if x.dtype != y.dtype:
y = y.astype(x.dtype)
if out.type == y.type:
# if x[idx] and y have the same type, directly return y
return [y]
else: else:
# The difference is related to broadcasting pattern
assert out.broadcastable != y.broadcastable
# We have to alloc y to the shape of x[idx]
x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:]) x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:])
return [T.alloc(x.owner.inputs[1], *x_subtensor.shape)] return [T.alloc(y, *x_subtensor.shape)]
else: else:
return return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论