提交 22105f78 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Caglar

Repair some opt that need some digging for constant

上级 6f1afdb7
...@@ -3206,17 +3206,14 @@ def local_incsubtensor_of_zeros(node): ...@@ -3206,17 +3206,14 @@ def local_incsubtensor_of_zeros(node):
y = node.inputs[1] y = node.inputs[1]
replace = False replace = False
try: try:
if get_scalar_constant_value(y, only_process_constants=True) == 0: # Don't use only_process_constants=True. We need to
replace = True # investigate Alloc of 0s but with non constant shape.
except NotScalarConstantError: if get_scalar_constant_value(y, elemwise=False) == 0:
pass # No need to copy over the stacktrace,
# because x should already have a stacktrace
if replace: return [x]
# No need to copy over the stacktrace, except NotScalarConstantError, e:
# because x should already have a stacktrace return
return [x]
else:
return False
@register_canonicalize('local_setsubtensor_of_allocs') @register_canonicalize('local_setsubtensor_of_allocs')
...@@ -3232,22 +3229,20 @@ def local_setsubtensor_of_constants(node): ...@@ -3232,22 +3229,20 @@ def local_setsubtensor_of_constants(node):
if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc: if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc:
x = node.inputs[0] x = node.inputs[0]
y = node.inputs[1] y = node.inputs[1]
replace_x = None
replace_y = None
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
try: try:
replace_x = get_scalar_constant_value(x, only_process_constants=True) replace_x = get_scalar_constant_value(x, elemwise=False)
except NotScalarConstantError: except NotScalarConstantError:
pass return
try: try:
replace_y = get_scalar_constant_value(y, only_process_constants=True) replace_y = get_scalar_constant_value(y, elemwise=False)
except NotScalarConstantError: except NotScalarConstantError:
pass return
if (replace_x is not None and if replace_x == replace_y:
replace_y is not None and
replace_x == replace_y):
# No need to copy over the stacktrace, # No need to copy over the stacktrace,
# because x should already have a stacktrace # because x should already have a stacktrace
...@@ -3285,7 +3280,9 @@ def local_adv_sub1_adv_inc_sub1(node): ...@@ -3285,7 +3280,9 @@ def local_adv_sub1_adv_inc_sub1(node):
if idx is not idx2: if idx is not idx2:
return return
if (not inp.owner.op.set_instead_of_inc and if (not inp.owner.op.set_instead_of_inc and
T.extract_constant(x, only_process_constants=True) != 0): # Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
T.extract_constant(x, elemwise=False) != 0):
return return
cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))] cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))]
if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0): if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论