提交 22b39728 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add infer_shape method to Subtensor

上级 211e2a12
...@@ -2172,6 +2172,31 @@ class Subtensor(Op): ...@@ -2172,6 +2172,31 @@ class Subtensor(Op):
cdata = cdata[0] cdata = cdata[0]
out[0] = numpy.asarray(x.__getitem__(cdata)) out[0] = numpy.asarray(x.__getitem__(cdata))
def infer_shape(self, node, shapes):
xshp = shapes[0]
assert len(xshp) == node.inputs[0].ndim
outshp = []
padded = self.idx_list + [slice(None, None, None)] * (len(xshp) - len(self.idx_list))
i = 0
for idx, xl in zip(padded, xshp):
if isinstance(idx, slice):
# If it is the default (None, None, None) slice, or a variant,
# the shape will be xl
if (idx.start is None or idx.start == 0)\
and (idx.stop is None or idx.stop == sys.maxint)\
and (idx.step is None or idx.step == 1):
outshp.append(xl)
else:
#No easy way to compute the shape
outshp.append(Shape_i(i)(node.outputs[0]))
i += 1
else:
# That dimension is dropped
pass
assert i == node.outputs[0].ndim
assert len(outshp) == node.outputs[0].ndim
return [outshp]
def grad(self, inputs, (gz,)): def grad(self, inputs, (gz,)):
x = inputs[0] x = inputs[0]
rest = inputs[1:] rest = inputs[1:]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论