提交 d704a600 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5562 from nouiz/opt_prof

Speed merge optimizer and allow partial optimizer profile
...@@ -1466,6 +1466,10 @@ class FunctionMaker(object): ...@@ -1466,6 +1466,10 @@ class FunctionMaker(object):
theano.config.traceback.limit = theano.config.traceback.compile_limit theano.config.traceback.limit = theano.config.traceback.compile_limit
start_optimizer = time.time() start_optimizer = time.time()
# In case there is an error during optimization.
optimizer_profile = None
opt_time = None
# now optimize the graph # now optimize the graph
if theano.config.cache_optimizations: if theano.config.cache_optimizations:
optimizer_profile = self.optimize_graph_with_cache( optimizer_profile = self.optimize_graph_with_cache(
...@@ -1475,8 +1479,23 @@ class FunctionMaker(object): ...@@ -1475,8 +1479,23 @@ class FunctionMaker(object):
end_optimizer = time.time() end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer opt_time = end_optimizer - start_optimizer
_logger.debug('Optimizing took %f seconds', opt_time)
# Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:
theano.config.compute_test_value = compute_test_value_orig
theano.config.traceback.limit = limit_orig
# If the optimizer got interrupted
if opt_time is None:
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
theano.compile.profiling.total_graph_opt_time += opt_time theano.compile.profiling.total_graph_opt_time += opt_time
if profile: if profile:
if (optimizer_profile is None and
hasattr(optimizer, 'pre_profile')):
optimizer_profile = optimizer.pre_profile
profile.optimizer_time += opt_time profile.optimizer_time += opt_time
if theano.config.profile_optimizer: if theano.config.profile_optimizer:
profile.optimizer_profile = (optimizer, profile.optimizer_profile = (optimizer,
...@@ -1486,13 +1505,6 @@ class FunctionMaker(object): ...@@ -1486,13 +1505,6 @@ class FunctionMaker(object):
warnings.warn(( warnings.warn((
"config.profile_optimizer requires config.profile to " "config.profile_optimizer requires config.profile to "
" be set to True as well"), stacklevel=3) " be set to True as well"), stacklevel=3)
_logger.debug('Optimizing took %f seconds', opt_time)
# Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:
theano.config.compute_test_value = compute_test_value_orig
theano.config.traceback.limit = limit_orig
# initialize the linker # initialize the linker
if not hasattr(linker, 'accept'): if not hasattr(linker, 'accept'):
...@@ -1784,7 +1796,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1784,7 +1796,7 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
if isinstance(mode, (list, tuple)): # "mode comparison" semantics if isinstance(mode, (list, tuple)): # "mode comparison" semantics
raise Exception("We do not support the passing of multiple modes") raise Exception("We do not support the passing of multiple modes")
else: try:
Maker = getattr(mode, 'function_maker', FunctionMaker) Maker = getattr(mode, 'function_maker', FunctionMaker)
fn = Maker(inputs, fn = Maker(inputs,
outputs, outputs,
...@@ -1794,10 +1806,11 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False, ...@@ -1794,10 +1806,11 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_keys=output_keys).create( output_keys=output_keys).create(
defaults) defaults)
finally:
t2 = time.time() t2 = time.time()
if profile: if profile:
profile.compile_time += t2 - t1 profile.compile_time += t2 - t1
# TODO: append
profile.nb_nodes = len(fn.maker.fgraph.apply_nodes) profile.nb_nodes = len(fn.maker.fgraph.apply_nodes)
fn.name = name fn.name = name
......
...@@ -228,6 +228,12 @@ class SeqOptimizer(Optimizer, list): ...@@ -228,6 +228,12 @@ class SeqOptimizer(Optimizer, list):
nb_node_before = len(fgraph.apply_nodes) nb_node_before = len(fgraph.apply_nodes)
sub_profs = [] sub_profs = []
nb_nodes = [] nb_nodes = []
self.pre_profile = (
self, l, -1, -1, nb_node_before,
-1, sub_profs, sub_validate_time,
nb_nodes, {})
try:
for optimizer in self: for optimizer in self:
try: try:
nb_nodes_before = len(fgraph.apply_nodes) nb_nodes_before = len(fgraph.apply_nodes)
...@@ -248,6 +254,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -248,6 +254,7 @@ class SeqOptimizer(Optimizer, list):
continue continue
else: else:
raise raise
finally:
if fgraph.profile: if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before validate_time = fgraph.profile.validate_time - validate_before
...@@ -262,11 +269,12 @@ class SeqOptimizer(Optimizer, list): ...@@ -262,11 +269,12 @@ class SeqOptimizer(Optimizer, list):
else: else:
validate_time = None validate_time = None
callbacks_time = {} callbacks_time = {}
callback_time = fgraph.execute_callbacks_time - callback_before callback_time = fgraph.execute_callbacks_time - callback_before
return (self, l, validate_time, callback_time, nb_node_before, self.pre_profile = (
self, l, validate_time, callback_time, nb_node_before,
len(fgraph.apply_nodes), sub_profs, sub_validate_time, len(fgraph.apply_nodes), sub_profs, sub_validate_time,
nb_nodes, callbacks_time) nb_nodes, callbacks_time)
return self.pre_profile
def __str__(self): def __str__(self):
return "SeqOpt(%s)" % list.__str__(self) return "SeqOpt(%s)" % list.__str__(self)
...@@ -877,6 +885,12 @@ class MergeOptimizer(Optimizer): ...@@ -877,6 +885,12 @@ class MergeOptimizer(Optimizer):
pairs = [(pairs[0][1], pairs[0][0])] pairs = [(pairs[0][1], pairs[0][0])]
try: try:
# If all Constants, no need to call validate.
# Only need to check one of the var of each pairs.
# If it is a Constant, the other must also be a Constant as we merge them.
if all([isinstance(old, graph.Constant) for old, new in pairs]):
fgraph.replace_all(pairs, 'MergeOptimizer')
else:
fgraph.replace_all_validate(pairs, 'MergeOptimizer') fgraph.replace_all_validate(pairs, 'MergeOptimizer')
except InconsistencyError: except InconsistencyError:
success = False success = False
......
...@@ -359,9 +359,7 @@ class TestAutoName: ...@@ -359,9 +359,7 @@ class TestAutoName:
assert r1 is r2 assert r1 is r2
r3 = tensor.constant(1.6) r3 = tensor.constant(1.6)
# The cache still create a new object that we don't return. assert r3.auto_name == "auto_" + str(autoname_id + 1)
# This is why we must increase by 2 and not 1.
assert r3.auto_name == "auto_" + str(autoname_id + 2)
def test_tensorvariable(self): def test_tensorvariable(self):
# Get counter value # Get counter value
......
...@@ -855,35 +855,60 @@ class Validator(object): ...@@ -855,35 +855,60 @@ class Validator(object):
If out is not valid and has no equivalent, None is returned. If out is not valid and has no equivalent, None is returned.
""" """
def get_value(out):
if out in self.valid: if out in self.valid:
return out, True return out, True
elif out in self.valid_equivalent: elif out in self.valid_equivalent:
return self.valid_equivalent[out], False return self.valid_equivalent[out], False
elif out in self.invalid: elif out in self.invalid:
return None return None
else:
raise RuntimeError("This should not happen")
q = [out]
while q:
out = q.pop()
if out in self.valid:
continue
elif out in self.invalid:
continue
if out.owner is None: if out.owner is None:
if isinstance(out, tensor.TensorConstant): if isinstance(out, tensor.TensorConstant):
# This might be a constant from the outer graph or a constant if hasattr(out, 'fgraph'):
# from the inner graph. In all cases, we can clone it to be # If out have an fgraph, we aren't sure if it
# certain we have a valid constant # is from the inner graph or outer graph, so
# clone it.
cloned_out = out.clone() cloned_out = out.clone()
self.valid.add(cloned_out) self.valid.add(cloned_out)
self.invalid.add(out) self.invalid.add(out)
self.valid_equivalent[out] = cloned_out self.valid_equivalent[out] = cloned_out
return cloned_out, False
else: else:
# This is an input node and it has not been explicitly marked self.valid.add(out)
# as invalid so we can use it continue
return out, True else:
# This is an input node and it has not been
# explicitly marked as invalid so we can use it
self.valid.add(out)
continue
# Recurse over inputs # Process the input if needed
inputs = [self.check(i) for i in out.owner.inputs] continue_while = False
for inp in out.owner.inputs:
if inp not in self.valid and inp not in self.invalid:
q.append(out)
q.extend(out.owner.inputs)
continue_while = True
break
if continue_while:
continue
inputs = [get_value(i) for i in out.owner.inputs]
# If some inputs are invalid without equivalent, so is out # If some inputs are invalid without equivalent, so is out
if None in inputs: if None in inputs:
self.invalid.add(out) self.invalid.add(out)
return None continue
# If some inputs are invalid with equivalent, # If some inputs are invalid with equivalent,
# an equivalent out should be built and returned # an equivalent out should be built and returned
...@@ -895,10 +920,12 @@ class Validator(object): ...@@ -895,10 +920,12 @@ class Validator(object):
self.invalid.add(out) self.invalid.add(out)
self.valid.add(cloned_out) self.valid.add(cloned_out)
self.valid_equivalent[out] = cloned_out self.valid_equivalent[out] = cloned_out
return cloned_out, False continue
# All inputs are valid, so is out # All inputs are valid, so is out
return out, True self.valid.add(out)
return get_value(out)
def scan_can_remove_outs(op, out_idxs): def scan_can_remove_outs(op, out_idxs):
......
...@@ -113,10 +113,7 @@ class TestGaussNewton(unittest.TestCase): ...@@ -113,10 +113,7 @@ class TestGaussNewton(unittest.TestCase):
def test_nobatch(self): def test_nobatch(self):
# This used to give an error due to optimization "scan_merge_inouts". # This used to give an error due to optimization "scan_merge_inouts".
# The batch size is set to 1 and the data is represented by a matrix. # The batch size is set to 1 and the data is represented by a matrix.
# As of 2013-10-24, it still triggers an optimization error due to self._run(100, 10, batch_size=1, mode=mode)
# "remove_constants_and_unused_inputs_scan".
mode_exc = mode.excluding("remove_constants_and_unused_inputs_scan")
self._run(100, 10, batch_size=1, mode=mode_exc)
class GaussNewtonMatrix(object): class GaussNewtonMatrix(object):
......
...@@ -19,7 +19,7 @@ from theano.gof.type import Generic ...@@ -19,7 +19,7 @@ from theano.gof.type import Generic
from theano.tensor import elemwise from theano.tensor import elemwise
from theano.tensor.var import (AsTensorError, TensorVariable, from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstant, TensorConstant, TensorConstantSignature,
_tensor_py_operators) _tensor_py_operators)
from theano.tensor.type import TensorType, values_eq_approx_always_true from theano.tensor.type import TensorType, values_eq_approx_always_true
from theano.tensor.type_other import NoneConst from theano.tensor.type_other import NoneConst
...@@ -220,7 +220,7 @@ _as_tensor_variable = as_tensor_variable ...@@ -220,7 +220,7 @@ _as_tensor_variable = as_tensor_variable
as_tensor = as_tensor_variable as_tensor = as_tensor_variable
def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): def constant(x, name=None, ndim=None, dtype=None):
"""Return a symbolic `Constant` with value `x`. """Return a symbolic `Constant` with value `x`.
Raises Raises
...@@ -230,6 +230,16 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -230,6 +230,16 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
ValueError ValueError
`x` could not be expanded to have ndim dimensions. `x` could not be expanded to have ndim dimensions.
Note
----
We create a small cache of frequently used constant.
This speed up the Merge optimization for big graph.
We want to cache all scalar to don't merge as frequently constants.
But we don't want to cache too much stuff.
So we cache integer with dtype [u]int and float where the value is
between -10 and 10.
We cache all broadcast pattern for scalar.
""" """
x_ = scal.convert(x, dtype=dtype) x_ = scal.convert(x, dtype=dtype)
...@@ -245,45 +255,29 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -245,45 +255,29 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
assert len(bcastable) == ndim assert len(bcastable) == ndim
try: try:
if rtype is TensorConstant: ttype = TensorType(dtype=x_.dtype, broadcastable=bcastable)
rval = rtype( if not constant.enable:
TensorType(dtype=x_.dtype, broadcastable=bcastable), return TensorConstant(ttype, x_, name=name)
x_.copy(),
name=name)
return rval
else:
# leave the shape out of the type
return rtype(TensorType(dtype=x_.dtype, broadcastable=bcastable),
x_, name=name)
except Exception:
raise TypeError("Could not convert %s to TensorType" % x, type(x))
sig = TensorConstantSignature((ttype, x_))
if sig in constant_cache:
return constant_cache[sig]
def constant(x, name=None, ndim=None, dtype=None): ret = TensorConstant(ttype, x_, name=name)
ret = constant_or_value(x, rtype=TensorConstant, name=name, ndim=ndim, if (x_.size == 1 and
dtype=dtype) (-10) <= x_ <= 10 and
(x_.dtype in int_dtypes or x_.dtype in uint_dtypes or
# We create a small cache of frequently used constant. (x_.dtype in float_dtypes and
# This speed up the Merge optimization for big graph.
# We want to cache all scalar to don't merge as frequently constants.
# But we don't want to cache too much stuff
# So we cache integer with dtype [u]int and float where the value is
# between -10 and 10
# We want to cache all broadcast pattern for scalar.
if not constant.enable:
return ret
sig = ret.signature()
if (sig not in constant_cache and ret.data.size == 1 and
(-10) <= ret.data <= 10 and
(ret.dtype in int_dtypes or ret.dtype in uint_dtypes or
(ret.dtype in float_dtypes and
# Limit the size of the cache. # Limit the size of the cache.
len(constant_cache) < 10000))): len(constant_cache) < 10000))):
constant_cache[sig] = ret constant_cache[sig] = ret
# This is needed to raise a good error to the user. # This is needed to raise a good error to the user.
ret.cached = True ret.cached = True
return ret
except Exception:
raise TypeError("Could not convert %s to TensorType" % x, type(x))
return constant_cache.get(sig, ret)
constant.enable = True constant.enable = True
constant_cache = {} constant_cache = {}
......
...@@ -3251,6 +3251,9 @@ def local_IncSubtensor_serialize(node): ...@@ -3251,6 +3251,9 @@ def local_IncSubtensor_serialize(node):
if movable_inputs: if movable_inputs:
new_inputs = ([i for i in node.inputs if not movable(i)] + new_inputs = ([i for i in node.inputs if not movable(i)] +
[mi.owner.inputs[0] for mi in movable_inputs]) [mi.owner.inputs[0] for mi in movable_inputs])
if len(new_inputs) == 0:
new_add = new_inputs[0]
else:
new_add = T.add(*new_inputs) new_add = T.add(*new_inputs)
# Copy over stacktrace from original output, as an error # Copy over stacktrace from original output, as an error
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论