提交 df1176a3 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Refactor local_useless_subtensor to allow earlier exiting of the function

上级 b9bfe165
...@@ -1947,15 +1947,17 @@ def local_useless_subtensor(node): ...@@ -1947,15 +1947,17 @@ def local_useless_subtensor(node):
# is not a useless subtensor # is not a useless subtensor
return False return False
length_pos_data = maxsize for pos, idx in enumerate(cdata):
length_pos = shape_of[node.inputs[0]][pos] length_pos = shape_of[node.inputs[0]][pos]
try:
length_pos_data = get_scalar_constant_value(length_pos)
except NotScalarConstantError:
pass
if isinstance(idx.stop, (int, numpy.integer)): if isinstance(idx.stop, (int, numpy.integer)):
length_pos_data = maxsize
try:
length_pos_data = get_scalar_constant_value(length_pos)
except NotScalarConstantError:
pass
if idx.stop < length_pos_data: if idx.stop < length_pos_data:
return False return False
elif isinstance(idx.stop, gof.Variable): elif isinstance(idx.stop, gof.Variable):
...@@ -1996,10 +1998,10 @@ def local_useless_subtensor(node): ...@@ -1996,10 +1998,10 @@ def local_useless_subtensor(node):
length = get_scalar_constant_value(shape_of[node.inputs[0]][0]) length = get_scalar_constant_value(shape_of[node.inputs[0]][0])
except NotScalarConstantError: except NotScalarConstantError:
return False return False
# get index (which must be a vector by definition) # get index (which must be a vector by definition)
idx = node.inputs[1] idx = node.inputs[1]
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for # `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
# this optimization # this optimization
if isinstance(idx, T.Constant): if isinstance(idx, T.Constant):
...@@ -2014,7 +2016,7 @@ def local_useless_subtensor(node): ...@@ -2014,7 +2016,7 @@ def local_useless_subtensor(node):
idx.owner.inputs) idx.owner.inputs)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
if start != 0: if start != 0:
return False return False
if stop != length: if stop != length:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论