提交 56584f9c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Implemented infer_shape for subtensor ( which was not complete before )

上级 9c07bd65
......@@ -2852,7 +2852,8 @@ class Subtensor(Op):
xshp = shapes[0]
assert len(xshp) == node.inputs[0].ndim
outshp = []
padded = self.idx_list + [slice(None, None, None)] * (len(xshp) - len(self.idx_list))
actual_idx_list = list(get_idx_list(node.inputs, self.idx_list))
padded = actual_idx_list + [slice(None, None, None)] * (len(xshp) - len(self.idx_list))
i = 0
shape_i = node.env.shape_feature.shape_i
for idx, xl in zip(padded, xshp):
......@@ -2861,11 +2862,23 @@ class Subtensor(Op):
# the shape will be xl
if (idx.start is None or idx.start == 0)\
and (idx.stop is None or idx.stop == sys.maxint)\
and (idx.step is None or idx.step == 1):
and (idx.step is None or abs(idx.step) == 1):
outshp.append(xl)
else:
# Not implemented yet
outshp.append(shape_i(i)(node.outputs[0]))
cnf = get_canonical_form_slice(idx, xl)
if cnf[0].stop not in [None, sys.maxint]:
length = cnf[0].stop
else:
length = xl
if cnf[0].start not in [None,0]:
length = length - cnf[0].start
length = switch(lt(length,0), 0, length)
if cnf[0].step not in [None, 1]:
# any more elegant way of doing this??
length = cast(
ceil(length / cast(cnf[0].step,'float32')),'int64')
outshp.append(length)
i += 1
else:
# That dimension is dropped
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论