提交 744c86af authored 作者: Frederic Bastien's avatar Frederic Bastien

BUG FIX: the stabilization optimization of the grad of log(erfc(x)) was applied…

BUG FIX: the stabilization optimization of the grad of log(erfc(x)) was applied too frequently! That test for that bug
上级 60f286bc
...@@ -2611,6 +2611,7 @@ def local_grad_log_erfc_neg(node): ...@@ -2611,6 +2611,7 @@ def local_grad_log_erfc_neg(node):
if not node.inputs[1].owner or node.inputs[1].owner.op != T.erfc: if not node.inputs[1].owner or node.inputs[1].owner.op != T.erfc:
return False return False
erfc = node.inputs[1] erfc = node.inputs[1]
erfc_x = erfc.owner.inputs[0]
if not node.inputs[0].owner: if not node.inputs[0].owner:
return False return False
...@@ -2642,6 +2643,9 @@ def local_grad_log_erfc_neg(node): ...@@ -2642,6 +2643,9 @@ def local_grad_log_erfc_neg(node):
#We use that flag to don't apply the optimization recursively #We use that flag to don't apply the optimization recursively
return False return False
if erfc_x is not x:
return False
#we move the cst outside the div. #we move the cst outside the div.
true_div_no_mul = T.true_div(exp,erfc) true_div_no_mul = T.true_div(exp,erfc)
true_div_no_mul.owner.tag.local_grad_log_erfc_neg=True true_div_no_mul.owner.tag.local_grad_log_erfc_neg=True
......
...@@ -1765,6 +1765,7 @@ class T_local_erfc(unittest.TestCase): ...@@ -1765,6 +1765,7 @@ class T_local_erfc(unittest.TestCase):
val.remove(10) val.remove(10)
val = numpy.asarray(val) val = numpy.asarray(val)
x = T.vector() x = T.vector()
y = T.vector()
#their is some nan that will happear in the graph for the log of the negatives values #their is some nan that will happear in the graph for the log of the negatives values
mode = copy.copy(self.mode) mode = copy.copy(self.mode)
...@@ -1791,7 +1792,23 @@ class T_local_erfc(unittest.TestCase): ...@@ -1791,7 +1792,23 @@ class T_local_erfc(unittest.TestCase):
#test that we work without the mul #test that we work without the mul
f = theano.function([x],T.exp(T.neg(T.sqr(x)))/T.erfc(x), mode=mode) f = theano.function([x],T.exp(T.neg(T.sqr(x)))/T.erfc(x), mode=mode)
theano.printing.debugprint(f) #theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==21, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
assert all(numpy.isfinite(f(val)))
#test that we don't work if x!=y
f = theano.function([x,y],T.exp(T.neg(T.sqr(x)))/T.erfc(y), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==5, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
f(val,val-3)
#test that we work without the sqr and neg
f = theano.function([x],T.exp(T.mul(-1,x,x))/T.erfc(x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==21, len(f.maker.env.nodes) assert len(f.maker.env.nodes)==21, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes]) assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论