提交 0a3e8fce authored 作者: Amjad Almahairi's avatar Amjad Almahairi

rename optimization and change dependencies

上级 a63cd50f
...@@ -89,7 +89,7 @@ _logger = logging.getLogger('theano.scan_module.scan_opt') ...@@ -89,7 +89,7 @@ _logger = logging.getLogger('theano.scan_module.scan_opt')
list_opt_slice = [tensor.opt.local_abs_merge, list_opt_slice = [tensor.opt.local_abs_merge,
tensor.opt.local_mul_switch_sink, tensor.opt.local_mul_switch_sink,
tensor.opt.local_upcast_elemwise_constant_inputs, tensor.opt.local_upcast_elemwise_constant_inputs,
tensor.opt.local_remove_switch_const_cond, tensor.opt.local_useless_switch,
tensor.opt.constant_folding] tensor.opt.constant_folding]
......
...@@ -2439,7 +2439,7 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -2439,7 +2439,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
""" """
list_opt = [local_abs_merge, local_mul_switch_sink, list_opt = [local_abs_merge, local_mul_switch_sink,
local_upcast_elemwise_constant_inputs, local_upcast_elemwise_constant_inputs,
local_remove_switch_const_cond, constant_folding] local_useless_switch, constant_folding]
if type(slice1) is not slice: if type(slice1) is not slice:
raise ValueError(('First provided slice should actually be of type' raise ValueError(('First provided slice should actually be of type'
...@@ -3256,7 +3256,7 @@ def local_join_make_vector(node): ...@@ -3256,7 +3256,7 @@ def local_join_make_vector(node):
# Switch opts # # Switch opts #
############### ###############
@register_canonicalize('fast_compile') @register_canonicalize('fast_compile', 'local_remove_switch_const_cond')
@register_specialize @register_specialize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
def local_useless_switch(node): def local_useless_switch(node):
...@@ -3266,7 +3266,6 @@ def local_useless_switch(node): ...@@ -3266,7 +3266,6 @@ def local_useless_switch(node):
if cond is constant and cond == 0: right if cond is constant and cond == 0: right
if cond is constant and cond != 0: left if cond is constant and cond != 0: left
if left is right -> left if left is right -> left
if left equal right -> left
T.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) T.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
""" """
...@@ -3292,17 +3291,12 @@ def local_useless_switch(node): ...@@ -3292,17 +3291,12 @@ def local_useless_switch(node):
# if left is right -> left # if left is right -> left
if node.inputs[1] is node.inputs[2]: if node.inputs[1] is node.inputs[2]:
return [node.inputs[1]] return [node.inputs[1]]
# if left equal right -> left
if (T.extract_constant(node.inputs[1]) == # This case happens with scan.
T.extract_constant(node.inputs[2])):
if node.inputs[1].type == node.outputs[0].type:
return [node.inputs[1]]
if node.inputs[2].type == node.outputs[0].type:
return [node.inputs[2]]
# This case happen with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
left = node.inputs[1] left = node.inputs[1]
right = node.inputs[2] right = node.inputs[2]
if (cond.owner and if (cond.owner and
isinstance(cond.owner.op, T.Elemwise) and isinstance(cond.owner.op, T.Elemwise) and
isinstance(cond.owner.op.scalar_op, scalar.LE) and isinstance(cond.owner.op.scalar_op, scalar.LE) and
...@@ -3315,7 +3309,6 @@ def local_useless_switch(node): ...@@ -3315,7 +3309,6 @@ def local_useless_switch(node):
return [right] return [right]
return False return False
return False return False
local_remove_switch_const_cond = local_useless_switch
#@register_canonicalize #@register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论