提交 48248330 authored 作者: Frederic Bastien's avatar Frederic Bastien

use the new DebugMode machanism to disable error when an stability optimization…

use the new DebugMode machanism to disable error when an stability optimization remove nan for the local_{mul,div}_switch_sink optimization. Updated test to use it.
上级 4d541b6a
...@@ -1196,10 +1196,12 @@ def local_mul_switch_sink(node): ...@@ -1196,10 +1196,12 @@ def local_mul_switch_sink(node):
if isinstance(switch.inputs[1],Constant) and get_constant_value(switch.inputs[1]) == 0.: if isinstance(switch.inputs[1],Constant) and get_constant_value(switch.inputs[1]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:] listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],0,T.mul(*(listmul + [switch.inputs[2]])))] fct = [T.switch(switch.inputs[0],0,T.mul(*(listmul + [switch.inputs[2]])))]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct return fct
if isinstance(switch.inputs[2],Constant) and get_constant_value(switch.inputs[2]) == 0.: if isinstance(switch.inputs[2],Constant) and get_constant_value(switch.inputs[2]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:] listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],T.mul(*(listmul + [switch.inputs[1]])),0)] fct = [T.switch(switch.inputs[0],T.mul(*(listmul + [switch.inputs[1]])),0)]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct return fct
return False return False
...@@ -1223,9 +1225,11 @@ def local_div_switch_sink(node): ...@@ -1223,9 +1225,11 @@ def local_div_switch_sink(node):
switch = node.inputs[0].owner switch = node.inputs[0].owner
if isinstance(switch.inputs[1],Constant) and get_constant_value(switch.inputs[1]) == 0.: if isinstance(switch.inputs[1],Constant) and get_constant_value(switch.inputs[1]) == 0.:
fct = [T.switch(switch.inputs[0],0,op(switch.inputs[2],node.inputs[1]))] fct = [T.switch(switch.inputs[0],0,op(switch.inputs[2],node.inputs[1]))]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct return fct
if isinstance(switch.inputs[2],Constant) and get_constant_value(switch.inputs[2]) == 0.: if isinstance(switch.inputs[2],Constant) and get_constant_value(switch.inputs[2]) == 0.:
fct = [T.switch(switch.inputs[0],op(switch.inputs[1],node.inputs[1]),0)] fct = [T.switch(switch.inputs[0],op(switch.inputs[1],node.inputs[1]),0)]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct return fct
return False return False
......
...@@ -1612,13 +1612,17 @@ class T_local_switch_sink(unittest.TestCase): ...@@ -1612,13 +1612,17 @@ class T_local_switch_sink(unittest.TestCase):
[[numpy.ones((4,)),numpy.zeros((4,)),numpy.ones((4,)),numpy.zeros((4,))]] + \ [[numpy.ones((4,)),numpy.zeros((4,)),numpy.ones((4,)),numpy.zeros((4,))]] + \
[[numpy.asarray(1.0),numpy.asarray(0.0),numpy.asarray(1.0),numpy.asarray(0.0)]] [[numpy.asarray(1.0),numpy.asarray(0.0),numpy.asarray(1.0),numpy.asarray(0.0)]]
self.mode = theano.compile.mode.get_default_mode().including('canonicalize','fast_run').excluding('gpu','fusion')
self.mode = copy.copy(self.mode)
self.mode.check_isfinite = False
def test_local_mul_switch_sink(self): def test_local_mul_switch_sink(self):
c = T.dscalar() c = T.dscalar()
idx = 0 idx = 0
for condition in [(T.dmatrix('cond'),self.condm),(T.dvector('cond'),self.condv),(T.dscalar('cond'),self.conds)]: for condition in [(T.dmatrix('cond'),self.condm),(T.dvector('cond'),self.condv),(T.dscalar('cond'),self.conds)]:
for x in [(T.dmatrix('x'),self.xm),(T.dvector('x'),self.xv),(T.dscalar('x'),self.xs)]: for x in [(T.dmatrix('x'),self.xm),(T.dvector('x'),self.xv),(T.dscalar('x'),self.xs)]:
y = T.mul(T.switch(condition[0]>0,1.*x[0],0.*x[0]),T.switch(condition[0]>0,1.*x[0],T.log(c)*x[0])) y = T.mul(T.switch(condition[0]>0,1.*x[0],0.*x[0]),T.switch(condition[0]>0,1.*x[0],T.log(c)*x[0]))
f = theano.function([condition[0],x[0],c],[y], mode='FAST_RUN') f = theano.function([condition[0],x[0],c],[y], mode=self.mode)
if type(condition[1]) is list: if type(condition[1]) is list:
for i in range(len(condition[1])): for i in range(len(condition[1])):
res= f(condition[1][i],x[1],-1) res= f(condition[1][i],x[1],-1)
...@@ -1634,7 +1638,7 @@ class T_local_switch_sink(unittest.TestCase): ...@@ -1634,7 +1638,7 @@ class T_local_switch_sink(unittest.TestCase):
for condition in [(T.dmatrix('cond'),self.condm),(T.dvector('cond'),self.condv),(T.dscalar('cond'),self.conds)]: for condition in [(T.dmatrix('cond'),self.condm),(T.dvector('cond'),self.condv),(T.dscalar('cond'),self.conds)]:
for x in [(T.dmatrix('x'),self.xm),(T.dvector('x'),self.xv),(T.dscalar('x'),self.xs)]: for x in [(T.dmatrix('x'),self.xm),(T.dvector('x'),self.xv),(T.dscalar('x'),self.xs)]:
y = T.true_div(T.switch(condition[0]>0,1.*x[0],0.*x[0]),T.switch(condition[0]>0,1.*x[0],T.log(c)*x[0])) y = T.true_div(T.switch(condition[0]>0,1.*x[0],0.*x[0]),T.switch(condition[0]>0,1.*x[0],T.log(c)*x[0]))
f = theano.function([condition[0],x[0],c],[y], mode='FAST_RUN') f = theano.function([condition[0],x[0],c],[y], mode=self.mode)
if type(condition[1]) is list: if type(condition[1]) is list:
for i in range(len(condition[1])): for i in range(len(condition[1])):
res= f(condition[1][i],x[1],-1) res= f(condition[1][i],x[1],-1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论