提交 5f42dd99 authored 作者: Frederic's avatar Frederic

Remove buildbot error in DebugMode. Mark opt as remove inf.

上级 76b2c090
......@@ -1407,9 +1407,11 @@ class PatternSub(LocalOptimizer):
def __init__(self, in_pattern, out_pattern,
allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False,
tracks=(), get_nodes=None):
tracks=(), get_nodes=None,
values_eq_approx=None):
self.in_pattern = in_pattern
self.out_pattern = out_pattern
self.values_eq_approx = values_eq_approx
if isinstance(in_pattern, (list, tuple)):
self.op = self.in_pattern[0]
elif isinstance(in_pattern, dict):
......@@ -1451,6 +1453,8 @@ class PatternSub(LocalOptimizer):
ret = self.transform(real_node, get_nodes=False)
if ret is not False and ret is not None:
assert len(real_node.outputs) == len(ret)
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
return dict(izip(real_node.outputs, ret))
if node.op != self.op:
......@@ -1534,8 +1538,10 @@ class PatternSub(LocalOptimizer):
else:
return pattern.clone()
p = self.out_pattern
new = build(p, u)
return [new]
ret = build(p, u)
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
return [ret]
else:
return False
......
......@@ -18,6 +18,7 @@ from theano.configparser import AddConfigVar, BoolParam
from theano.printing import pprint
from theano.tensor import basic as tensor
from theano.tensor import elemwise, opt, NotScalarConstantError
from theano.tensor.type import values_eq_approx_remove_inf
############
......@@ -314,6 +315,9 @@ theano.compile.optdb['uncanonicalize'].register("local_hard_sigmoid",
class ScalarSoftplus(scalar.UnaryScalarOp):
"""
This help numerical stability.
"""
@staticmethod
def static_impl(x):
if x < -30.0:
......@@ -378,6 +382,7 @@ logsigm_to_softplus = gof.PatternSub(
(tensor.log, (sigmoid, 'x')),
(tensor.neg, (softplus, (tensor.neg, 'x'))),
allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1)
......@@ -403,12 +408,14 @@ log1msigm_to_softplus = gof.PatternSub(
(sigmoid, 'x'))),
(tensor.neg, (softplus, 'x')),
allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1)
log1pexp_to_softplus = gof.PatternSub(
(tensor.log1p,
(tensor.exp, 'x')),
(softplus, 'x'),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True)
opt.register_stabilize(logsigm_to_softplus, name='logsigm_to_softplus')
......
......@@ -4228,7 +4228,9 @@ def test_constant_get_stabilized():
"""
x2 = T.scalar()
y2 = T.log(1 + T.exp(x2))
f2 = theano.function([x2], y2)
mode = theano.compile.get_default_mode()
mode.check_isfinite = False
f2 = theano.function([x2], y2, mode=mode)
try:
assert len(f2.maker.fgraph.toposort()) == 1
assert f2.maker.fgraph.toposort()[0].op == \
......@@ -4237,14 +4239,14 @@ def test_constant_get_stabilized():
x = T.as_tensor_variable(800)
y = T.log(1 + T.exp(x))
f = theano.function([], y)
f = theano.function([], y, mode=mode)
assert len(f.maker.fgraph.toposort()) == 0
assert numpy.isinf(f())
# When this error is fixed, the following line should be ok.
assert f() == 800, f()
except (AssertionError, theano.compile.debugmode.InvalidValueError):
except AssertionError:
raise SkipTest('Theano optimizes constant before stabilization. '
'This breaks stabilization optimization in some '
'cases. See #504.')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论