提交 48248330 authored 作者: Frederic Bastien's avatar Frederic Bastien

use the new DebugMode machanism to disable error when an stability optimization…

use the new DebugMode machanism to disable error when an stability optimization remove nan for the local_{mul,div}_switch_sink optimization. Updated test to use it.
上级 4d541b6a
......@@ -1196,10 +1196,12 @@ def local_mul_switch_sink(node):
if isinstance(switch.inputs[1],Constant) and get_constant_value(switch.inputs[1]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],0,T.mul(*(listmul + [switch.inputs[2]])))]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct
if isinstance(switch.inputs[2],Constant) and get_constant_value(switch.inputs[2]) == 0.:
listmul = node.inputs[:idx] + node.inputs[idx+1:]
fct = [T.switch(switch.inputs[0],T.mul(*(listmul + [switch.inputs[1]])),0)]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct
return False
......@@ -1223,9 +1225,11 @@ def local_div_switch_sink(node):
switch = node.inputs[0].owner
if isinstance(switch.inputs[1],Constant) and get_constant_value(switch.inputs[1]) == 0.:
fct = [T.switch(switch.inputs[0],0,op(switch.inputs[2],node.inputs[1]))]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct
if isinstance(switch.inputs[2],Constant) and get_constant_value(switch.inputs[2]) == 0.:
fct = [T.switch(switch.inputs[0],op(switch.inputs[1],node.inputs[1]),0)]
fct[0].values_eq_approx = fct[0].type.values_eq_approx_remove_nan
return fct
return False
......
......@@ -1612,13 +1612,17 @@ class T_local_switch_sink(unittest.TestCase):
[[numpy.ones((4,)),numpy.zeros((4,)),numpy.ones((4,)),numpy.zeros((4,))]] + \
[[numpy.asarray(1.0),numpy.asarray(0.0),numpy.asarray(1.0),numpy.asarray(0.0)]]
self.mode = theano.compile.mode.get_default_mode().including('canonicalize','fast_run').excluding('gpu','fusion')
self.mode = copy.copy(self.mode)
self.mode.check_isfinite = False
def test_local_mul_switch_sink(self):
c = T.dscalar()
idx = 0
for condition in [(T.dmatrix('cond'),self.condm),(T.dvector('cond'),self.condv),(T.dscalar('cond'),self.conds)]:
for x in [(T.dmatrix('x'),self.xm),(T.dvector('x'),self.xv),(T.dscalar('x'),self.xs)]:
y = T.mul(T.switch(condition[0]>0,1.*x[0],0.*x[0]),T.switch(condition[0]>0,1.*x[0],T.log(c)*x[0]))
f = theano.function([condition[0],x[0],c],[y], mode='FAST_RUN')
f = theano.function([condition[0],x[0],c],[y], mode=self.mode)
if type(condition[1]) is list:
for i in range(len(condition[1])):
res= f(condition[1][i],x[1],-1)
......@@ -1634,7 +1638,7 @@ class T_local_switch_sink(unittest.TestCase):
for condition in [(T.dmatrix('cond'),self.condm),(T.dvector('cond'),self.condv),(T.dscalar('cond'),self.conds)]:
for x in [(T.dmatrix('x'),self.xm),(T.dvector('x'),self.xv),(T.dscalar('x'),self.xs)]:
y = T.true_div(T.switch(condition[0]>0,1.*x[0],0.*x[0]),T.switch(condition[0]>0,1.*x[0],T.log(c)*x[0]))
f = theano.function([condition[0],x[0],c],[y], mode='FAST_RUN')
f = theano.function([condition[0],x[0],c],[y], mode=self.mode)
if type(condition[1]) is list:
for i in range(len(condition[1])):
res= f(condition[1][i],x[1],-1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论