提交 a1abed83 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4128 from caglar/fix_extract_constant

[ENH] faster opt by changing call to extract_constant and get_scalar_constant_value
...@@ -413,6 +413,7 @@ log1msigm_to_softplus = gof.PatternSub( ...@@ -413,6 +413,7 @@ log1msigm_to_softplus = gof.PatternSub(
values_eq_approx=values_eq_approx_remove_inf, values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1) skip_identities_fn=_skip_mul_1)
log1pexp_to_softplus = gof.PatternSub( log1pexp_to_softplus = gof.PatternSub(
(tensor.log1p, (tensor.log1p,
(tensor.exp, 'x')), (tensor.exp, 'x')),
...@@ -420,12 +421,20 @@ log1pexp_to_softplus = gof.PatternSub( ...@@ -420,12 +421,20 @@ log1pexp_to_softplus = gof.PatternSub(
values_eq_approx=values_eq_approx_remove_inf, values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True) allow_multiple_clients=True)
log1p_neg_sigmoid = gof.PatternSub(
(tensor.log1p,
(tensor.neg, (sigmoid, 'x'))),
(tensor.neg, (softplus, 'x')),
values_eq_approx=values_eq_approx_remove_inf,
allow_multiple_clients=True)
opt.register_stabilize(logsigm_to_softplus, name='logsigm_to_softplus') opt.register_stabilize(logsigm_to_softplus, name='logsigm_to_softplus')
opt.register_stabilize(log1msigm_to_softplus, name='log1msigm_to_softplus') opt.register_stabilize(log1msigm_to_softplus, name='log1msigm_to_softplus')
opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus') opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus')
opt.register_stabilize(log1p_neg_sigmoid, name='log1p_neg_sigmoid,')
def is_1pexp(t): def is_1pexp(t, only_process_constants=True):
""" """
Returns Returns
...@@ -437,8 +446,9 @@ def is_1pexp(t): ...@@ -437,8 +446,9 @@ def is_1pexp(t):
""" """
if t.owner and t.owner.op == tensor.add: if t.owner and t.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \ scalars, scalar_inputs, nonconsts = \
opt.scalarconsts_rest(t.owner.inputs) opt.scalarconsts_rest(t.owner.inputs,
# scalar_inputs are potentially dimshuffled and fill'd scalars only_process_constants=only_process_constants)
# scalar_inputs are potentially dimshuffled and filled with scalars
if len(nonconsts) == 1: if len(nonconsts) == 1:
maybe_exp = nonconsts[0] maybe_exp = nonconsts[0]
if maybe_exp.owner and maybe_exp.owner.op == tensor.exp: if maybe_exp.owner and maybe_exp.owner.op == tensor.exp:
...@@ -947,7 +957,7 @@ def local_inv_1_plus_exp(node): ...@@ -947,7 +957,7 @@ def local_inv_1_plus_exp(node):
inv_arg = node.inputs[0] inv_arg = node.inputs[0]
if inv_arg.owner and inv_arg.owner.op == tensor.add: if inv_arg.owner and inv_arg.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \ scalars, scalar_inputs, nonconsts = \
opt.scalarconsts_rest(inv_arg.owner.inputs) opt.scalarconsts_rest(inv_arg.owner.inputs, only_process_constants=True)
# scalar_inputs are potentially dimshuffled and fill'd scalars # scalar_inputs are potentially dimshuffled and fill'd scalars
if len(nonconsts) == 1: if len(nonconsts) == 1:
if nonconsts[0].owner and nonconsts[0].owner.op == tensor.exp: if nonconsts[0].owner and nonconsts[0].owner.op == tensor.exp:
......
...@@ -356,7 +356,6 @@ class T_sigmoid_opts(unittest.TestCase): ...@@ -356,7 +356,6 @@ class T_sigmoid_opts(unittest.TestCase):
f = theano.function([x], s, mode=mode) f = theano.function([x], s, mode=mode)
assert hasattr(f.maker.fgraph.outputs[0].tag, 'trace') assert hasattr(f.maker.fgraph.outputs[0].tag, 'trace')
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert len(topo) > 1
assert not any([n.op == sigmoid for n in topo]) assert not any([n.op == sigmoid for n in topo])
ux_v = f([[-50, -10, -4, -1, 0, 1, 4, 10, 50]]) ux_v = f([[-50, -10, -4, -1, 0, 1, 4, 10, 50]])
...@@ -467,15 +466,17 @@ class T_sigmoid_utils(unittest.TestCase): ...@@ -467,15 +466,17 @@ class T_sigmoid_utils(unittest.TestCase):
try: try:
x = tensor.vector('x') x = tensor.vector('x')
exp = tensor.exp exp = tensor.exp
assert is_1pexp(1 + exp(x)) == (False, x) assert is_1pexp(1 + exp(x), False) == (False, x)
assert is_1pexp(exp(x) + 1) == (False, x) assert is_1pexp(exp(x) + 1, False) == (False, x)
for neg, exp_arg in imap(is_1pexp, [(1 + exp(-x)), (exp(-x) + 1)]): for neg, exp_arg in imap(lambda x:
is_1pexp(x, only_process_constants=False),
[(1 + exp(-x)), (exp(-x) + 1)]):
assert not neg and theano.gof.graph.is_same_graph(exp_arg, -x) assert not neg and theano.gof.graph.is_same_graph(exp_arg, -x)
assert is_1pexp(1 - exp(x)) is None assert is_1pexp(1 - exp(x), False) is None
assert is_1pexp(2 + exp(x)) is None assert is_1pexp(2 + exp(x), False) is None
assert is_1pexp(exp(x) + 2) is None assert is_1pexp(exp(x) + 2, False) is None
assert is_1pexp(exp(x) - 1) is None assert is_1pexp(exp(x) - 1, False) is None
assert is_1pexp(-1 + exp(x)) is None assert is_1pexp(-1 + exp(x), False) is None
assert is_1pexp(1 + 2 * exp(x)) is None assert is_1pexp(1 + 2 * exp(x), False) is None
finally: finally:
config.warn.identify_1pexp_bug = backup config.warn.identify_1pexp_bug = backup
...@@ -126,7 +126,7 @@ def merge_broadcastables(broadcastables): ...@@ -126,7 +126,7 @@ def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)] return [all(bcast) for bcast in zip(*broadcastables)]
def scalarconsts_rest(inputs): def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
"""Partition a list of variables into two kinds: """Partition a list of variables into two kinds:
scalar constants, and the rest.""" scalar constants, and the rest."""
consts = [] consts = []
...@@ -134,7 +134,8 @@ def scalarconsts_rest(inputs): ...@@ -134,7 +134,8 @@ def scalarconsts_rest(inputs):
nonconsts = [] nonconsts = []
for i in inputs: for i in inputs:
try: try:
v = get_scalar_constant_value(i) v = get_scalar_constant_value(i, elemwise=elemwise,
only_process_constants=only_process_constants)
consts.append(v) consts.append(v)
origconsts.append(i) origconsts.append(i)
except NotScalarConstantError: except NotScalarConstantError:
...@@ -448,8 +449,9 @@ def register_uncanonicalize(lopt, *tags, **kwargs): ...@@ -448,8 +449,9 @@ def register_uncanonicalize(lopt, *tags, **kwargs):
return register_uncanonicalize(inner_lopt, lopt, *tags, **kwargs) return register_uncanonicalize(inner_lopt, lopt, *tags, **kwargs)
return register return register
else: else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name', None)) or lopt.__name__
compile.optdb['uncanonicalize'].register(name, lopt, 'fast_run', *tags) compile.optdb['uncanonicalize'].register(name, lopt, 'fast_run', *tags,
**kwargs)
return lopt return lopt
...@@ -459,8 +461,9 @@ def register_specialize_device(lopt, *tags, **kwargs): ...@@ -459,8 +461,9 @@ def register_specialize_device(lopt, *tags, **kwargs):
return register_specialize_device(inner_lopt, lopt, *tags, **kwargs) return register_specialize_device(inner_lopt, lopt, *tags, **kwargs)
return register return register
else: else:
name = (kwargs and kwargs.pop('name')) or lopt.__name__ name = (kwargs and kwargs.pop('name', None)) or lopt.__name__
compile.optdb['specialize_device'].register(name, lopt, 'fast_run', *tags) compile.optdb['specialize_device'].register(name, lopt, 'fast_run', *tags,
**kwargs)
return lopt return lopt
...@@ -479,13 +482,13 @@ def local_0_dot_x(node): ...@@ -479,13 +482,13 @@ def local_0_dot_x(node):
y = node.inputs[1] y = node.inputs[1]
replace = False replace = False
try: try:
if get_scalar_constant_value(x) == 0: if get_scalar_constant_value(x, only_process_constants=True) == 0:
replace = True replace = True
except NotScalarConstantError: except NotScalarConstantError:
pass pass
try: try:
if get_scalar_constant_value(y) == 0: if get_scalar_constant_value(y, only_process_constants=True) == 0:
replace = True replace = True
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -1196,7 +1199,7 @@ class ShapeFeature(object): ...@@ -1196,7 +1199,7 @@ class ShapeFeature(object):
# But we never timed this speed optimization! # But we never timed this speed optimization!
self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals(merged_shape[i]) or
self.lscalar_one.equals( self.lscalar_one.equals(
T.extract_constant(merged_shape[i])) T.extract_constant(merged_shape[i], only_process_constants=True))
for i in xrange(r.ndim)]) for i in xrange(r.ndim)])
self.shape_of[r] = tuple(merged_shape) self.shape_of[r] = tuple(merged_shape)
for sv in self.shape_of[r]: for sv in self.shape_of[r]:
...@@ -1893,7 +1896,7 @@ def local_subtensor_make_vector(node): ...@@ -1893,7 +1896,7 @@ def local_subtensor_make_vector(node):
if idx.ndim == 0: if idx.ndim == 0:
# if it is a constant we can do something with it # if it is a constant we can do something with it
try: try:
v = get_scalar_constant_value(idx) v = get_scalar_constant_value(idx, only_process_constants=True)
if isinstance(v, numpy.integer): if isinstance(v, numpy.integer):
# Python 2.4 wants to index only with Python integers # Python 2.4 wants to index only with Python integers
v = int(v) v = int(v)
...@@ -1998,7 +2001,7 @@ def local_useless_elemwise(node): ...@@ -1998,7 +2001,7 @@ def local_useless_elemwise(node):
len(node.inputs) == 2): len(node.inputs) == 2):
if isinstance(node.inputs[0], T.TensorConstant): if isinstance(node.inputs[0], T.TensorConstant):
const_val = T.extract_constant(node.inputs[0]) const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return zeros_like(node, 1) return zeros_like(node, 1)
...@@ -2006,7 +2009,7 @@ def local_useless_elemwise(node): ...@@ -2006,7 +2009,7 @@ def local_useless_elemwise(node):
return [node.inputs[1]] return [node.inputs[1]]
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1]) const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return zeros_like(node, 0) return zeros_like(node, 0)
...@@ -2017,7 +2020,7 @@ def local_useless_elemwise(node): ...@@ -2017,7 +2020,7 @@ def local_useless_elemwise(node):
len(node.inputs) == 2): len(node.inputs) == 2):
if isinstance(node.inputs[0], T.TensorConstant): if isinstance(node.inputs[0], T.TensorConstant):
const_val = T.extract_constant(node.inputs[0]) const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[1]] return [node.inputs[1]]
...@@ -2025,7 +2028,7 @@ def local_useless_elemwise(node): ...@@ -2025,7 +2028,7 @@ def local_useless_elemwise(node):
return ones_like(node, 1) return ones_like(node, 1)
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1]) const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[0]] return [node.inputs[0]]
...@@ -2317,7 +2320,8 @@ def local_upcast_elemwise_constant_inputs(node): ...@@ -2317,7 +2320,8 @@ def local_upcast_elemwise_constant_inputs(node):
else: else:
try: try:
# works only for scalars # works only for scalars
cval_i = get_scalar_constant_value(i, elemwise=False) cval_i = get_scalar_constant_value(i,
only_process_constants=True)
if all(i.broadcastable): if all(i.broadcastable):
new_inputs.append(T.shape_padleft( new_inputs.append(T.shape_padleft(
T.cast(cval_i, output_dtype), T.cast(cval_i, output_dtype),
...@@ -2372,7 +2376,8 @@ def local_useless_inc_subtensor(node): ...@@ -2372,7 +2376,8 @@ def local_useless_inc_subtensor(node):
if node.op.set_instead_of_inc is False: if node.op.set_instead_of_inc is False:
# This is an IncSubtensor, so the init value must be zeros # This is an IncSubtensor, so the init value must be zeros
try: try:
c = get_scalar_constant_value(node.inputs[0]) c = get_scalar_constant_value(node.inputs[0],
only_process_constants=True)
if c != 0: if c != 0:
return return
except NotScalarConstantError: except NotScalarConstantError:
...@@ -2389,7 +2394,8 @@ def local_useless_inc_subtensor(node): ...@@ -2389,7 +2394,8 @@ def local_useless_inc_subtensor(node):
# Put the constant inputs in the slice. # Put the constant inputs in the slice.
idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list) idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list)
if all(isinstance(e, slice) and e.start is None and if all(isinstance(e, slice) and e.start is None and
e.stop is None and (e.step is None or T.extract_constant(e.step) == -1) e.stop is None and (e.step is None or T.extract_constant(e.step,
only_process_constants=True) == -1)
for e in idx_cst): for e in idx_cst):
# IncSubtensor broadcast node.inputs[1] on node.inputs[0] # IncSubtensor broadcast node.inputs[1] on node.inputs[0]
# based on run time shapes, so we must check they are the same. # based on run time shapes, so we must check they are the same.
...@@ -2459,7 +2465,8 @@ def local_useless_slice(node): ...@@ -2459,7 +2465,8 @@ def local_useless_slice(node):
for s in slices[::-1]: for s in slices[::-1]:
# check if slice and then check slice indices # check if slice and then check slice indices
if (isinstance(s, slice) and s.start is None and s.stop is None and if (isinstance(s, slice) and s.start is None and s.stop is None and
(s.step is None or T.extract_constant(s.step) == 1)): (s.step is None or T.extract_constant(s.step,
only_process_constants=True) == 1)):
last_slice -= 1 last_slice -= 1
else: else:
break break
...@@ -2515,7 +2522,8 @@ def local_useless_subtensor(node): ...@@ -2515,7 +2522,8 @@ def local_useless_subtensor(node):
if isinstance(idx.stop, (integer_types, numpy.integer)): if isinstance(idx.stop, (integer_types, numpy.integer)):
length_pos_data = sys.maxsize length_pos_data = sys.maxsize
try: try:
length_pos_data = get_scalar_constant_value(length_pos) length_pos_data = get_scalar_constant_value(length_pos,
only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -2555,7 +2563,8 @@ def local_useless_subtensor(node): ...@@ -2555,7 +2563,8 @@ def local_useless_subtensor(node):
elif isinstance(node.op, AdvancedSubtensor1): elif isinstance(node.op, AdvancedSubtensor1):
# get length of the indexed tensor along the first axis # get length of the indexed tensor along the first axis
try: try:
length = get_scalar_constant_value(shape_of[node.inputs[0]][0]) length = get_scalar_constant_value(shape_of[node.inputs[0]][0],
only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
...@@ -2572,7 +2581,8 @@ def local_useless_subtensor(node): ...@@ -2572,7 +2581,8 @@ def local_useless_subtensor(node):
return False return False
elif idx.owner is not None and isinstance(idx.owner.op, T.ARange): elif idx.owner is not None and isinstance(idx.owner.op, T.ARange):
try: try:
start, stop, step = map(get_scalar_constant_value, start, stop, step = map(lambda x: get_scalar_constant_value(x,
only_process_constants=True),
idx.owner.inputs) idx.owner.inputs)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
...@@ -3195,19 +3205,15 @@ def local_incsubtensor_of_zeros(node): ...@@ -3195,19 +3205,15 @@ def local_incsubtensor_of_zeros(node):
not node.op.set_instead_of_inc): not node.op.set_instead_of_inc):
x = node.inputs[0] x = node.inputs[0]
y = node.inputs[1] y = node.inputs[1]
replace = False
try: try:
if get_scalar_constant_value(y) == 0: # Don't use only_process_constants=True. We need to
replace = True # investigate Alloc of 0s but with non constant shape.
if get_scalar_constant_value(y, elemwise=False) == 0:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return [x]
except NotScalarConstantError: except NotScalarConstantError:
pass return
if replace:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return [x]
else:
return False
@register_canonicalize('local_setsubtensor_of_allocs') @register_canonicalize('local_setsubtensor_of_allocs')
...@@ -3223,22 +3229,20 @@ def local_setsubtensor_of_constants(node): ...@@ -3223,22 +3229,20 @@ def local_setsubtensor_of_constants(node):
if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc: if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc:
x = node.inputs[0] x = node.inputs[0]
y = node.inputs[1] y = node.inputs[1]
replace_x = None
replace_y = None
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
try: try:
replace_x = get_scalar_constant_value(x) replace_x = get_scalar_constant_value(x, elemwise=False)
except NotScalarConstantError: except NotScalarConstantError:
pass return
try: try:
replace_y = get_scalar_constant_value(y) replace_y = get_scalar_constant_value(y, elemwise=False)
except NotScalarConstantError: except NotScalarConstantError:
pass return
if (replace_x is not None and if replace_x == replace_y:
replace_y is not None and
replace_x == replace_y):
# No need to copy over the stacktrace, # No need to copy over the stacktrace,
# because x should already have a stacktrace # because x should already have a stacktrace
...@@ -3276,7 +3280,9 @@ def local_adv_sub1_adv_inc_sub1(node): ...@@ -3276,7 +3280,9 @@ def local_adv_sub1_adv_inc_sub1(node):
if idx is not idx2: if idx is not idx2:
return return
if (not inp.owner.op.set_instead_of_inc and if (not inp.owner.op.set_instead_of_inc and
T.extract_constant(x) != 0): # Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
T.extract_constant(x, elemwise=False) != 0):
return return
cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))] cond = [T.all(T.and_(T.lt(idx, x.shape[0]), T.ge(idx, -x.shape[0])))]
if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0): if not node.fgraph.shape_feature.same_shape(idx, y, 0, 0):
...@@ -3568,7 +3574,8 @@ def local_join_empty(node): ...@@ -3568,7 +3574,8 @@ def local_join_empty(node):
return return
new_inputs = [] new_inputs = []
try: try:
join_idx = get_scalar_constant_value(node.inputs[0]) join_idx = get_scalar_constant_value(node.inputs[0],
only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
return return
for idx in xrange(1, len(node.inputs)): for idx in xrange(1, len(node.inputs)):
...@@ -3727,8 +3734,10 @@ def local_useless_switch(node): ...@@ -3727,8 +3734,10 @@ def local_useless_switch(node):
""" """
if (isinstance(node.op, T.Elemwise) and if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)): isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0], elemwise=False) cond = T.extract_constant(node.inputs[0],
if type(cond) is numpy.ndarray and cond.ndim == 0: only_process_constants=True)
if ((type(cond) is numpy.ndarray and cond.ndim == 0) or
isinstance(cond, numpy.number)):
if cond == 0: if cond == 0:
correct_out = node.inputs[2] correct_out = node.inputs[2]
else: else:
...@@ -3775,8 +3784,8 @@ def local_useless_switch(node): ...@@ -3775,8 +3784,8 @@ def local_useless_switch(node):
isinstance(cond_var.owner.op.scalar_op, scalar.LE) and \ isinstance(cond_var.owner.op.scalar_op, scalar.LE) and \
cond_var.owner.inputs[0].owner and \ cond_var.owner.inputs[0].owner and \
isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) and \ isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) and \
T.extract_constant(cond_var.owner.inputs[1]) == 0 and \ T.extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0 and \
T.extract_constant(left) == 0 and \ T.extract_constant(left, only_process_constants=True) == 0 and \
right is cond_var.owner.inputs[0]: right is cond_var.owner.inputs[0]:
assert right.type == node.outputs[0].type assert right.type == node.outputs[0].type
# No need to copy over stacktrace, because the right input node # No need to copy over stacktrace, because the right input node
...@@ -3889,7 +3898,8 @@ def local_div_switch_sink(node): ...@@ -3889,7 +3898,8 @@ def local_div_switch_sink(node):
if node.inputs[0].owner and node.inputs[0].owner.op == T.switch: if node.inputs[0].owner and node.inputs[0].owner.op == T.switch:
switch = node.inputs[0].owner switch = node.inputs[0].owner
try: try:
if get_scalar_constant_value(switch.inputs[1]) == 0.: if get_scalar_constant_value(switch.inputs[1],
only_process_constants=True) == 0.:
fdiv = op(switch.inputs[2], node.inputs[1]) fdiv = op(switch.inputs[2], node.inputs[1])
# Copy over stacktrace for elementwise division op # Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op. # from previous elementwise multiplication op.
...@@ -3911,7 +3921,8 @@ def local_div_switch_sink(node): ...@@ -3911,7 +3921,8 @@ def local_div_switch_sink(node):
except NotScalarConstantError: except NotScalarConstantError:
pass pass
try: try:
if get_scalar_constant_value(switch.inputs[2]) == 0.: if get_scalar_constant_value(switch.inputs[2],
only_process_constants=True) == 0.:
fdiv = op(switch.inputs[1], node.inputs[1]) fdiv = op(switch.inputs[1], node.inputs[1])
# Copy over stacktrace for elementwise division op # Copy over stacktrace for elementwise division op
# from previous elementwise multiplication op. # from previous elementwise multiplication op.
...@@ -3976,7 +3987,8 @@ def local_useless_tile(node): ...@@ -3976,7 +3987,8 @@ def local_useless_tile(node):
""" """
if isinstance(node.op, T.Tile): if isinstance(node.op, T.Tile):
try: try:
a = T.get_scalar_constant_value(node.inputs[1]) a = T.get_scalar_constant_value(node.inputs[1],
only_process_constants=True)
if a == 1: if a == 1:
try: try:
l = T.get_vector_length(node.inputs[1]) l = T.get_vector_length(node.inputs[1])
...@@ -4159,7 +4171,8 @@ if 0: ...@@ -4159,7 +4171,8 @@ if 0:
def tmp(thing): def tmp(thing):
try: try:
return T.get_scalar_constant_value(thing) return T.get_scalar_constant_value(thing,
only_process_constants=True)
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
print(e, thing.owner.inputs[0]) print(e, thing.owner.inputs[0])
return None return None
...@@ -5156,7 +5169,7 @@ def local_reduce_join(node): ...@@ -5156,7 +5169,7 @@ def local_reduce_join(node):
node.inputs[0].owner and node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Join)): isinstance(node.inputs[0].owner.op, T.Join)):
join = node.inputs[0].owner join = node.inputs[0].owner
if T.extract_constant(join.inputs[0]) != 0: if T.extract_constant(join.inputs[0], only_process_constants=True) != 0:
return return
if isinstance(node.op.scalar_op, (scalar.Maximum, scalar.Minimum)): if isinstance(node.op.scalar_op, (scalar.Maximum, scalar.Minimum)):
...@@ -5206,7 +5219,9 @@ def local_reduce_join(node): ...@@ -5206,7 +5219,9 @@ def local_reduce_join(node):
# We add the new check late to don't add extra warning. # We add the new check late to don't add extra warning.
try: try:
join_axis = get_scalar_constant_value(join.inputs[0]) join_axis = get_scalar_constant_value(join.inputs[0],
only_process_constants=True)
if join_axis != reduce_axis[0]: if join_axis != reduce_axis[0]:
return return
except NotScalarConstantError: except NotScalarConstantError:
...@@ -5288,7 +5303,8 @@ def local_opt_alloc(node): ...@@ -5288,7 +5303,8 @@ def local_opt_alloc(node):
if (node.op.axis is None or if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))): node.op.axis == tuple(range(input.ndim))):
try: try:
val = get_scalar_constant_value(input) val = get_scalar_constant_value(input,
only_process_constants=True)
assert val.size == 1 assert val.size == 1
# check which type of op # check which type of op
casted = T.mul(*shapes).astype(str(input.dtype)) casted = T.mul(*shapes).astype(str(input.dtype))
...@@ -5302,7 +5318,8 @@ def local_opt_alloc(node): ...@@ -5302,7 +5318,8 @@ def local_opt_alloc(node):
pass pass
else: else:
try: try:
val = get_scalar_constant_value(input) val = get_scalar_constant_value(input,
only_process_constants=True)
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] val = val.reshape(1)[0]
to_prod = [shapes[i] for i in xrange(len(shapes)) to_prod = [shapes[i] for i in xrange(len(shapes))
...@@ -5746,7 +5763,8 @@ def local_abs_merge(node): ...@@ -5746,7 +5763,8 @@ def local_abs_merge(node):
inputs.append(i.owner.inputs[0]) inputs.append(i.owner.inputs[0])
elif isinstance(i, Constant): elif isinstance(i, Constant):
try: try:
const = get_scalar_constant_value(i) const = get_scalar_constant_value(i,
only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
if not (const >= 0).all(): if not (const >= 0).all():
...@@ -5766,12 +5784,12 @@ def local_abs_merge(node): ...@@ -5766,12 +5784,12 @@ def local_abs_merge(node):
@gof.local_optimizer([T.log]) @gof.local_optimizer([T.log])
def local_log1p(node): def local_log1p(node):
# log(1+x) -> log1p(x) # log(1+x) -> log1p(x)
# log(1-x) -> log1p(-x)
if node.op == T.log: if node.op == T.log:
log_arg, = node.inputs log_arg, = node.inputs
if log_arg.owner and log_arg.owner.op == T.add: if log_arg.owner and log_arg.owner.op == T.add:
scalars, scalar_inputs, nonconsts = scalarconsts_rest( scalars, scalar_inputs, nonconsts = scalarconsts_rest(
log_arg.owner.inputs) log_arg.owner.inputs, only_process_constants=True)
# scalar_inputs are potentially dimshuffled and fill'd scalars # scalar_inputs are potentially dimshuffled and fill'd scalars
if scalars and numpy.allclose(numpy.sum(scalars), 1): if scalars and numpy.allclose(numpy.sum(scalars), 1):
if not nonconsts: if not nonconsts:
...@@ -5782,6 +5800,13 @@ def local_log1p(node): ...@@ -5782,6 +5800,13 @@ def local_log1p(node):
return _fill_chain(T.log1p(T.add(*nonconsts)), return _fill_chain(T.log1p(T.add(*nonconsts)),
scalar_inputs) scalar_inputs)
elif log_arg.owner and log_arg.owner.op == T.sub:
one = T.extract_constant(log_arg.owner.inputs[0],
only_process_constants=True)
if one != 1:
return
return [T.log1p(T.neg(log_arg.owner.inputs[1]))]
# TODO: in canonicalize, change log10 and log2 -> log # TODO: in canonicalize, change log10 and log2 -> log
@register_stabilize @register_stabilize
...@@ -6017,7 +6042,6 @@ def constant_folding(node): ...@@ -6017,7 +6042,6 @@ def constant_folding(node):
required = thunk() required = thunk()
assert not required # a node whose inputs are all provided should always assert not required # a node whose inputs are all provided should always
# return successfully # return successfully
rval = [] rval = []
for output in node.outputs: for output in node.outputs:
assert compute_map[output][0], (output, storage_map[output][0]) assert compute_map[output][0], (output, storage_map[output][0])
...@@ -6036,6 +6060,7 @@ def constant_folding(node): ...@@ -6036,6 +6060,7 @@ def constant_folding(node):
topo_constant_folding = in2out(constant_folding, ignore_newtrees=True, topo_constant_folding = in2out(constant_folding, ignore_newtrees=True,
name="topo_constant_folding") name="topo_constant_folding")
register_canonicalize(topo_constant_folding, 'fast_compile', final_opt=True) register_canonicalize(topo_constant_folding, 'fast_compile', final_opt=True)
register_uncanonicalize(topo_constant_folding, 'fast_compile', final_opt=True)
register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True) register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True)
register_specialize(topo_constant_folding, 'fast_compile', final_opt=True) register_specialize(topo_constant_folding, 'fast_compile', final_opt=True)
...@@ -6328,7 +6353,8 @@ def local_grad_log_erfc_neg(node): ...@@ -6328,7 +6353,8 @@ def local_grad_log_erfc_neg(node):
mul_neg = T.mul(*mul_inputs) mul_neg = T.mul(*mul_inputs)
try: try:
cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0]) cst2 = get_scalar_constant_value(mul_neg.owner.inputs[0],
only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
...@@ -6355,7 +6381,8 @@ def local_grad_log_erfc_neg(node): ...@@ -6355,7 +6381,8 @@ def local_grad_log_erfc_neg(node):
x = erfc_x x = erfc_x
try: try:
cst = get_scalar_constant_value(erfc_x.owner.inputs[0]) cst = get_scalar_constant_value(erfc_x.owner.inputs[0],
only_process_constants=True)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
if cst2 != -cst * 2: if cst2 != -cst * 2:
......
...@@ -1635,8 +1635,8 @@ def test_log_add(): ...@@ -1635,8 +1635,8 @@ def test_log_add():
def test_local_useless_slice(): def test_local_useless_slice():
# test a simple matrix # test a simple matrix
x = tensor.matrix('x') x = tensor.matrix('x')
mode_unopt = compile.get_default_mode().excluding("local_useless_slice") mode_unopt = compile.get_default_mode().excluding("local_useless_slice", "local_mul_canonizer")
mode_opt = compile.get_default_mode().including("local_useless_slice") mode_opt = compile.get_default_mode().including("local_useless_slice").excluding("local_mul_canonizer")
# test with and without the useless slice # test with and without the useless slice
o = 2 * x[0, :] o = 2 * x[0, :]
...@@ -2124,7 +2124,7 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -2124,7 +2124,7 @@ class test_local_subtensor_lift(unittest.TestCase):
f1 = function([x], newx[:2, :5], mode=mode_opt) f1 = function([x], newx[:2, :5], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(check_stack_trace(f1, ops_to_check=[ self.assertTrue(check_stack_trace(f1, ops_to_check=[
Subtensor, tensor.Rebroadcast])) Subtensor, tensor.Rebroadcast]))
prog = f1.maker.fgraph.toposort() prog = f1.maker.fgraph.toposort()
assert isinstance(prog[0].op, tensor.Subtensor) assert isinstance(prog[0].op, tensor.Subtensor)
assert isinstance(prog[1].op, tensor.Rebroadcast) assert isinstance(prog[1].op, tensor.Rebroadcast)
...@@ -2140,7 +2140,7 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -2140,7 +2140,7 @@ class test_local_subtensor_lift(unittest.TestCase):
f2 = function([y], newy[:, 3, 0, :], mode=mode_opt) f2 = function([y], newy[:, 3, 0, :], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
self.assertTrue(check_stack_trace(f2, ops_to_check=[ self.assertTrue(check_stack_trace(f2, ops_to_check=[
Subtensor, tensor.Rebroadcast])) Subtensor, tensor.Rebroadcast]))
prog = f2.maker.fgraph.toposort() prog = f2.maker.fgraph.toposort()
assert isinstance(prog[0].op, tensor.Subtensor) assert isinstance(prog[0].op, tensor.Subtensor)
assert isinstance(prog[1].op, tensor.Rebroadcast) assert isinstance(prog[1].op, tensor.Rebroadcast)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论