提交 0a3e8fce authored 作者: Amjad Almahairi's avatar Amjad Almahairi

rename optimization and change dependencies

上级 a63cd50f
......@@ -89,7 +89,7 @@ _logger = logging.getLogger('theano.scan_module.scan_opt')
list_opt_slice = [tensor.opt.local_abs_merge,
tensor.opt.local_mul_switch_sink,
tensor.opt.local_upcast_elemwise_constant_inputs,
tensor.opt.local_remove_switch_const_cond,
tensor.opt.local_useless_switch,
tensor.opt.constant_folding]
......
......@@ -1609,7 +1609,7 @@ def local_useless_elemwise(node):
return [node.inputs[1]]
elif const_val == 0:
return zeros_like(node, 1)
if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1])
if const_val == 1:
......@@ -2439,7 +2439,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
"""
list_opt = [local_abs_merge, local_mul_switch_sink,
local_upcast_elemwise_constant_inputs,
local_remove_switch_const_cond, constant_folding]
local_useless_switch, constant_folding]
if type(slice1) is not slice:
raise ValueError(('First provided slice should actually be of type'
......@@ -3256,7 +3256,7 @@ def local_join_make_vector(node):
# Switch opts #
###############
@register_canonicalize('fast_compile')
@register_canonicalize('fast_compile', 'local_remove_switch_const_cond')
@register_specialize
@gof.local_optimizer([T.Elemwise])
def local_useless_switch(node):
......@@ -3266,7 +3266,6 @@ def local_useless_switch(node):
if cond is constant and cond == 0: right
if cond is constant and cond != 0: left
if left is right -> left
if left equal right -> left
T.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
"""
......@@ -3292,17 +3291,12 @@ def local_useless_switch(node):
# if left is right -> left
if node.inputs[1] is node.inputs[2]:
return [node.inputs[1]]
# if left equal right -> left
if (T.extract_constant(node.inputs[1]) ==
T.extract_constant(node.inputs[2])):
if node.inputs[1].type == node.outputs[0].type:
return [node.inputs[1]]
if node.inputs[2].type == node.outputs[0].type:
return [node.inputs[2]]
# This case happen with scan.
# This case happens with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
left = node.inputs[1]
right = node.inputs[2]
if (cond.owner and
isinstance(cond.owner.op, T.Elemwise) and
isinstance(cond.owner.op.scalar_op, scalar.LE) and
......@@ -3315,7 +3309,6 @@ def local_useless_switch(node):
return [right]
return False
return False
local_remove_switch_const_cond = local_useless_switch
#@register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论