提交 241e4a67 authored 作者: Xavier Glorot's avatar Xavier Glorot

Removed -is constant- tests to switch sink optimizations

上级 9b5901c2
...@@ -1194,12 +1194,12 @@ def local_mul_switch_sink(node): ...@@ -1194,12 +1194,12 @@ def local_mul_switch_sink(node):
for idx , i in enumerate(node.inputs): for idx , i in enumerate(node.inputs):
if i.owner and i.owner.op == T.switch: if i.owner and i.owner.op == T.switch:
switch = i.owner switch = i.owner
if isinstance(switch.inputs[1],Constant) and get_constant_value(switch.inputs[1]) == 0.: if get_constant_value(switch.inputs[1]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:] listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],0,T.mul(*(listmul + [switch.inputs[2]])))] fct = [T.switch(switch.inputs[0],0,T.mul(*(listmul + [switch.inputs[2]])))]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct return fct
if isinstance(switch.inputs[2],Constant) and get_constant_value(switch.inputs[2]) == 0.: if get_constant_value(switch.inputs[2]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:] listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],T.mul(*(listmul + [switch.inputs[1]])),0)] fct = [T.switch(switch.inputs[0],T.mul(*(listmul + [switch.inputs[1]])),0)]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论