提交 0958ff90 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Revised version of infer_shape for theano that takes into consideration the

changes done to the canonical form of the slice.
上级 902d822b
......@@ -2835,31 +2835,21 @@ class Subtensor(Op):
assert len(xshp) == node.inputs[0].ndim
outshp = []
actual_idx_list = list(get_idx_list(node.inputs, self.idx_list))
padded = actual_idx_list + [slice(None, None, None)] * (len(xshp) - len(self.idx_list))
padded = ( actual_idx_list +
[slice(None, None, None)]*(len(xshp)-len(self.idx_list)))
i = 0
shape_i = node.env.shape_feature.shape_i
for idx, xl in zip(padded, xshp):
if isinstance(idx, slice):
# If it is the default (None, None, None) slice, or a variant,
# the shape will be xl
if (idx.start is None or idx.start == 0)\
and (idx.stop is None or idx.stop == sys.maxint)\
and (idx.step is None or abs(idx.step) == 1):
if ( (idx.start in [None, 0])
and (idx.stop in [None, sys.maxint])
and (idx.step is None or idx.step == 1) ):
outshp.append(xl)
else:
cnf = get_canonical_form_slice(idx, xl)
if cnf[0].stop not in [None, sys.maxint]:
length = cnf[0].stop
else:
length = xl
if cnf[0].start not in [None,0]:
length = length - cnf[0].start
length = switch(lt(length,0), 0, length)
if cnf[0].step not in [None, 1]:
# any more elegant way of doing this??
length = cast(
ceil(length / cast(cnf[0].step,'float32')),'int64')
length = (cnf[0].stop - cnf[0].start -1)/cnf[0].step + 1
length = switch(lt(length,0), 0, length)
outshp.append(length)
i += 1
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论