提交 97be1af1 authored 作者: Xavier Glorot's avatar Xavier Glorot

added optimization switch_sink for mul and div

上级 cf208ead
......@@ -1163,6 +1163,72 @@ def local_join_1(node):
return [tensors[0]]
###############
# Switch opts #
###############
@register_canonicalize
@gof.local_optimizer([T.mul])
def local_mul_switch_sink(node):
"""
This optimization makes the folowing changes in the graph:
T.mul(A,T.switch(cond,0,iff),B) --> T.switch(cond,0,T.mul(A,B,iff))
T.mul(A,T.switch(cond,ift,0),B) --> T.switch(cond,T.mul(A,B,ift),0)
A and B being several (or none) symbolic variables.
This is useful because A and B may not be numerically stable and give
NaN or inf values for cases where the switch returns 0.
With this optimization T.grad(T.switch(...)) has the right behavior.
Exemple:
x -> f(x)
x -> g(x)
y = T.switch(cond,f(x),g(x))
**without the optimization
T.grad(y,x) -> grad(f(x),x) * grad(y,f(x)) + grad(g(x),x) * grad(y,g(x))
**with the optimization
T.grad(y,x) -> switch(cond,grad(f(x),x), 0) + switch(cond,0,grad(g(x),x))
This will be particularly usefull for the lazyif because we skip
an entire part of the graph.
"""
if node.op!=T.mul:
return False
for idx , i in enumerate(node.inputs):
if i.owner and i.owner.op == T.switch:
switch = i.owner
if isinstance(switch.inputs[1],Constant) and get_constant_value(switch.inputs[1]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],0,T.mul(*(listmul + [switch.inputs[2]])))]
return fct
if isinstance(switch.inputs[2],Constant) and get_constant_value(switch.inputs[2]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],T.mul(*(listmul + [switch.inputs[1]])),0)]
return fct
return False
@register_canonicalize
@gof.local_optimizer([T.true_div])
def local_div_switch_sink(node):
"""
This optimization makes the folowing changes in the graph:
T.div(T.switch(cond,0,iff),A) --> T.switch(cond,0,T.div(iff,A))
T.div(T.switch(cond,ift,0),A) --> T.switch(cond,T.div(ift,A),0)
A being a symbolic variable.
This is useful because A may not be numerically stable and give
NaN or inf values for cases where the switch returns 0.
See local_mul_switch_sink for more details.
"""
if node.op!=T.true_div and node.op!=T.int_div and node.op!=T.floor_div:
return False
op = node.op
if node.inputs[0].owner and node.inputs[0].owner.op == T.switch:
switch = node.inputs[0].owner
if isinstance(switch.inputs[1],Constant) and get_constant_value(switch.inputs[1]) == 0.:
fct = [T.switch(switch.inputs[0],0,op(switch.inputs[2],node.inputs[1]))]
return fct
if isinstance(switch.inputs[2],Constant) and get_constant_value(switch.inputs[2]) == 0.:
fct = [T.switch(switch.inputs[0],op(switch.inputs[1],node.inputs[1]),0)]
return fct
return False
##################
# Reshape opts #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论