提交 de731fb1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify `Subtensor.infer_shape` for reversed slices

上级 dea9940a
......@@ -948,6 +948,9 @@ class Subtensor(COp):
out[0] = np.asarray(x.__getitem__(cdata))
def infer_shape(self, fgraph, node, shapes):
def _is_constant(const, x):
return isinstance(const, Constant) and const.data.item() == x
xshp = shapes[0]
assert len(xshp) == node.inputs[0].ndim
outshp = []
......@@ -961,10 +964,17 @@ class Subtensor(COp):
# If it is the default (None, None, None) slice, or a variant,
# the shape will be xl
if (
(idx.start in [None, 0])
and (idx.stop in [None, sys.maxsize])
and (idx.step is None or idx.step == 1)
(idx.start is None or _is_constant(idx.start, 0))
and (idx.stop is None or _is_constant(idx.stop, sys.maxsize))
and (idx.step is None or _is_constant(idx.step, 1))
):
outshp.append(xl)
elif (
(idx.start is None)
and (idx.stop is None)
and _is_constant(idx.step, -1)
):
# Reverse slice
outshp.append(xl)
else:
cnf = get_canonical_form_slice(idx, xl)[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论