提交 94ec1975 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix the broadcastable flags in case of broadcastable input.

上级 63cb5dd9
......@@ -1523,7 +1523,9 @@ class AdvancedSubtensor1(Op):
raise TypeError('index must be vector')
if x_.type.ndim == 0:
raise TypeError('cannot index into a scalar')
return Apply(self, [x_, ilist_], [x_.type()])
bcast = (ilist_.broadcastable[0],) + x_.broadcastable[1:]
return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype,
broadcastable=bcast)()])
def perform(self, node, inp, out_):
x, i = inp
......@@ -1742,8 +1744,9 @@ class AdvancedIncSubtensor1(Op):
'cannot %s x subtensor with ndim=%s'
' by y with ndim=%s to x subtensor with ndim=%s ' % (
opname, x_.type.ndim, y_.type.ndim))
return Apply(self, [x_, y_, ilist_], [x_.type()])
bcast = (ilist_.broadcastable[0],) + x_.broadcastable[1:]
return Apply(self, [x_, y_, ilist_], [TensorType(dtype=x.dtype,
broadcastable=bcast)()])
def perform(self, node, inp, out_):
# TODO opt to make this inplace
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论