提交 de731fb1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify `Subtensor.infer_shape` for reversed slices

上级 dea9940a
...@@ -948,6 +948,9 @@ class Subtensor(COp): ...@@ -948,6 +948,9 @@ class Subtensor(COp):
out[0] = np.asarray(x.__getitem__(cdata)) out[0] = np.asarray(x.__getitem__(cdata))
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
def _is_constant(const, x):
return isinstance(const, Constant) and const.data.item() == x
xshp = shapes[0] xshp = shapes[0]
assert len(xshp) == node.inputs[0].ndim assert len(xshp) == node.inputs[0].ndim
outshp = [] outshp = []
...@@ -961,10 +964,17 @@ class Subtensor(COp): ...@@ -961,10 +964,17 @@ class Subtensor(COp):
# If it is the default (None, None, None) slice, or a variant, # If it is the default (None, None, None) slice, or a variant,
# the shape will be xl # the shape will be xl
if ( if (
(idx.start in [None, 0]) (idx.start is None or _is_constant(idx.start, 0))
and (idx.stop in [None, sys.maxsize]) and (idx.stop is None or _is_constant(idx.stop, sys.maxsize))
and (idx.step is None or idx.step == 1) and (idx.step is None or _is_constant(idx.step, 1))
):
outshp.append(xl)
elif (
(idx.start is None)
and (idx.stop is None)
and _is_constant(idx.step, -1)
): ):
# Reverse slice
outshp.append(xl) outshp.append(xl)
else: else:
cnf = get_canonical_form_slice(idx, xl)[0] cnf = get_canonical_form_slice(idx, xl)[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论