提交 22105f78 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Caglar

Repair some opt that need some digging for constant

上级 6f1afdb7
......@@ -3206,17 +3206,14 @@ def local_incsubtensor_of_zeros(node):
y = node.inputs[1]
replace = False
try:
if get_scalar_constant_value(y, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
if replace:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return [x]
else:
return False
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
if get_scalar_constant_value(y, elemwise=False) == 0:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return [x]
except NotScalarConstantError, e:
return
@register_canonicalize('local_setsubtensor_of_allocs')
......@@ -3232,22 +3229,20 @@ def local_setsubtensor_of_constants(node):
if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc:
x = node.inputs[0]
y = node.inputs[1]
replace_x = None
replace_y = None
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
try:
replace_x = get_scalar_constant_value(x, only_process_constants=True)
replace_x = get_scalar_constant_value(x, elemwise=False)
except NotScalarConstantError:
pass
return
try:
replace_y = get_scalar_constant_value(y, only_process_constants=True)
replace_y = get_scalar_constant_value(y, elemwise=False)
except NotScalarConstantError:
pass
return
if (replace_x is not None and
replace_y is not None and
replace_x == replace_y):
if replace_x == replace_y:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
......@@ -3285,7 +3280,9 @@ def local_adv_sub1_adv_inc_sub1(node):
if idx is not idx2:
return
if (not inp.owner.op.set_instead_of_inc and
T.extract_constant(x, only_process_constants=True) != 0):
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
T.extract_constant(x, elemwise=False) != 0):
return
cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))]
if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论