提交 d9b62292 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix AdvancedSubtensor.infer_shape

上级 b8b1dcf0
...@@ -3304,9 +3304,21 @@ class AdvancedSubtensor(Op): ...@@ -3304,9 +3304,21 @@ class AdvancedSubtensor(Op):
% (x.ndim, ','.join(str(input.ndim) for input in inputs))) % (x.ndim, ','.join(str(input.ndim) for input in inputs)))
raise NotImplementedError('Advanced indexing of x with arguments (%s) not supported yet'\ raise NotImplementedError('Advanced indexing of x with arguments (%s) not supported yet'\
% ','.join(str(input) for input in inputs)) % ','.join(str(input) for input in inputs))
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
xshp, ind1shp, ind2shp = ishapes # Really special case
return [ind2shp] if len(ishapes) == 3:
xshp, ind1shp, ind2shp = ishapes
if len(xshp) == 2 and len(ind1shp) == 1 and len(ind2shp) == 1:
# if the graph is correct, we can assume ind1shp[0] and
# ind2shp[0] will have the same value.
# Try to return the one closest to the graph input.
if node.inputs[2].owner is None:
return [ind2shp]
else:
return [ind1shp]
# Default case, we don't know
return node.env.shape_feature.default_infer_shape(node, ishapes)
def perform(self, node, inputs, (out,)): def perform(self, node, inputs, (out,)):
# TODO: in general, we need to re-pack the inputs into a valid index, just like # TODO: in general, we need to re-pack the inputs into a valid index, just like
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论