提交 382d7e07 authored 作者: Xavier Glorot's avatar Xavier Glorot

Added try statement to switch sink optimizations not to raise constant checking errors

上级 8088138b
...@@ -1194,16 +1194,22 @@ def local_mul_switch_sink(node): ...@@ -1194,16 +1194,22 @@ def local_mul_switch_sink(node):
for idx , i in enumerate(node.inputs): for idx , i in enumerate(node.inputs):
if i.owner and i.owner.op == T.switch: if i.owner and i.owner.op == T.switch:
switch = i.owner switch = i.owner
if get_constant_value(switch.inputs[1]) == 0.: try:
listmul = node.inputs[:idx] + node.inputs[idx+1:] if get_constant_value(switch.inputs[1]) == 0.:
fct = [T.switch(switch.inputs[0],0,T.mul(*(listmul + [switch.inputs[2]])))] listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan fct = [T.switch(switch.inputs[0],0,T.mul(*(listmul + [switch.inputs[2]])))]
return fct fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
if get_constant_value(switch.inputs[2]) == 0.: return fct
listmul = node.inputs[:idx] + node.inputs[idx+1:] except TypeError:
fct = [T.switch(switch.inputs[0],T.mul(*(listmul + [switch.inputs[1]])),0)] pass
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan try:
return fct if get_constant_value(switch.inputs[2]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],T.mul(*(listmul + [switch.inputs[1]])),0)]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct
except TypeError:
pass
return False return False
@register_canonicalize @register_canonicalize
...@@ -1224,14 +1230,20 @@ def local_div_switch_sink(node): ...@@ -1224,14 +1230,20 @@ def local_div_switch_sink(node):
op = node.op op = node.op
if node.inputs[0].owner and node.inputs[0].owner.op == T.switch: if node.inputs[0].owner and node.inputs[0].owner.op == T.switch:
switch = node.inputs[0].owner switch = node.inputs[0].owner
if get_constant_value(switch.inputs[1]) == 0.: try:
fct = [T.switch(switch.inputs[0],0,op(switch.inputs[2],node.inputs[1]))] if get_constant_value(switch.inputs[1]) == 0.:
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan fct = [T.switch(switch.inputs[0],0,op(switch.inputs[2],node.inputs[1]))]
return fct fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
if get_constant_value(switch.inputs[2]) == 0.: return fct
fct = [T.switch(switch.inputs[0],op(switch.inputs[1],node.inputs[1]),0)] except TypeError:
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan pass
return fct try:
if get_constant_value(switch.inputs[2]) == 0.:
fct = [T.switch(switch.inputs[0],op(switch.inputs[1],node.inputs[1]),0)]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct
except TypeError:
pass
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论