提交 a1abed83 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4128 from caglar/fix_extract_constant

[ENH] faster opt by changing call to extract_constant and get_scalar_constant_value
......@@ -413,6 +413,7 @@ log1msigm_to_softplus = gof.PatternSub(
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1)
log1pexp_to_softplus = gof.PatternSub(
(tensor.log1p,
(tensor.exp, 'x')),
......@@ -420,12 +421,20 @@ log1pexp_to_softplus = gof.PatternSub(
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True)
log1p_neg_sigmoid = gof.PatternSub(
(tensor.log1p,
(tensor.neg, (sigmoid, 'x'))),
(tensor.neg, (softplus, 'x')),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True)
opt.register_stabilize(logsigm_to_softplus, name='logsigm_to_softplus')
opt.register_stabilize(log1msigm_to_softplus, name='log1msigm_to_softplus')
opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus')
opt.register_stabilize(log1p_neg_sigmoid, name='log1p_neg_sigmoid,')
def is_1pexp(t):
def is_1pexp(t, only_process_constants=True):
"""
Returns
......@@ -437,8 +446,9 @@ def is_1pexp(t):
"""
if t.owner and t.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \
opt.scalarconsts_rest(t.owner.inputs)
# scalar_inputs are potentially dimshuffled and fill'd scalars
opt.scalarconsts_rest(t.owner.inputs,
only_process_constants=only_process_constants)
# scalar_inputs are potentially dimshuffled and filled with scalars
if len(nonconsts) == 1:
maybe_exp = nonconsts[0]
if maybe_exp.owner and maybe_exp.owner.op == tensor.exp:
......@@ -947,7 +957,7 @@ def local_inv_1_plus_exp(node):
inv_arg = node.inputs[0]
if inv_arg.owner and inv_arg.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \
opt.scalarconsts_rest(inv_arg.owner.inputs)
opt.scalarconsts_rest(inv_arg.owner.inputs, only_process_constants=True)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == tensor.exp:
......
......@@ -356,7 +356,6 @@ class T_sigmoid_opts(unittest.TestCase):
f = theano.function([x], s, mode=mode)
assert hasattr(f.maker.fgraph.outputs[0].tag, 'trace')
topo = f.maker.fgraph.toposort()
assert len(topo) > 1
assert not any([n.op == sigmoid for n in topo])
ux_v = f([[-50, -10, -4, -1, 0, 1, 4, 10, 50]])
......@@ -467,15 +466,17 @@ class T_sigmoid_utils(unittest.TestCase):
try:
x = tensor.vector('x')
exp = tensor.exp
assert is_1pexp(1 + exp(x)) == (False, x)
assert is_1pexp(exp(x) + 1) == (False, x)
for neg, exp_arg in imap(is_1pexp, [(1 + exp(-x)), (exp(-x) + 1)]):
assert is_1pexp(1 + exp(x), False) == (False, x)
assert is_1pexp(exp(x) + 1, False) == (False, x)
for neg, exp_arg in imap(lambda x:
is_1pexp(x, only_process_constants=False),
[(1 + exp(-x)), (exp(-x) + 1)]):
assert not neg and theano.gof.graph.is_same_graph(exp_arg, -x)
assert is_1pexp(1 - exp(x)) is None
assert is_1pexp(2 + exp(x)) is None
assert is_1pexp(exp(x) + 2) is None
assert is_1pexp(exp(x) - 1) is None
assert is_1pexp(-1 + exp(x)) is None
assert is_1pexp(1 + 2 * exp(x)) is None
assert is_1pexp(1 - exp(x), False) is None
assert is_1pexp(2 + exp(x), False) is None
assert is_1pexp(exp(x) + 2, False) is None
assert is_1pexp(exp(x) - 1, False) is None
assert is_1pexp(-1 + exp(x), False) is None
assert is_1pexp(1 + 2 * exp(x), False) is None
finally:
config.warn.identify_1pexp_bug = backup
......@@ -126,7 +126,7 @@ def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)]
def scalarconsts_rest(inputs):
def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
"""Partition a list of variables into two kinds:
scalar constants, and the rest."""
consts = []
......@@ -134,7 +134,8 @@ def scalarconsts_rest(inputs):
nonconsts = []
for i in inputs:
try:
v = get_scalar_constant_value(i)
v = get_scalar_constant_value(i, elemwise=elemwise,
only_process_constants=only_process_constants)
consts.append(v)
origconsts.append(i)
except NotScalarConstantError:
......@@ -448,8 +449,9 @@ def register_uncanonicalize(lopt, *tags, **kwargs):
return register_uncanonicalize(inner_lopt, lopt, *tags, **kwargs)
return register
else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['uncanonicalize'].register(name, lopt, 'fast_run', *tags)
name = (kwargs and kwargs.pop('name', None)) or lopt.__name__
compile.optdb['uncanonicalize'].register(name, lopt, 'fast_run', *tags,
**kwargs)
return lopt
......@@ -459,8 +461,9 @@ def register_specialize_device(lopt, *tags, **kwargs):
return register_specialize_device(inner_lopt, lopt, *tags, **kwargs)
return register
else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['specialize_device'].register(name, lopt, 'fast_run', *tags)
name = (kwargs and kwargs.pop('name', None)) or lopt.__name__
compile.optdb['specialize_device'].register(name, lopt, 'fast_run', *tags,
**kwargs)
return lopt
......@@ -479,13 +482,13 @@ def local_0_dot_x(node):
y = node.inputs[1]
replace = False
try:
if get_scalar_constant_value(x) == 0:
if get_scalar_constant_value(x, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
try:
if get_scalar_constant_value(y) == 0:
if get_scalar_constant_value(y, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
......@@ -1196,7 +1199,7 @@ class ShapeFeature(object):
# But we never timed this speed optimization!
self.lscalar_one.equals(merged_shape[i]) or
self.lscalar_one.equals(
T.extract_constant(merged_shape[i]))
T.extract_constant(merged_shape[i], only_process_constants=True))
for i in xrange(r.ndim)])
self.shape_of[r] = tuple(merged_shape)
for sv in self.shape_of[r]:
......@@ -1893,7 +1896,7 @@ def local_subtensor_make_vector(node):
if idx.ndim == 0:
# if it is a constant we can do something with it
try:
v = get_scalar_constant_value(idx)
v = get_scalar_constant_value(idx, only_process_constants=True)
if isinstance(v, numpy.integer):
# Python 2.4 wants to index only with Python integers
v = int(v)
......@@ -1998,7 +2001,7 @@ def local_useless_elemwise(node):
len(node.inputs) == 2):
if isinstance(node.inputs[0], T.TensorConstant):
const_val = T.extract_constant(node.inputs[0])
const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
if not isinstance(const_val, Variable):
if const_val == 0:
return zeros_like(node, 1)
......@@ -2006,7 +2009,7 @@ def local_useless_elemwise(node):
return [node.inputs[1]]
if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1])
const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
if not isinstance(const_val, Variable):
if const_val == 0:
return zeros_like(node, 0)
......@@ -2017,7 +2020,7 @@ def local_useless_elemwise(node):
len(node.inputs) == 2):
if isinstance(node.inputs[0], T.TensorConstant):
const_val = T.extract_constant(node.inputs[0])
const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
if not isinstance(const_val, Variable):
if const_val == 0:
return [node.inputs[1]]
......@@ -2025,7 +2028,7 @@ def local_useless_elemwise(node):
return ones_like(node, 1)
if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1])
const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
if not isinstance(const_val, Variable):
if const_val == 0:
return [node.inputs[0]]
......@@ -2317,7 +2320,8 @@ def local_upcast_elemwise_constant_inputs(node):
else:
try:
# works only for scalars
cval_i = get_scalar_constant_value(i, elemwise=False)
cval_i = get_scalar_constant_value(i,
only_process_constants=True)
if all(i.broadcastable):
new_inputs.append(T.shape_padleft(
T.cast(cval_i, output_dtype),
......@@ -2372,7 +2376,8 @@ def local_useless_inc_subtensor(node):
if node.op.set_instead_of_inc is False:
# This is an IncSubtensor, so the init value must be zeros
try:
c = get_scalar_constant_value(node.inputs[0])
c = get_scalar_constant_value(node.inputs[0],
only_process_constants=True)
if c != 0:
return
except NotScalarConstantError:
......@@ -2389,7 +2394,8 @@ def local_useless_inc_subtensor(node):
# Put the constant inputs in the slice.
idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list)
if all(isinstance(e, slice) and e.start is None and
e.stop is None and (e.step is None or T.extract_constant(e.step) == -1)
e.stop is None and (e.step is None or T.extract_constant(e.step,
only_process_constants=True) == -1)
for e in idx_cst):
# IncSubtensor broadcast node.inputs[1] on node.inputs[0]
# based on run time shapes, so we must check they are the same.
......@@ -2459,7 +2465,8 @@ def local_useless_slice(node):
for s in slices[::-1]:
# check if slice and then check slice indices
if (isinstance(s, slice) and s.start is None and s.stop is None and
(s.step is None or T.extract_constant(s.step) == 1)):
(s.step is None or T.extract_constant(s.step,
only_process_constants=True) == 1)):
last_slice -= 1
else:
break
......@@ -2515,7 +2522,8 @@ def local_useless_subtensor(node):
if isinstance(idx.stop, (integer_types, numpy.integer)):
length_pos_data = sys.maxsize
try:
length_pos_data = get_scalar_constant_value(length_pos)
length_pos_data = get_scalar_constant_value(length_pos,
only_process_constants=True)
except NotScalarConstantError:
pass
......@@ -2555,7 +2563,8 @@ def local_useless_subtensor(node):
elif isinstance(node.op, AdvancedSubtensor1):
# get length of the indexed tensor along the first axis
try:
length = get_scalar_constant_value(shape_of[node.inputs[0]][0])
length = get_scalar_constant_value(shape_of[node.inputs[0]][0],
only_process_constants=True)
except NotScalarConstantError:
return False
......@@ -2572,7 +2581,8 @@ def local_useless_subtensor(node):
return False
elif idx.owner is not None and isinstance(idx.owner.op, T.ARange):
try:
start, stop, step = map(get_scalar_constant_value,
start, stop, step = map(lambda x: get_scalar_constant_value(x,
only_process_constants=True),
idx.owner.inputs)
except NotScalarConstantError:
return False
......@@ -3195,19 +3205,15 @@ def local_incsubtensor_of_zeros(node):
not node.op.set_instead_of_inc):
x = node.inputs[0]
y = node.inputs[1]
replace = False
try:
if get_scalar_constant_value(y) == 0:
replace = True
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
if get_scalar_constant_value(y, elemwise=False) == 0:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return [x]
except NotScalarConstantError:
pass
if replace:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return [x]
else:
return False
return
@register_canonicalize('local_setsubtensor_of_allocs')
......@@ -3223,22 +3229,20 @@ def local_setsubtensor_of_constants(node):
if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc:
x = node.inputs[0]
y = node.inputs[1]
replace_x = None
replace_y = None
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
try:
replace_x = get_scalar_constant_value(x)
replace_x = get_scalar_constant_value(x, elemwise=False)
except NotScalarConstantError:
pass
return
try:
replace_y = get_scalar_constant_value(y)
replace_y = get_scalar_constant_value(y, elemwise=False)
except NotScalarConstantError:
pass
return
if (replace_x is not None and
replace_y is not None and
replace_x == replace_y):
if replace_x == replace_y:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
......@@ -3276,7 +3280,9 @@ def local_adv_sub1_adv_inc_sub1(node):
if idx is not idx2:
return
if (not inp.owner.op.set_instead_of_inc and
T.extract_constant(x) != 0):
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
T.extract_constant(x, elemwise=False) != 0):
return
cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))]
if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0):
......@@ -3568,7 +3574,8 @@ def local_join_empty(node):
return
new_inputs = []
try:
join_idx = get_scalar_constant_value(node.inputs[0])
join_idx = get_scalar_constant_value(node.inputs[0],
only_process_constants=True)
except NotScalarConstantError:
return
for idx in xrange(1, len(node.inputs)):
......@@ -3727,8 +3734,10 @@ def local_useless_switch(node):
"""
if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0], elemwise=False)
if type(cond) is numpy.ndarray and cond.ndim == 0:
cond = T.extract_constant(node.inputs[0],
only_process_constants=True)
if ((type(cond) is numpy.ndarray and cond.ndim == 0) or
isinstance(cond, numpy.number)):
if cond == 0:
correct_out = node.inputs[2]
else:
......@@ -3775,8 +3784,8 @@ def local_useless_switch(node):
isinstance(cond_var.owner.op.scalar_op, scalar.LE) and \
cond_var.owner.inputs[0].owner and \
isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) and \
T.extract_constant(cond_var.owner.inputs[1]) == 0 and \
T.extract_constant(left) == 0 and \
T.extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0 and \
T.extract_constant(left, only_process_constants=True) == 0 and \
right is cond_var.owner.inputs[0]:
assert right.type == node.outputs[0].type
# No need to copy over stacktrace, because the right input node
......@@ -3889,7 +3898,8 @@ def local_div_switch_sink(node):
if node.inputs[0].owner and node.inputs[0].owner.op == T.switch:
switch = node.inputs[0].owner
try:
if get_scalar_constant_value(switch.inputs[1]) == 0.:
if get_scalar_constant_value(switch.inputs[1],
only_process_constants=True) == 0.:
fdiv = op(switch.inputs[2], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
......@@ -3911,7 +3921,8 @@ def local_div_switch_sink(node):
except NotScalarConstantError:
pass
try:
if get_scalar_constant_value(switch.inputs[2]) == 0.:
if get_scalar_constant_value(switch.inputs[2],
only_process_constants=True) == 0.:
fdiv = op(switch.inputs[1], node.inputs[1])
# Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op.
......@@ -3976,7 +3987,8 @@ def local_useless_tile(node):
"""
if isinstance(node.op, T.Tile):
try:
a = T.get_scalar_constant_value(node.inputs[1])
a = T.get_scalar_constant_value(node.inputs[1],
only_process_constants=True)
if a == 1:
try:
l = T.get_vector_length(node.inputs[1])
......@@ -4159,7 +4171,8 @@ if 0:
def tmp(thing):
try:
return T.get_scalar_constant_value(thing)
return T.get_scalar_constant_value(thing,
only_process_constants=True)
except (TypeError, ValueError) as e:
print(e, thing.owner.inputs[0])
return None
......@@ -5156,7 +5169,7 @@ def local_reduce_join(node):
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Join)):
join = node.inputs[0].owner
if T.extract_constant(join.inputs[0]) != 0:
if T.extract_constant(join.inputs[0], only_process_constants=True) != 0:
return
if isinstance(node.op.scalar_op, (scalar.Maximum, scalar.Minimum)):
......@@ -5206,7 +5219,9 @@ def local_reduce_join(node):
# We add the new check late to don't add extra warning.
try:
join_axis = get_scalar_constant_value(join.inputs[0])
join_axis = get_scalar_constant_value(join.inputs[0],
only_process_constants=True)
if join_axis != reduce_axis[0]:
return
except NotScalarConstantError:
......@@ -5288,7 +5303,8 @@ def local_opt_alloc(node):
if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))):
try:
val = get_scalar_constant_value(input)
val = get_scalar_constant_value(input,
only_process_constants=True)
assert val.size == 1
# check which type of op
casted = T.mul(*shapes).astype(str(input.dtype))
......@@ -5302,7 +5318,8 @@ def local_opt_alloc(node):
pass
else:
try:
val = get_scalar_constant_value(input)
val = get_scalar_constant_value(input,
only_process_constants=True)
assert val.size == 1
val = val.reshape(1)[0]
to_prod = [shapes[i] for i in xrange(len(shapes))
......@@ -5746,7 +5763,8 @@ def local_abs_merge(node):
inputs.append(i.owner.inputs[0])
elif isinstance(i, Constant):
try:
const = get_scalar_constant_value(i)
const = get_scalar_constant_value(i,
only_process_constants=True)
except NotScalarConstantError:
return False
if not (const >= 0).all():
......@@ -5766,12 +5784,12 @@ def local_abs_merge(node):
@gof.local_optimizer([T.log])
def local_log1p(node):
# log(1+x) -> log1p(x)
# log(1-x) -> log1p(-x)
if node.op == T.log:
log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add:
scalars, scalar_inputs, nonconsts = scalarconsts_rest(
log_arg.owner.inputs)
log_arg.owner.inputs, only_process_constants=True)
# scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and numpy.allclose(numpy.sum(scalars), 1):
if not nonconsts:
......@@ -5782,6 +5800,13 @@ def local_log1p(node):
return _fill_chain(T.log1p(T.add(*nonconsts)),
scalar_inputs)
elif log_arg.owner and log_arg.owner.op == T.sub:
one = T.extract_constant(log_arg.owner.inputs[0],
only_process_constants=True)
if one != 1:
return
return [T.log1p(T.neg(log_arg.owner.inputs[1]))]
# TODO: in canonicalize, change log10 and log2 -> log
@register_stabilize
......@@ -6017,7 +6042,6 @@ def constant_folding(node):
required = thunk()
assert not required # a node whose inputs are all provided should always
# return successfully
rval = []
for output in node.outputs:
assert compute_map[output][0], (output, storage_map[output][0])
......@@ -6036,6 +6060,7 @@ def constant_folding(node):
topo_constant_folding = in2out(constant_folding, ignore_newtrees=True,
name="topo_constant_folding")
register_canonicalize(topo_constant_folding, 'fast_compile', final_opt=True)
register_uncanonicalize(topo_constant_folding, 'fast_compile', final_opt=True)
register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True)
register_specialize(topo_constant_folding, 'fast_compile', final_opt=True)
......@@ -6328,7 +6353,8 @@ def local_grad_log_erfc_neg(node):
mul_neg = T.mul(*mul_inputs)
try:
cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0])
cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0],
only_process_constants=True)
except NotScalarConstantError:
return False
......@@ -6355,7 +6381,8 @@ def local_grad_log_erfc_neg(node):
x = erfc_x
try:
cst = get_scalar_constant_value(erfc_x.owner.inputs[0])
cst = get_scalar_constant_value(erfc_x.owner.inputs[0],
only_process_constants=True)
except NotScalarConstantError:
return False
if cst2 != -cst * 2:
......
......@@ -1635,8 +1635,8 @@ def test_log_add():
def test_local_useless_slice():
# test a simple matrix
x = tensor.matrix('x')
mode_unopt = compile.get_default_mode().excluding("local_useless_slice")
mode_opt = compile.get_default_mode().including("local_useless_slice")
mode_unopt = compile.get_default_mode().excluding("local_useless_slice", "local_mul_canonizer")
mode_opt = compile.get_default_mode().including("local_useless_slice").excluding("local_mul_canonizer")
# test with and without the useless slice
o = 2 * x[0, :]
......@@ -2124,7 +2124,7 @@ class test_local_subtensor_lift(unittest.TestCase):
f1 = function([x], newx[:2, :5], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(check_stack_trace(f1, ops_to_check=[
Subtensor, tensor.Rebroadcast]))
Subtensor, tensor.Rebroadcast]))
prog = f1.maker.fgraph.toposort()
assert isinstance(prog[0].op, tensor.Subtensor)
assert isinstance(prog[1].op, tensor.Rebroadcast)
......@@ -2140,7 +2140,7 @@ class test_local_subtensor_lift(unittest.TestCase):
f2 = function([y], newy[:, 3, 0, :], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
self.assertTrue(check_stack_trace(f2, ops_to_check=[
Subtensor, tensor.Rebroadcast]))
Subtensor, tensor.Rebroadcast]))
prog = f2.maker.fgraph.toposort()
assert isinstance(prog[0].op, tensor.Subtensor)
assert isinstance(prog[1].op, tensor.Rebroadcast)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论