提交 a1abed83 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4128 from caglar/fix_extract_constant

[ENH] faster opt by changing call to extract_constant and get_scalar_constant_value
...@@ -413,6 +413,7 @@ log1msigm_to_softplus = gof.PatternSub( ...@@ -413,6 +413,7 @@ log1msigm_to_softplus = gof.PatternSub(
values_eq_approx=values_eq_approx_remove_inf, values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1) skip_identities_fn=_skip_mul_1)
log1pexp_to_softplus = gof.PatternSub( log1pexp_to_softplus = gof.PatternSub(
(tensor.log1p, (tensor.log1p,
(tensor.exp, 'x')), (tensor.exp, 'x')),
...@@ -420,12 +421,20 @@ log1pexp_to_softplus = gof.PatternSub( ...@@ -420,12 +421,20 @@ log1pexp_to_softplus = gof.PatternSub(
values_eq_approx=values_eq_approx_remove_inf, values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True) allow_multiple_clients=True)
log1p_neg_sigmoid = gof.PatternSub(
(tensor.log1p,
(tensor.neg, (sigmoid, 'x'))),
(tensor.neg, (softplus, 'x')),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True)
opt.register_stabilize(logsigm_to_softplus, name='logsigm_to_softplus') opt.register_stabilize(logsigm_to_softplus, name='logsigm_to_softplus')
opt.register_stabilize(log1msigm_to_softplus, name='log1msigm_to_softplus') opt.register_stabilize(log1msigm_to_softplus, name='log1msigm_to_softplus')
opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus') opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus')
opt.register_stabilize(log1p_neg_sigmoid, name='log1p_neg_sigmoid,')
def is_1pexp(t): def is_1pexp(t, only_process_constants=True):
""" """
Returns Returns
...@@ -437,8 +446,9 @@ def is_1pexp(t): ...@@ -437,8 +446,9 @@ def is_1pexp(t):
""" """
if t.owner and t.owner.op == tensor.add: if t.owner and t.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \ scalars, scalar_inputs, nonconsts = \
opt.scalarconsts_rest(t.owner.inputs) opt.scalarconsts_rest(t.owner.inputs,
# scalar_inputs are potentially dimshuffled and fill'd scalars only_process_constants=only_process_constants)
# scalar_inputs are potentially dimshuffled and filled with scalars
if len(nonconsts) == 1: if len(nonconsts) == 1:
maybe_exp = nonconsts[0] maybe_exp = nonconsts[0]
if maybe_exp.owner and maybe_exp.owner.op == tensor.exp: if maybe_exp.owner and maybe_exp.owner.op == tensor.exp:
...@@ -947,7 +957,7 @@ def local_inv_1_plus_exp(node): ...@@ -947,7 +957,7 @@ def local_inv_1_plus_exp(node):
inv_arg = node.inputs[0] inv_arg = node.inputs[0]
if inv_arg.owner and inv_arg.owner.op == tensor.add: if inv_arg.owner and inv_arg.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \ scalars, scalar_inputs, nonconsts = \
opt.scalarconsts_rest(inv_arg.owner.inputs) opt.scalarconsts_rest(inv_arg.owner.inputs, only_process_constants=True)
# scalar_inputs are potentially dimshuffled and fill'd scalars # scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1: if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == tensor.exp: if nonconsts[0].owner and nonconsts[0].owner.op == tensor.exp:
......
...@@ -356,7 +356,6 @@ class T_sigmoid_opts(unittest.TestCase): ...@@ -356,7 +356,6 @@ class T_sigmoid_opts(unittest.TestCase):
f = theano.function([x], s, mode=mode) f = theano.function([x], s, mode=mode)
assert hasattr(f.maker.fgraph.outputs[0].tag, 'trace') assert hasattr(f.maker.fgraph.outputs[0].tag, 'trace')
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) > 1
assert not any([n.op == sigmoid for n in topo]) assert not any([n.op == sigmoid for n in topo])
ux_v = f([[-50, -10, -4, -1, 0, 1, 4, 10, 50]]) ux_v = f([[-50, -10, -4, -1, 0, 1, 4, 10, 50]])
...@@ -467,15 +466,17 @@ class T_sigmoid_utils(unittest.TestCase): ...@@ -467,15 +466,17 @@ class T_sigmoid_utils(unittest.TestCase):
try: try:
x = tensor.vector('x') x = tensor.vector('x')
exp = tensor.exp exp = tensor.exp
assert is_1pexp(1 + exp(x)) == (False, x) assert is_1pexp(1 + exp(x), False) == (False, x)
assert is_1pexp(exp(x) + 1) == (False, x) assert is_1pexp(exp(x) + 1, False) == (False, x)
for neg, exp_arg in imap(is_1pexp, [(1 + exp(-x)), (exp(-x) + 1)]): for neg, exp_arg in imap(lambda x:
is_1pexp(x, only_process_constants=False),
[(1 + exp(-x)), (exp(-x) + 1)]):
assert not neg and theano.gof.graph.is_same_graph(exp_arg, -x) assert not neg and theano.gof.graph.is_same_graph(exp_arg, -x)
assert is_1pexp(1 - exp(x)) is None assert is_1pexp(1 - exp(x), False) is None
assert is_1pexp(2 + exp(x)) is None assert is_1pexp(2 + exp(x), False) is None
assert is_1pexp(exp(x) + 2) is None assert is_1pexp(exp(x) + 2, False) is None
assert is_1pexp(exp(x) - 1) is None assert is_1pexp(exp(x) - 1, False) is None
assert is_1pexp(-1 + exp(x)) is None assert is_1pexp(-1 + exp(x), False) is None
assert is_1pexp(1 + 2 * exp(x)) is None assert is_1pexp(1 + 2 * exp(x), False) is None
finally: finally:
config.warn.identify_1pexp_bug = backup config.warn.identify_1pexp_bug = backup
差异被折叠。
...@@ -1635,8 +1635,8 @@ def test_log_add(): ...@@ -1635,8 +1635,8 @@ def test_log_add():
def test_local_useless_slice(): def test_local_useless_slice():
# test a simple matrix # test a simple matrix
x = tensor.matrix('x') x = tensor.matrix('x')
mode_unopt = compile.get_default_mode().excluding("local_useless_slice") mode_unopt = compile.get_default_mode().excluding("local_useless_slice", "local_mul_canonizer")
mode_opt = compile.get_default_mode().including("local_useless_slice") mode_opt = compile.get_default_mode().including("local_useless_slice").excluding("local_mul_canonizer")
# test with and without the useless slice # test with and without the useless slice
o = 2 * x[0, :] o = 2 * x[0, :]
...@@ -2124,7 +2124,7 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -2124,7 +2124,7 @@ class test_local_subtensor_lift(unittest.TestCase):
f1 = function([x], newx[:2, :5], mode=mode_opt) f1 = function([x], newx[:2, :5], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(check_stack_trace(f1, ops_to_check=[ self.assertTrue(check_stack_trace(f1, ops_to_check=[
Subtensor, tensor.Rebroadcast])) Subtensor, tensor.Rebroadcast]))
prog = f1.maker.fgraph.toposort() prog = f1.maker.fgraph.toposort()
assert isinstance(prog[0].op, tensor.Subtensor) assert isinstance(prog[0].op, tensor.Subtensor)
assert isinstance(prog[1].op, tensor.Rebroadcast) assert isinstance(prog[1].op, tensor.Rebroadcast)
...@@ -2140,7 +2140,7 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -2140,7 +2140,7 @@ class test_local_subtensor_lift(unittest.TestCase):
f2 = function([y], newy[:, 3, 0, :], mode=mode_opt) f2 = function([y], newy[:, 3, 0, :], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(check_stack_trace(f2, ops_to_check=[ self.assertTrue(check_stack_trace(f2, ops_to_check=[
Subtensor, tensor.Rebroadcast])) Subtensor, tensor.Rebroadcast]))
prog = f2.maker.fgraph.toposort() prog = f2.maker.fgraph.toposort()
assert isinstance(prog[0].op, tensor.Subtensor) assert isinstance(prog[0].op, tensor.Subtensor)
assert isinstance(prog[1].op, tensor.Rebroadcast) assert isinstance(prog[1].op, tensor.Rebroadcast)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论