提交 a76cd786 authored 作者: jsalvatier's avatar jsalvatier 提交者: John Salvatier

inferring the shape like that doesn't work

上级 a003a1c4
...@@ -7202,7 +7202,7 @@ class MakeSlice(Op): ...@@ -7202,7 +7202,7 @@ class MakeSlice(Op):
return Apply(self, return Apply(self,
map(as_int_none_variable,[slc.start, slc.stop, slc.step]), map(as_int_none_variable,[slc.start, slc.stop, slc.step]),
[Slice()]) [slicetype()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
out, = out_ out, = out_
...@@ -7210,10 +7210,10 @@ class MakeSlice(Op): ...@@ -7210,10 +7210,10 @@ class MakeSlice(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
make_slice = MakeSlice make_slice = MakeSlice()
class Slice(gof.Type): class SliceType(gof.Type):
def filter(self, x, strict=False, allow_downcast=None): def filter(self, x, strict=False, allow_downcast=None):
if isinstance(x, slice): if isinstance(x, slice):
...@@ -7223,7 +7223,7 @@ class Slice(gof.Type): ...@@ -7223,7 +7223,7 @@ class Slice(gof.Type):
def __str__(self): def __str__(self):
return "slice" return "slice"
class NoneTypeT(gof.Type): class NoneTypeT(gof.Type):
...@@ -7236,7 +7236,8 @@ class NoneTypeT(gof.Type): ...@@ -7236,7 +7236,8 @@ class NoneTypeT(gof.Type):
def __str__(self): def __str__(self):
return "None" return "None"
slicetype = SliceType()
NoneConst = Constant(NoneTypeT(), None, name = 'None') NoneConst = Constant(NoneTypeT(), None, name = 'None')
def adv_broadcastable(a, idx): def adv_broadcastable(a, idx):
...@@ -7251,7 +7252,7 @@ def adv_broadcastable(a, idx): ...@@ -7251,7 +7252,7 @@ def adv_broadcastable(a, idx):
if v is NoneConst: if v is NoneConst:
return None return None
if isinstance(v.type, Slice): if isinstance(v.type, SliceType):
return slice(None,None) return slice(None,None)
return numpy.zeros( (1,)* v.ndim, int) return numpy.zeros( (1,)* v.ndim, int)
...@@ -7296,18 +7297,6 @@ class AdvancedSubtensor(Op): ...@@ -7296,18 +7297,6 @@ class AdvancedSubtensor(Op):
return self.make_node(eval_points[0], *inputs[1:]).outputs return self.make_node(eval_points[0], *inputs[1:]).outputs
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
# Really special case
if len(ishapes) == 3:
xshp, ind1shp, ind2shp = ishapes
if len(xshp) == 2 and len(ind1shp) == 1 and len(ind2shp) == 1:
# if the graph is correct, we can assume ind1shp[0] and
# ind2shp[0] will have the same value.
# Try to return the one closest to the graph input.
if node.inputs[2].owner is None:
return [ind2shp]
else:
return [ind1shp]
# Default case, we don't know
return node.fgraph.shape_feature.default_infer_shape(node, ishapes) return node.fgraph.shape_feature.default_infer_shape(node, ishapes)
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论