提交 6253b797 authored 作者: abergeron's avatar abergeron

Merge pull request #1717 from nouiz/faster_opt

Faster opt
...@@ -166,9 +166,10 @@ yourself. Here is some code that will help you. ...@@ -166,9 +166,10 @@ yourself. Here is some code that will help you.
cd OpenBLAS cd OpenBLAS
make FC=gfortran make FC=gfortran
sudo make PREFIX=/usr/local/ install sudo make PREFIX=/usr/local/ install
cd /usr/local/lib # Tell Theano to use OpenBLAS.
ln -s libopenblas.so /usr/lib/libblas.so # This work only for the current user.
ln -s libopenblas.so.0 /usr/lib/libblas.so.3gf # Each Theano user on that computer should run that line.
echo -e "\n[blas]\nldflags = -lopenblas\n" >> ~/.theanorc
Contributed GPU instruction Contributed GPU instruction
......
...@@ -787,8 +787,8 @@ class ProfileStats(object): ...@@ -787,8 +787,8 @@ class ProfileStats(object):
if self.variable_shape or self.variable_strides: if self.variable_shape or self.variable_strides:
self.summary_memory(file, n_apply_to_print) self.summary_memory(file, n_apply_to_print)
if self.optimizer_profile: if self.optimizer_profile:
print "Optimizer Profile" print >> file, "Optimizer Profile"
print "-----------------" print >> file, "-----------------"
self.optimizer_profile[0].print_profile(file, self.optimizer_profile[0].print_profile(file,
self.optimizer_profile[1]) self.optimizer_profile[1])
......
...@@ -1252,7 +1252,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -1252,7 +1252,7 @@ class NavigatorOptimizer(Optimizer):
pruner(node) pruner(node)
if chin is not None: if chin is not None:
def on_change_input(self, fgraph, node, i, r, new_r, reason): def on_change_input(self, fgraph, node, i, r, new_r, reason):
chin(node, i, r, new_r) chin(node, i, r, new_r, reason)
u = Updater() u = Updater()
fgraph.attach_feature(u) fgraph.attach_feature(u)
...@@ -1701,7 +1701,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1701,7 +1701,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
lopt)) lopt))
count_opt = [] count_opt = []
not_used = 0 not_used = []
not_used_time = 0 not_used_time = 0
process_count = {} process_count = {}
for o in opt.global_optimizers + list(opt.get_local_optimizers()): for o in opt.global_optimizers + list(opt.get_local_optimizers()):
...@@ -1713,7 +1713,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1713,7 +1713,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if count > 0: if count > 0:
count_opt.append((time_opts[opt], count, opt)) count_opt.append((time_opts[opt], count, opt))
else: else:
not_used += 1 not_used.append((time_opts[opt], opt))
not_used_time += time_opts[opt] not_used_time += time_opts[opt]
if count_opt: if count_opt:
...@@ -1724,7 +1724,10 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1724,7 +1724,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
print >> stream, blanc, ' %.3fs - %d - %s' % ( print >> stream, blanc, ' %.3fs - %d - %s' % (
t, count, opt) t, count, opt)
print >> stream, blanc, ' %.3fs - in %d optimization that where not used' % ( print >> stream, blanc, ' %.3fs - in %d optimization that where not used' % (
not_used_time, not_used) not_used_time, len(not_used))
not_used.sort()
for (t, opt) in not_used[::-1]:
print >> stream, blanc + " ", ' %.3fs - %s' % (t, opt)
print >> stream print >> stream
@staticmethod @staticmethod
......
...@@ -76,11 +76,11 @@ class GpuElemwise(HideC, Elemwise): ...@@ -76,11 +76,11 @@ class GpuElemwise(HideC, Elemwise):
try: try:
inps = [make_argument(i, 'i%d' % (n,)) for n, i in inps = [make_argument(i, 'i%d' % (n,)) for n, i in
enumerate(node.inputs)] enumerate(node.inputs)]
scal_ins = [scalar.Scalar(i.dtype) for i in node.inputs] scal_ins = [scalar.get_scalar_type(i.dtype) for i in node.inputs]
outs = [make_argument(o, 'o%d' % (n,)) for n, o in outs = [make_argument(o, 'o%d' % (n,)) for n, o in
enumerate(node.outputs) if not n in self.inplace_pattern] enumerate(node.outputs) if not n in self.inplace_pattern]
scal_out = [scalar.Scalar(o.dtype) for o in node.outputs] scal_out = [scalar.get_scalar_type(o.dtype) for o in node.outputs]
fake_node = Apply(self.scalar_op, [i() for i in scal_ins], fake_node = Apply(self.scalar_op, [i() for i in scal_ins],
[o() for o in scal_out]) [o() for o in scal_out])
...@@ -103,11 +103,11 @@ class GpuElemwise(HideC, Elemwise): ...@@ -103,11 +103,11 @@ class GpuElemwise(HideC, Elemwise):
def generate_kernel(self, node, nodename): def generate_kernel(self, node, nodename):
inps = [make_argument(i, 'i%d' % (n,)) for n, i in inps = [make_argument(i, 'i%d' % (n,)) for n, i in
enumerate(node.inputs)] enumerate(node.inputs)]
scal_ins = [scalar.Scalar(i.dtype) for i in node.inputs] scal_ins = [scalar.get_scalar_type(i.dtype) for i in node.inputs]
outs = [make_argument(o, 'o%d' % (n,)) for n, o in outs = [make_argument(o, 'o%d' % (n,)) for n, o in
enumerate(node.outputs) if not n in self.inplace_pattern] enumerate(node.outputs) if not n in self.inplace_pattern]
scal_out = [scalar.Scalar(o.dtype) for o in node.outputs] scal_out = [scalar.get_scalar_type(o.dtype) for o in node.outputs]
fake_node = Apply(self.scalar_op, [i() for i in scal_ins], fake_node = Apply(self.scalar_op, [i() for i in scal_ins],
[o() for o in scal_out]) [o() for o in scal_out])
......
...@@ -69,6 +69,18 @@ def upcast(dtype, *dtypes): ...@@ -69,6 +69,18 @@ def upcast(dtype, *dtypes):
return rval return rval
def get_scalar_type(dtype):
"""
Return an Scalar(dtype) object.
This cache objects to save allocation and run time.
"""
if dtype not in get_scalar_type.cache:
get_scalar_type.cache[dtype] = Scalar(dtype=dtype)
return get_scalar_type.cache[dtype]
get_scalar_type.cache = {}
def as_scalar(x, name=None): def as_scalar(x, name=None):
if isinstance(x, gof.Apply): if isinstance(x, gof.Apply):
if len(x.outputs) != 1: if len(x.outputs) != 1:
...@@ -91,7 +103,7 @@ def constant(x): ...@@ -91,7 +103,7 @@ def constant(x):
# purpose typically. # purpose typically.
if hasattr(x, 'dtype'): if hasattr(x, 'dtype'):
assert x.ndim == 0 assert x.ndim == 0
return ScalarConstant(Scalar(str(x.dtype)), x) return ScalarConstant(get_scalar_type(str(x.dtype)), x)
if isinstance(x, builtin_float): if isinstance(x, builtin_float):
for dtype in ['float32', 'float64']: for dtype in ['float32', 'float64']:
x_ = theano._asarray(x, dtype=dtype) x_ = theano._asarray(x, dtype=dtype)
...@@ -99,7 +111,7 @@ def constant(x): ...@@ -99,7 +111,7 @@ def constant(x):
break break
x_ = None x_ = None
assert x_ is not None assert x_ is not None
return ScalarConstant(Scalar(str(x_.dtype)), x) return ScalarConstant(get_scalar_type(str(x_.dtype)), x)
if isinstance(x, builtin_int): if isinstance(x, builtin_int):
for dtype in ['int8', 'int16', 'int32', 'int64']: for dtype in ['int8', 'int16', 'int32', 'int64']:
x_ = theano._asarray(x, dtype=dtype) x_ = theano._asarray(x, dtype=dtype)
...@@ -107,7 +119,7 @@ def constant(x): ...@@ -107,7 +119,7 @@ def constant(x):
break break
x_ = None x_ = None
assert x_ is not None assert x_ is not None
return ScalarConstant(Scalar(str(x_.dtype)), x) return ScalarConstant(get_scalar_type(str(x_.dtype)), x)
if isinstance(x, builtin_complex): if isinstance(x, builtin_complex):
#TODO: We have added the complex type, so this should be tested #TODO: We have added the complex type, so this should be tested
raise NotImplementedError() raise NotImplementedError()
...@@ -457,18 +469,18 @@ theano.compile.register_view_op_c_code( ...@@ -457,18 +469,18 @@ theano.compile.register_view_op_c_code(
1) 1)
int8 = Scalar('int8') int8 = get_scalar_type('int8')
int16 = Scalar('int16') int16 = get_scalar_type('int16')
int32 = Scalar('int32') int32 = get_scalar_type('int32')
int64 = Scalar('int64') int64 = get_scalar_type('int64')
uint8 = Scalar('uint8') uint8 = get_scalar_type('uint8')
uint16 = Scalar('uint16') uint16 = get_scalar_type('uint16')
uint32 = Scalar('uint32') uint32 = get_scalar_type('uint32')
uint64 = Scalar('uint64') uint64 = get_scalar_type('uint64')
float32 = Scalar('float32') float32 = get_scalar_type('float32')
float64 = Scalar('float64') float64 = get_scalar_type('float64')
complex64 = Scalar('complex64') complex64 = get_scalar_type('complex64')
complex128 = Scalar('complex128') complex128 = get_scalar_type('complex128')
int_types = int8, int16, int32, int64 int_types = int8, int16, int32, int64
uint_types = uint8, uint16, uint32, uint64 uint_types = uint8, uint16, uint32, uint64
...@@ -584,7 +596,7 @@ class _scalar_py_operators: ...@@ -584,7 +596,7 @@ class _scalar_py_operators:
# The second is needed for Elemwise ops to work right # The second is needed for Elemwise ops to work right
if dtype is None: if dtype is None:
dtype = str(self.type.dtype) dtype = str(self.type.dtype)
return second(self, ScalarConstant(Scalar(dtype), 0)) return second(self, ScalarConstant(get_scalar_type(dtype), 0))
def astype(self, dtype): def astype(self, dtype):
return cast(self, dtype) return cast(self, dtype)
...@@ -628,7 +640,8 @@ complexs128 = _multi(complex128) ...@@ -628,7 +640,8 @@ complexs128 = _multi(complex128)
# necessary to use this same mechanism in other places as well in the future. # necessary to use this same mechanism in other places as well in the future.
class upcast_out(object): class upcast_out(object):
def __new__(self, *types): def __new__(self, *types):
return Scalar(dtype=Scalar.upcast(*types)), dtype = Scalar.upcast(*types)
return get_scalar_type(dtype),
class upgrade_to_float(object): class upgrade_to_float(object):
...@@ -644,7 +657,7 @@ class upgrade_to_float(object): ...@@ -644,7 +657,7 @@ class upgrade_to_float(object):
uint16: float32, uint16: float32,
uint32: float64, uint32: float64,
uint64: float64} uint64: float64}
return Scalar(Scalar.upcast(*[conv.get(type, type) return get_scalar_type(Scalar.upcast(*[conv.get(type, type)
for type in types])), for type in types])),
...@@ -656,7 +669,7 @@ class same_out(object): ...@@ -656,7 +669,7 @@ class same_out(object):
def upcast_out_no_complex(*types): def upcast_out_no_complex(*types):
if any([type in complex_types for type in types]): if any([type in complex_types for type in types]):
raise TypeError('complex type are not supported') raise TypeError('complex type are not supported')
return Scalar(dtype=Scalar.upcast(*types)), return get_scalar_type(dtype=Scalar.upcast(*types)),
def same_out_float_only(type): def same_out_float_only(type):
...@@ -1455,7 +1468,7 @@ def div_proxy(x, y): ...@@ -1455,7 +1468,7 @@ def div_proxy(x, y):
class TrueDiv(BinaryScalarOp): class TrueDiv(BinaryScalarOp):
def output_types(self, types): def output_types(self, types):
if all(t in discrete_types for t in types): if all(t in discrete_types for t in types):
return [Scalar(config.floatX)] return [get_scalar_type(config.floatX)]
else: else:
return super(TrueDiv, self).output_types(types) return super(TrueDiv, self).output_types(types)
......
...@@ -59,7 +59,7 @@ def safe_new(x, tag='', dtype=None): ...@@ -59,7 +59,7 @@ def safe_new(x, tag='', dtype=None):
# making the pushout optimization fail # making the pushout optimization fail
elif isinstance(x, scalar.ScalarVariable): elif isinstance(x, scalar.ScalarVariable):
if dtype: if dtype:
nw_x = scalar.Scalar(dtype=dtype)() nw_x = scalar.get_scalar_type(dtype=dtype)()
else: else:
nw_x = x.type() nw_x = x.type()
nw_x.name = nw_name nw_x.name = nw_name
......
...@@ -1113,8 +1113,11 @@ class test_structureddot(unittest.TestCase): ...@@ -1113,8 +1113,11 @@ class test_structureddot(unittest.TestCase):
utt.assert_allclose(scipy_result, theano_result) utt.assert_allclose(scipy_result, theano_result)
if (not theano.config.mode in ["DebugMode", "DEBUG_MODE"] and if (not theano.config.mode in ["DebugMode", "DEBUG_MODE"] and
theano.config.cxx): theano.config.cxx):
self.assertFalse(theano_time > overhead_rtol * scipy_time + self.assertFalse(
overhead_tol) theano_time > overhead_rtol * scipy_time + overhead_tol,
(theano_time,
overhead_rtol * scipy_time + overhead_tol,
scipy_time, overhead_rtol, overhead_tol))
class DotTests(utt.InferShapeTester): class DotTests(utt.InferShapeTester):
......
...@@ -993,7 +993,7 @@ class ScalarFromTensor(Op): ...@@ -993,7 +993,7 @@ class ScalarFromTensor(Op):
assert t.type.broadcastable == () assert t.type.broadcastable == ()
return Apply(self, return Apply(self,
[t], [t],
[scal.Scalar(dtype=t.type.dtype).make_variable()]) [scal.get_scalar_type(dtype=t.type.dtype).make_variable()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
s, = inp s, = inp
......
...@@ -8,7 +8,7 @@ import theano ...@@ -8,7 +8,7 @@ import theano
from theano import gof from theano import gof
from theano.gof import Apply, Op from theano.gof import Apply, Op
from theano import scalar from theano import scalar
from theano.scalar import Scalar from theano.scalar import Scalar, get_scalar_type
from theano.printing import pprint from theano.printing import pprint
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
from theano.tensor.utils import hash_from_dict from theano.tensor.utils import hash_from_dict
...@@ -515,7 +515,7 @@ class Elemwise(Op): ...@@ -515,7 +515,7 @@ class Elemwise(Op):
""" """
inputs = map(as_tensor_variable, inputs) inputs = map(as_tensor_variable, inputs)
shadow = self.scalar_op.make_node( shadow = self.scalar_op.make_node(
*[Scalar(dtype=i.type.dtype)() for i in inputs]) *[get_scalar_type(dtype=i.type.dtype)() for i in inputs])
target_length = max([input.type.ndim for input in inputs]) target_length = max([input.type.ndim for input in inputs])
...@@ -718,7 +718,7 @@ class Elemwise(Op): ...@@ -718,7 +718,7 @@ class Elemwise(Op):
def as_scalar(t): def as_scalar(t):
if isinstance(t.type, (NullType, DisconnectedType)): if isinstance(t.type, (NullType, DisconnectedType)):
return t return t
return Scalar(t.type.dtype)() return get_scalar_type(t.type.dtype)()
scalar_inputs = map(as_scalar, inputs) scalar_inputs = map(as_scalar, inputs)
scalar_ograds = map(as_scalar, ograds) scalar_ograds = map(as_scalar, ograds)
...@@ -1039,9 +1039,9 @@ class Elemwise(Op): ...@@ -1039,9 +1039,9 @@ class Elemwise(Op):
# We generate the C code of the inner loop using the scalar op # We generate the C code of the inner loop using the scalar op
task_code = self.scalar_op.c_code( task_code = self.scalar_op.c_code(
Apply(self.scalar_op, Apply(self.scalar_op,
[Scalar(dtype=input.type.dtype)() [get_scalar_type(dtype=input.type.dtype)()
for input in node.inputs], for input in node.inputs],
[Scalar(dtype=output.type.dtype)() [get_scalar_type(dtype=output.type.dtype)()
for output in node.outputs]), for output in node.outputs]),
nodename + '_scalar_', nodename + '_scalar_',
["%s_i" % s for s in _inames], ["%s_i" % s for s in _inames],
...@@ -1161,11 +1161,11 @@ class Elemwise(Op): ...@@ -1161,11 +1161,11 @@ class Elemwise(Op):
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op, scalar_node = Apply(self.scalar_op,
[Scalar(dtype=input.type.dtype)() for input in node.inputs], [get_scalar_type(dtype=input.type.dtype)() for input in node.inputs],
[Scalar(dtype=output.type.dtype)() for output in node.outputs]) [get_scalar_type(dtype=output.type.dtype)() for output in node.outputs])
version.append(self.scalar_op.c_code_cache_version_apply(scalar_node)) version.append(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs: for i in node.inputs + node.outputs:
version.append(Scalar(dtype=i.type.dtype).c_code_cache_version()) version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version())
if all(version): if all(version):
return tuple(version) return tuple(version)
else: else:
...@@ -1531,9 +1531,9 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1531,9 +1531,9 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
task1_code = self.scalar_op.c_code( task1_code = self.scalar_op.c_code(
Apply( Apply(
self.scalar_op, self.scalar_op,
[Scalar(dtype=input.type.dtype)() [get_scalar_type(dtype=input.type.dtype)()
for input in (node.inputs * 2)], for input in (node.inputs * 2)],
[Scalar(dtype=output.type.dtype)() [get_scalar_type(dtype=output.type.dtype)()
for input in node.outputs]), for input in node.outputs]),
None, None,
["%s_i" % aname, "%s_i" % inames[0]], ["%s_i" % aname, "%s_i" % inames[0]],
...@@ -1583,11 +1583,11 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){ ...@@ -1583,11 +1583,11 @@ for(int i=0;i<PyArray_NDIM(%(iname)s);i++){
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply(self.scalar_op, scalar_node = Apply(self.scalar_op,
[Scalar(dtype=input.type.dtype)() for input in node.inputs], [get_scalar_type(dtype=input.type.dtype)() for input in node.inputs],
[Scalar(dtype=output.type.dtype)() for output in node.outputs]) [get_scalar_type(dtype=output.type.dtype)() for output in node.outputs])
version.append(self.scalar_op.c_code_cache_version_apply(scalar_node)) version.append(self.scalar_op.c_code_cache_version_apply(scalar_node))
for i in node.inputs + node.outputs: for i in node.inputs + node.outputs:
version.append(Scalar(dtype=i.type.dtype).c_code_cache_version()) version.append(get_scalar_type(dtype=i.type.dtype).c_code_cache_version())
if all(version): if all(version):
return tuple(version) return tuple(version)
else: else:
...@@ -1665,7 +1665,7 @@ class CAReduceDtype(CAReduce): ...@@ -1665,7 +1665,7 @@ class CAReduceDtype(CAReduce):
def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None): def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None):
""" """
Usage: CAReduceDtype(scalar_op, axis=None, dtype=None) Usage: CAReduceDtype(scalar_op, axis=None, dtype=None, acc_dtype=None)
:param scalar_op: a binary scalar op with only one output. :param scalar_op: a binary scalar op with only one output.
It must be commutative and associative. It must be commutative and associative.
......
...@@ -162,7 +162,7 @@ class T_sigmoid_opts(unittest.TestCase): ...@@ -162,7 +162,7 @@ class T_sigmoid_opts(unittest.TestCase):
f = theano.function([x], (T.fill(x, -1.0) * T.exp(x)) / f = theano.function([x], (T.fill(x, -1.0) * T.exp(x)) /
((1 + T.exp(x)) * (1 + T.exp(-x))), mode=m) ((1 + T.exp(x)) * (1 + T.exp(-x))), mode=m)
assert [node.op for node in f.maker.fgraph.toposort()] == [sigmoid, assert [node.op for node in f.maker.fgraph.toposort()] == [sigmoid,
T.mul, theano.tensor.inplace.neg_inplace] T.mul]
f(data) f(data)
f = theano.function([x], (T.fill(x, -1.1) * T.exp(x)) / f = theano.function([x], (T.fill(x, -1.1) * T.exp(x)) /
((1 + T.exp(x)) * (1 + T.exp(-x))), mode=m) ((1 + T.exp(x)) * (1 + T.exp(-x))), mode=m)
...@@ -238,7 +238,7 @@ class T_sigmoid_opts(unittest.TestCase): ...@@ -238,7 +238,7 @@ class T_sigmoid_opts(unittest.TestCase):
tensor.exp(x * y) * tensor.exp(y)), tensor.exp(x * y) * tensor.exp(y)),
mode=m) mode=m)
match(f, [sigmoid, tensor.mul, tensor.neg, tensor.exp, sigmoid, match(f, [sigmoid, tensor.mul, tensor.neg, tensor.exp, sigmoid,
tensor.mul, tensor.neg]) tensor.mul])
def test_perform_sigm_times_exp(self): def test_perform_sigm_times_exp(self):
""" """
......
...@@ -2559,12 +2559,12 @@ def local_fill_cut(node): ...@@ -2559,12 +2559,12 @@ def local_fill_cut(node):
# scalars, but we can't ignore the large matrix because it gives # scalars, but we can't ignore the large matrix because it gives
# the shape of the result. # the shape of the result.
if not opt.check_chain(node, T.Elemwise): if node.op != T.Elemwise:
return False return False
output = node.outputs[0] output = node.outputs[0]
try: try:
#reference is some input with the same type as the input but #reference is some input with the same type as the output but
#that is not produced by a fill #that is not produced by a fill
reference = [input reference = [input
for input in node.inputs for input in node.inputs
...@@ -2574,16 +2574,18 @@ def local_fill_cut(node): ...@@ -2574,16 +2574,18 @@ def local_fill_cut(node):
return False return False
new_inputs = [] new_inputs = []
new = False
for input in node.inputs: for input in node.inputs:
if opt.check_chain(input, T.fill): if input.owner and input.owner.op == T.fill:
model, filling = input.owner.inputs model, filling = input.owner.inputs
if encompasses_broadcastable(reference.type.broadcastable, if encompasses_broadcastable(reference.type.broadcastable,
filling.type.broadcastable): filling.type.broadcastable):
new_inputs.append(filling) new_inputs.append(filling)
new = True
continue continue
new_inputs.append(input) new_inputs.append(input)
if new_inputs == node.inputs: if not new:
return False return False
rval = node.op(*new_inputs) rval = node.op(*new_inputs)
...@@ -2787,9 +2789,9 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2787,9 +2789,9 @@ class Canonizer(gof.LocalOptimizer):
pairs = [self.get_num_denum(input2) for input2 in parent.inputs] pairs = [self.get_num_denum(input2) for input2 in parent.inputs]
if parent.op == self.main: if parent.op == self.main:
# If we have main(x, y), numx, denumx, numy and denumy # If we have main(x, y, ...), numx, denumx, numy, denumy, ...
# then num is concat(numx, numy) and denum is # then num is concat(numx, numy, num...) and denum is
# concat(denumx, denumy) note that main() can have any # concat(denumx, denumy, denum...) note that main() can have any
# number of arguments >= 0 concat is list concatenation # number of arguments >= 0 concat is list concatenation
num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs)) num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs))
denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs)) denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs))
...@@ -2865,12 +2867,13 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2865,12 +2867,13 @@ class Canonizer(gof.LocalOptimizer):
else: else:
return v return v
def simplify(self, num, denum): def simplify(self, num, denum, out_type):
""" """
Shorthand for: Shorthand for:
self.simplify_constants(*self.simplify_factors(num, denum)) self.simplify_constants(*self.simplify_factors(num, denum))
""" """
rval = self.simplify_constants(*self.simplify_factors(num, denum)) rval = self.simplify_constants(*self.simplify_factors(num, denum),
out_type=out_type)
for reason, simplifier in self.external_simplifiers: for reason, simplifier in self.external_simplifiers:
# TODO: document that 'reason' is associated with this # TODO: document that 'reason' is associated with this
# simplification to help auditing when things go # simplification to help auditing when things go
...@@ -2894,7 +2897,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2894,7 +2897,7 @@ class Canonizer(gof.LocalOptimizer):
denum.remove(v) denum.remove(v)
return num, denum return num, denum
def simplify_constants(self, orig_num, orig_denum): def simplify_constants(self, orig_num, orig_denum, out_type=None):
""" """
Finds all constants in orig_num and orig_denum (using Finds all constants in orig_num and orig_denum (using
...@@ -2912,7 +2915,6 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2912,7 +2915,6 @@ class Canonizer(gof.LocalOptimizer):
# Lists representing the numerator and denumerator # Lists representing the numerator and denumerator
num, denum = list(orig_num), list(orig_denum) num, denum = list(orig_num), list(orig_denum)
out_type = self.merge_num_denum(orig_num, orig_denum).type
# Lists representing the *constant* elements of num and denum # Lists representing the *constant* elements of num and denum
numct, denumct = [], [] numct, denumct = [], []
...@@ -2981,29 +2983,26 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2981,29 +2983,26 @@ class Canonizer(gof.LocalOptimizer):
if op not in [self.main, self.inverse, self.reciprocal]: if op not in [self.main, self.inverse, self.reciprocal]:
return False return False
out = node.outputs[0]
assert len(node.outputs) == 1 assert len(node.outputs) == 1
out = node.outputs[0]
# check if any of the clients of this node would be part of # check if any of the clients of this node would be part of
# this canonized graph... if so, we do nothing and wait for # this canonized graph... if so, we do nothing and wait for
# them to be transformed. # them to be transformed.
def _bypass_dimshuffle(n):
if (isinstance(getattr(n, 'op', None), DimShuffle) and
len(n.outputs[0].clients) <= 1):
return _bypass_dimshuffle(n.outputs[0].clients[0][0])
else:
return n
for c, c_idx in out.clients: for c, c_idx in out.clients:
if c == 'output': if c == 'output':
continue continue
if getattr(_bypass_dimshuffle(c), 'op', '') in [ while (isinstance(getattr(c, 'op', None), DimShuffle) and
self.main, self.inverse, self.reciprocal]: len(c.outputs[0].clients) <= 1):
c = c.outputs[0].clients[0][0]
if getattr(c, 'op', '') in [self.main, self.inverse,
self.reciprocal]:
return False return False
# Here we make the canonical version of the graph around this node # Here we make the canonical version of the graph around this node
# See the documentation of get_num_denum and simplify # See the documentation of get_num_denum and simplify
orig_num, orig_denum = self.get_num_denum(node.outputs[0]) orig_num, orig_denum = self.get_num_denum(node.outputs[0])
num, denum = self.simplify(list(orig_num), list(orig_denum)) num, denum = self.simplify(list(orig_num), list(orig_denum), out.type)
def same(x, y): def same(x, y):
return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in return len(x) == len(y) and all(N.all(xe == ye) for xe, ye in
...@@ -3387,20 +3386,6 @@ def local_sum_alloc(node): ...@@ -3387,20 +3386,6 @@ def local_sum_alloc(node):
pass pass
@gof.local_optimizer([T.mul])
def local_mul_to_neg(node):
if node.op == T.mul and N.all(
local_mul_canonizer.get_constant(node.inputs[0]) == -1.0):
other_prod = local_mul_canonizer.merge_num_denum(node.inputs[1:], [])
if other_prod.type == node.outputs[0].type:
return [-other_prod]
# else the multiplication is also acting as a cast, so we
# might as well leave it alone. I don't think it's better to
# turn this into a negation in the wrong type, followed by an
# explicit cast.
register_specialize(local_mul_to_neg)
@register_specialize @register_specialize
@gof.local_optimizer([T.neg]) @gof.local_optimizer([T.neg])
def local_neg_neg(node): def local_neg_neg(node):
...@@ -3447,7 +3432,7 @@ def local_mul_zero(node): ...@@ -3447,7 +3432,7 @@ def local_mul_zero(node):
except NotScalarConstantError: except NotScalarConstantError:
continue continue
#print 'MUL by value', value, node.inputs #print 'MUL by value', value, node.inputs
if N.all(value == 0): if value == 0:
#print '... returning zeros' #print '... returning zeros'
return _fill_chain(theano._asarray(0, dtype=otype.dtype), return _fill_chain(theano._asarray(0, dtype=otype.dtype),
node.inputs) node.inputs)
...@@ -3485,9 +3470,9 @@ register_canonicalize(local_inv_canon) ...@@ -3485,9 +3470,9 @@ register_canonicalize(local_inv_canon)
@gof.local_optimizer([T.pow]) @gof.local_optimizer([T.pow])
def local_pow_canonicalize(node): def local_pow_canonicalize(node):
if node.op == T.pow: if node.op == T.pow:
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 0): if local_mul_canonizer.get_constant(node.inputs[1]) == 0:
return [broadcast_like(1, node.outputs[0], node.fgraph)] return [broadcast_like(1, node.outputs[0], node.fgraph)]
if N.all(local_mul_canonizer.get_constant(node.inputs[1]) == 1): if local_mul_canonizer.get_constant(node.inputs[1]) == 1:
return [broadcast_like(node.inputs[0], node.outputs[0], node.fgraph)] return [broadcast_like(node.inputs[0], node.outputs[0], node.fgraph)]
else: else:
return False return False
...@@ -3581,7 +3566,7 @@ def local_pow_specialize_device(node): ...@@ -3581,7 +3566,7 @@ def local_pow_specialize_device(node):
# 512 is too small for the cpu and too big for some gpu! # 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512: if abs(y) == int(abs(y)) and abs(y) <= 512:
pow2 = [xsym] pow2 = [xsym]
pow2_scal = [theano.scalar.Scalar(xsym.dtype)()] pow2_scal = [theano.scalar.get_scalar_type(xsym.dtype)()]
y_to_do = abs(y) y_to_do = abs(y)
for i in xrange(int(numpy.log2(y_to_do))): for i in xrange(int(numpy.log2(y_to_do))):
pow2.append(T.sqr(pow2[i])) pow2.append(T.sqr(pow2[i]))
...@@ -3616,7 +3601,15 @@ def local_pow_specialize_device(node): ...@@ -3616,7 +3601,15 @@ def local_pow_specialize_device(node):
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_specialize(node): def local_mul_specialize(node):
"""Remove special-case constants from mul arguments """Remove special-case constants from mul arguments and useless neg in inputs.
mul(-1, x) -> neg(x)
mul(1, x, y) -> mul(x, y)
mul(0, ...) -> alloc(0, shapes...)
This is not done if we would add more nodes in the graph, like with:
mul(-1, x, y) -/-> neg(mul(x, y))
""" """
# here, we are past the point of canonicalization, so we don't # here, we are past the point of canonicalization, so we don't
# want to put in un-necessary fills. # want to put in un-necessary fills.
...@@ -3626,19 +3619,23 @@ def local_mul_specialize(node): ...@@ -3626,19 +3619,23 @@ def local_mul_specialize(node):
#the idea here is that we have pow(x, y) #the idea here is that we have pow(x, y)
neg = False neg = False
new_inputs = [] new_inputs = []
nb_neg_node = 0
nb_cst = 0
for input in node.inputs: for input in node.inputs:
# remove any neg arguments # remove any neg arguments
while input.owner and input.owner.op == T.neg: while input.owner and input.owner.op == T.neg:
neg ^= True neg ^= True
input = input.owner.inputs[0] input = input.owner.inputs[0]
nb_neg_node += 1
# remove special case arguments of 1, -1 or 0 # remove special case arguments of 1, -1 or 0
y = local_mul_canonizer.get_constant(input) y = local_mul_canonizer.get_constant(input)
if N.all(y == 1.0): if y == 1.0:
continue nb_cst += 1
elif N.all(y == -1.0): elif y == -1.0:
nb_cst += 1
neg ^= True # toggles neg ^= True # toggles
elif N.all(y == 0.0): elif y == 0.0:
# if we find any zero, we just return right away # if we find any zero, we just return right away
return [broadcast_like(0, node.outputs[0], node.fgraph)] return [broadcast_like(0, node.outputs[0], node.fgraph)]
else: else:
...@@ -3652,9 +3649,16 @@ def local_mul_specialize(node): ...@@ -3652,9 +3649,16 @@ def local_mul_specialize(node):
else: else:
rval = new_inputs[0] rval = new_inputs[0]
else: else:
if neg: # The next case would cause a replace by an equivalent case.
rval = -T.mul(*new_inputs) if (neg and
else: nb_neg_node == 0 and
nb_cst == 1):
return
elif neg:
# Don't add an extra neg node as we can't
# fully replace this mul by a neg.
m1 = numpy.asarray(-1, dtype=node.outputs[0].dtype)
new_inputs = [m1] + new_inputs
rval = T.mul(*new_inputs) rval = T.mul(*new_inputs)
return [broadcast_like(rval, node.outputs[0], node.fgraph)] return [broadcast_like(rval, node.outputs[0], node.fgraph)]
...@@ -3712,9 +3716,6 @@ def local_add_specialize(node): ...@@ -3712,9 +3716,6 @@ def local_add_specialize(node):
return False return False
register_specialize(local_add_specialize) register_specialize(local_add_specialize)
# neg_to_mul = out2in(gof.LocalOptGroup(local_neg_to_mul))
# mul_to_neg = out2in(gof.LocalOptGroup(local_mul_to_neg))
mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut,
local_fill_sink), local_fill_sink),
name='mul_canonizer_groups') name='mul_canonizer_groups')
...@@ -3871,7 +3872,8 @@ register_canonicalize(local_add_canonizer, name='local_add_canonizer') ...@@ -3871,7 +3872,8 @@ register_canonicalize(local_add_canonizer, name='local_add_canonizer')
################## ##################
def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0): def distribute_greedy(pos_pairs, neg_pairs, num, denum,
out_type, minscore=0):
# each pair in pos_pairs and neg_pairs is a num/denum pair. this # each pair in pos_pairs and neg_pairs is a num/denum pair. this
# function attempts to add num and denum to the corresponding parts # function attempts to add num and denum to the corresponding parts
# of each pair, and counts how many multiplications/divisions can # of each pair, and counts how many multiplications/divisions can
...@@ -3887,10 +3889,10 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0): ...@@ -3887,10 +3889,10 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
# score is number of operations saved, higher is better # score is number of operations saved, higher is better
score = len(num) + div_cost * len(denum) score = len(num) + div_cost * len(denum)
new_pos_pairs = list(itertools.starmap(local_mul_canonizer.simplify, new_pos_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
[(n + num, d + denum) for (n, d) [(n + num, d + denum, out_type) for (n, d)
in pos_pairs])) in pos_pairs]))
new_neg_pairs = list(itertools.starmap(local_mul_canonizer.simplify, new_neg_pairs = list(itertools.starmap(local_mul_canonizer.simplify,
[(n + num, d + denum) for (n, d) [(n + num, d + denum, out_type) for (n, d)
in neg_pairs])) in neg_pairs]))
for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs +
new_neg_pairs): new_neg_pairs):
...@@ -3903,7 +3905,7 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0): ...@@ -3903,7 +3905,7 @@ def distribute_greedy(pos_pairs, neg_pairs, num, denum, minscore=0):
return True, new_pos_pairs, new_neg_pairs return True, new_pos_pairs, new_neg_pairs
def attempt_distribution(factor, num, denum): def attempt_distribution(factor, num, denum, out_type):
# we try to insert each num and each denum in the factor # we try to insert each num and each denum in the factor
# returns: changes?, new_factor, new_num, new_denum # returns: changes?, new_factor, new_num, new_denum
# if there are changes, new_num and new_denum contain all the numerators # if there are changes, new_num and new_denum contain all the numerators
...@@ -3916,13 +3918,13 @@ def attempt_distribution(factor, num, denum): ...@@ -3916,13 +3918,13 @@ def attempt_distribution(factor, num, denum):
change = False change = False
for n in list(num): for n in list(num):
success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs, success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs,
neg_pairs, [n], []) neg_pairs, [n], [], out_type)
if success: if success:
change = True change = True
num.remove(n) num.remove(n)
for d in list(denum): for d in list(denum):
success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs, success, pos_pairs, neg_pairs = distribute_greedy(pos_pairs,
neg_pairs, [], [d]) neg_pairs, [], [d], out_type)
if success: if success:
change = True change = True
denum.remove(d) denum.remove(d)
...@@ -3967,12 +3969,13 @@ def local_greedy_distributor(node): ...@@ -3967,12 +3969,13 @@ def local_greedy_distributor(node):
change = False change = False
out_type = out.type
for candidate in list(num): for candidate in list(num):
if candidate not in num: if candidate not in num:
continue continue
num.remove(candidate) num.remove(candidate)
_change, candidate, num, denum = attempt_distribution(candidate, _change, candidate, num, denum = attempt_distribution(candidate,
num, denum) num, denum, out_type)
change |= _change change |= _change
new_num.append(candidate) new_num.append(candidate)
...@@ -3981,7 +3984,7 @@ def local_greedy_distributor(node): ...@@ -3981,7 +3984,7 @@ def local_greedy_distributor(node):
continue continue
denum.remove(candidate) denum.remove(candidate)
_change, candidate, denum, num = attempt_distribution(candidate, _change, candidate, denum, num = attempt_distribution(candidate,
denum, num) denum, num, out_type)
change |= _change change |= _change
new_denum.append(candidate) new_denum.append(candidate)
...@@ -4636,7 +4639,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4636,7 +4639,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
elif ii in tmp_input: elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else: else:
tmp = scalar.Scalar(ii.dtype).make_variable() tmp = scalar.get_scalar_type(ii.dtype).make_variable()
try: try:
tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0] tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0]
except AttributeError: except AttributeError:
...@@ -4690,7 +4693,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4690,7 +4693,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
if inputs.count(i) == node.inputs.count(i): if inputs.count(i) == node.inputs.count(i):
s = s_inputs[inputs.index(i)] s = s_inputs[inputs.index(i)]
else: else:
s = scalar.Scalar(i.dtype).make_variable() s = scalar.get_scalar_type(i.dtype).make_variable()
try: try:
if theano.config.compute_test_value != 'off': if theano.config.compute_test_value != 'off':
v = gof.op.get_test_value(i) v = gof.op.get_test_value(i)
......
...@@ -318,11 +318,11 @@ class Subtensor(Op): ...@@ -318,11 +318,11 @@ class Subtensor(Op):
if (isinstance(entry, gof.Variable) if (isinstance(entry, gof.Variable)
and entry.type in tensor_types and entry.type in tensor_types
and numpy.all(entry.type.broadcastable)): and numpy.all(entry.type.broadcastable)):
return scal.Scalar(entry.type.dtype) return scal.get_scalar_type(entry.type.dtype)
elif (isinstance(entry, gof.Type) elif (isinstance(entry, gof.Type)
and entry in tensor_types and entry in tensor_types
and numpy.all(entry.broadcastable)): and numpy.all(entry.broadcastable)):
return scal.Scalar(entry.dtype) return scal.get_scalar_type(entry.dtype)
elif slice_ok and isinstance(entry, slice): elif slice_ok and isinstance(entry, slice):
a = entry.start a = entry.start
b = entry.stop b = entry.stop
......
...@@ -2838,7 +2838,7 @@ def test_local_mul_specialize(): ...@@ -2838,7 +2838,7 @@ def test_local_mul_specialize():
nodes = [node.op for node in f.maker.fgraph.toposort()] nodes = [node.op for node in f.maker.fgraph.toposort()]
print nodes print nodes
theano.printing.debugprint(f) theano.printing.debugprint(f)
assert nodes == [T.mul, inplace.neg_inplace] assert nodes == [T.mul]
f = function([v, m], v * 0 * (-m), mode=mode) f = function([v, m], v * 0 * (-m), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()] nodes = [node.op for node in f.maker.fgraph.toposort()]
...@@ -2852,6 +2852,12 @@ def test_local_mul_specialize(): ...@@ -2852,6 +2852,12 @@ def test_local_mul_specialize():
theano.printing.debugprint(f) theano.printing.debugprint(f)
assert nodes == [T.mul] assert nodes == [T.mul]
f = function([v, m], v * (-1) * m, mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
print nodes
theano.printing.debugprint(f)
assert nodes == [T.mul]
def speed_local_pow_specialize_range(): def speed_local_pow_specialize_range():
val = numpy.random.rand(1e7) val = numpy.random.rand(1e7)
...@@ -4000,27 +4006,6 @@ def test_local_join_1(): ...@@ -4000,27 +4006,6 @@ def test_local_join_1():
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
def test_local_mul_to_neg():
"""
Test that a multiplication by -1 or -1.0 yields the appropriate data type
"""
a = T.imatrix()
f1 = theano.function([a], -1 * a)
f2 = theano.function([a], -1.0 * a)
aval = numpy.random.randint(0, 10, (2, 2)).astype('int32')
if config.cast_policy == 'custom':
assert f1(aval).dtype == a.dtype
assert f2(aval).dtype == 'float64'
elif config.cast_policy == 'numpy':
assert f1(aval).dtype == str(numpy.array(0).dtype)
assert f2(aval).dtype == 'float64'
elif config.cast_policy == 'numpy+floatX':
assert f1(aval).dtype == str(numpy.array(0).dtype)
assert f2(aval).dtype == config.floatX
else:
raise NotImplementedError(config.cast_policy)
def test_local_add_specialize(): def test_local_add_specialize():
# test of non-zero dimension # test of non-zero dimension
a = tensor.vector() a = tensor.vector()
......
...@@ -240,7 +240,7 @@ class TensorType(Type): ...@@ -240,7 +240,7 @@ class TensorType(Type):
% (self.__class__.__name__, self.dtype)) % (self.__class__.__name__, self.dtype))
def to_scalar_type(self): def to_scalar_type(self):
return scal.Scalar(dtype=self.dtype) return scal.get_scalar_type(dtype=self.dtype)
def __eq__(self, other): def __eq__(self, other):
"""Compare True iff other is the same kind of TensorType""" """Compare True iff other is the same kind of TensorType"""
...@@ -538,23 +538,23 @@ class TensorType(Type): ...@@ -538,23 +538,23 @@ class TensorType(Type):
def c_headers(self): def c_headers(self):
"""Override `CLinkerObject.c_headers` """ """Override `CLinkerObject.c_headers` """
return scal.Scalar(self.dtype).c_headers() return scal.get_scalar_type(self.dtype).c_headers()
def c_libraries(self): def c_libraries(self):
return scal.Scalar(self.dtype).c_libraries() return scal.get_scalar_type(self.dtype).c_libraries()
def c_compile_args(self): def c_compile_args(self):
return scal.Scalar(self.dtype).c_compile_args() return scal.get_scalar_type(self.dtype).c_compile_args()
def c_support_code(self): def c_support_code(self):
"""Override `CLinkerObject.c_support_code` """ """Override `CLinkerObject.c_support_code` """
return scal.Scalar(self.dtype).c_support_code() return scal.get_scalar_type(self.dtype).c_support_code()
def c_init_code(self): def c_init_code(self):
return scal.Scalar(self.dtype).c_init_code() return scal.get_scalar_type(self.dtype).c_init_code()
def c_code_cache_version(self): def c_code_cache_version(self):
scalar_version = scal.Scalar(self.dtype).c_code_cache_version() scalar_version = scal.get_scalar_type(self.dtype).c_code_cache_version()
if scalar_version: if scalar_version:
return (11,) + scalar_version return (11,) + scalar_version
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论