提交 5f42dd99 authored 作者: Frederic's avatar Frederic

Remove buildbot error in DebugMode. Mark opt as remove inf.

上级 76b2c090
...@@ -1407,9 +1407,11 @@ class PatternSub(LocalOptimizer): ...@@ -1407,9 +1407,11 @@ class PatternSub(LocalOptimizer):
def __init__(self, in_pattern, out_pattern, def __init__(self, in_pattern, out_pattern,
allow_multiple_clients=False, allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False, skip_identities_fn=None, name=None, pdb=False,
tracks=(), get_nodes=None): tracks=(), get_nodes=None,
values_eq_approx=None):
self.in_pattern = in_pattern self.in_pattern = in_pattern
self.out_pattern = out_pattern self.out_pattern = out_pattern
self.values_eq_approx = values_eq_approx
if isinstance(in_pattern, (list, tuple)): if isinstance(in_pattern, (list, tuple)):
self.op = self.in_pattern[0] self.op = self.in_pattern[0]
elif isinstance(in_pattern, dict): elif isinstance(in_pattern, dict):
...@@ -1451,6 +1453,8 @@ class PatternSub(LocalOptimizer): ...@@ -1451,6 +1453,8 @@ class PatternSub(LocalOptimizer):
ret = self.transform(real_node, get_nodes=False) ret = self.transform(real_node, get_nodes=False)
if ret is not False and ret is not None: if ret is not False and ret is not None:
assert len(real_node.outputs) == len(ret) assert len(real_node.outputs) == len(ret)
if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
return dict(izip(real_node.outputs, ret)) return dict(izip(real_node.outputs, ret))
if node.op != self.op: if node.op != self.op:
...@@ -1534,8 +1538,10 @@ class PatternSub(LocalOptimizer): ...@@ -1534,8 +1538,10 @@ class PatternSub(LocalOptimizer):
else: else:
return pattern.clone() return pattern.clone()
p = self.out_pattern p = self.out_pattern
new = build(p, u) ret = build(p, u)
return [new] if self.values_eq_approx:
ret.tag.values_eq_approx = self.values_eq_approx
return [ret]
else: else:
return False return False
......
...@@ -18,6 +18,7 @@ from theano.configparser import AddConfigVar, BoolParam ...@@ -18,6 +18,7 @@ from theano.configparser import AddConfigVar, BoolParam
from theano.printing import pprint from theano.printing import pprint
from theano.tensor import basic as tensor from theano.tensor import basic as tensor
from theano.tensor import elemwise, opt, NotScalarConstantError from theano.tensor import elemwise, opt, NotScalarConstantError
from theano.tensor.type import values_eq_approx_remove_inf
############ ############
...@@ -314,6 +315,9 @@ theano.compile.optdb['uncanonicalize'].register("local_hard_sigmoid", ...@@ -314,6 +315,9 @@ theano.compile.optdb['uncanonicalize'].register("local_hard_sigmoid",
class ScalarSoftplus(scalar.UnaryScalarOp): class ScalarSoftplus(scalar.UnaryScalarOp):
"""
This help numerical stability.
"""
@staticmethod @staticmethod
def static_impl(x): def static_impl(x):
if x < -30.0: if x < -30.0:
...@@ -378,6 +382,7 @@ logsigm_to_softplus = gof.PatternSub( ...@@ -378,6 +382,7 @@ logsigm_to_softplus = gof.PatternSub(
(tensor.log, (sigmoid, 'x')), (tensor.log, (sigmoid, 'x')),
(tensor.neg, (softplus, (tensor.neg, 'x'))), (tensor.neg, (softplus, (tensor.neg, 'x'))),
allow_multiple_clients=True, allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1) skip_identities_fn=_skip_mul_1)
...@@ -403,12 +408,14 @@ log1msigm_to_softplus = gof.PatternSub( ...@@ -403,12 +408,14 @@ log1msigm_to_softplus = gof.PatternSub(
(sigmoid, 'x'))), (sigmoid, 'x'))),
(tensor.neg, (softplus, 'x')), (tensor.neg, (softplus, 'x')),
allow_multiple_clients=True, allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1) skip_identities_fn=_skip_mul_1)
log1pexp_to_softplus = gof.PatternSub( log1pexp_to_softplus = gof.PatternSub(
(tensor.log1p, (tensor.log1p,
(tensor.exp, 'x')), (tensor.exp, 'x')),
(softplus, 'x'), (softplus, 'x'),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True) allow_multiple_clients=True)
opt.register_stabilize(logsigm_to_softplus, name='logsigm_to_softplus') opt.register_stabilize(logsigm_to_softplus, name='logsigm_to_softplus')
......
...@@ -4228,7 +4228,9 @@ def test_constant_get_stabilized(): ...@@ -4228,7 +4228,9 @@ def test_constant_get_stabilized():
""" """
x2 = T.scalar() x2 = T.scalar()
y2 = T.log(1 + T.exp(x2)) y2 = T.log(1 + T.exp(x2))
f2 = theano.function([x2], y2) mode = theano.compile.get_default_mode()
mode.check_isfinite = False
f2 = theano.function([x2], y2, mode=mode)
try: try:
assert len(f2.maker.fgraph.toposort()) == 1 assert len(f2.maker.fgraph.toposort()) == 1
assert f2.maker.fgraph.toposort()[0].op == \ assert f2.maker.fgraph.toposort()[0].op == \
...@@ -4237,14 +4239,14 @@ def test_constant_get_stabilized(): ...@@ -4237,14 +4239,14 @@ def test_constant_get_stabilized():
x = T.as_tensor_variable(800) x = T.as_tensor_variable(800)
y = T.log(1 + T.exp(x)) y = T.log(1 + T.exp(x))
f = theano.function([], y) f = theano.function([], y, mode=mode)
assert len(f.maker.fgraph.toposort()) == 0 assert len(f.maker.fgraph.toposort()) == 0
assert numpy.isinf(f()) assert numpy.isinf(f())
# When this error is fixed, the following line should be ok. # When this error is fixed, the following line should be ok.
assert f() == 800, f() assert f() == 800, f()
except (AssertionError, theano.compile.debugmode.InvalidValueError): except AssertionError:
raise SkipTest('Theano optimizes constant before stabilization. ' raise SkipTest('Theano optimizes constant before stabilization. '
'This breaks stabilization optimization in some ' 'This breaks stabilization optimization in some '
'cases. See #504.') 'cases. See #504.')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论