提交 a76cd786 authored 作者: jsalvatier's avatar jsalvatier 提交者: John Salvatier

inferring the shape like that doesn't work

上级 a003a1c4
......@@ -7202,7 +7202,7 @@ class MakeSlice(Op):
return Apply(self,
map(as_int_none_variable,[slc.start, slc.stop, slc.step]),
[Slice()])
[slicetype()])
def perform(self, node, inp, out_):
out, = out_
......@@ -7210,10 +7210,10 @@ class MakeSlice(Op):
def __str__(self):
return self.__class__.__name__
make_slice = MakeSlice
make_slice = MakeSlice()
class Slice(gof.Type):
class SliceType(gof.Type):
def filter(self, x, strict=False, allow_downcast=None):
if isinstance(x, slice):
......@@ -7223,7 +7223,7 @@ class Slice(gof.Type):
def __str__(self):
return "slice"
class NoneTypeT(gof.Type):
......@@ -7236,7 +7236,8 @@ class NoneTypeT(gof.Type):
def __str__(self):
return "None"
slicetype = SliceType()
NoneConst = Constant(NoneTypeT(), None, name = 'None')
def adv_broadcastable(a, idx):
......@@ -7251,7 +7252,7 @@ def adv_broadcastable(a, idx):
if v is NoneConst:
return None
if isinstance(v.type, Slice):
if isinstance(v.type, SliceType):
return slice(None,None)
return numpy.zeros( (1,)* v.ndim, int)
......@@ -7296,18 +7297,6 @@ class AdvancedSubtensor(Op):
return self.make_node(eval_points[0], *inputs[1:]).outputs
def infer_shape(self, node, ishapes):
# Really special case
if len(ishapes) == 3:
xshp, ind1shp, ind2shp = ishapes
if len(xshp) == 2 and len(ind1shp) == 1 and len(ind2shp) == 1:
# if the graph is correct, we can assume ind1shp[0] and
# ind2shp[0] will have the same value.
# Try to return the one closest to the graph input.
if node.inputs[2].owner is None:
return [ind2shp]
else:
return [ind1shp]
# Default case, we don't know
return node.fgraph.shape_feature.default_infer_shape(node, ishapes)
def perform(self, node, inputs, out_):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论