提交 382d7e07 authored 作者: Xavier Glorot's avatar Xavier Glorot

Added try statement to switch sink optimizations not to raise constant checking errors

上级 8088138b
......@@ -1194,16 +1194,22 @@ def local_mul_switch_sink(node):
for idx , i in enumerate(node.inputs):
if i.owner and i.owner.op == T.switch:
switch = i.owner
try:
if get_constant_value(switch.inputs[1]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],0,T.mul(*(listmul + [switch.inputs[2]])))]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct
except TypeError:
pass
try:
if get_constant_value(switch.inputs[2]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],T.mul(*(listmul + [switch.inputs[1]])),0)]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct
except TypeError:
pass
return False
@register_canonicalize
......@@ -1224,14 +1230,20 @@ def local_div_switch_sink(node):
op = node.op
if node.inputs[0].owner and node.inputs[0].owner.op == T.switch:
switch = node.inputs[0].owner
try:
if get_constant_value(switch.inputs[1]) == 0.:
fct = [T.switch(switch.inputs[0],0,op(switch.inputs[2],node.inputs[1]))]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct
except TypeError:
pass
try:
if get_constant_value(switch.inputs[2]) == 0.:
fct = [T.switch(switch.inputs[0],op(switch.inputs[1],node.inputs[1]),0)]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct
except TypeError:
pass
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论