提交 ba0db2f3 authored 作者: Frederic Bastien's avatar Frederic Bastien

make insert_inplace_optimization faster.

上级 c981568e
......@@ -92,6 +92,10 @@ def broadcast_like(value, template, env):
return rval
theano.configparser.AddConfigVar('tensor.insert_inplace_optimizer_validate_nb',
"-1: auto, if graph have less then 500 nodes 1, else 10",
theano.configparser.IntParam(-1))
@gof.optimizer
def insert_inplace_optimizer(env):
"""
......@@ -107,16 +111,48 @@ def insert_inplace_optimizer(env):
x + y + z -> x += y += z
(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)
"""
#we should not validate too often as this take too much time to execute!
#It is the _dfs_toposort() fct in theano/gof/destroyhandler.py
#that take so much time.
#Should we try to use another lib that do toposort?
# igraph: http://igraph.sourceforge.net/
# networkx: https://networkx.lanl.gov/
#Should we try to use cython?
# compiling only that fct is not enought, should we try to add the deque class too?
# and init the deque and other list to an upper bound number of element?
#Should Theano do online toposort as in http://code.google.com/p/acyclic/?
#
#The next longuest optimizer is the canonizer phase
#Then I think it is the [io_?]toposort(need to validate) so check if the solution is also applicable their.
#we execute validate after this number of change.
validate_each_change = config.tensor.insert_inplace_optimizer_validate_nb
if validate_each_change==-1:
if len(env.nodes)>500:
validate_each_change = 10
else: validate_each_change = 1
nb_change_no_validate = 0
chk = env.checkpoint()
for node in list(graph.io_toposort(env.inputs, env.outputs)):
op = node.op
if not isinstance(op, Elemwise):
continue
baseline = op.inplace_pattern
protected_inputs = [f.protected for f in node.env._features if isinstance(f,theano.compile.function_module.Supervisor)]
protected_inputs = sum(protected_inputs,[])#flatten the list
protected_inputs.extend(env.outputs)
candidate_outputs = [i for i in xrange(len(node.outputs)) if i not in baseline]
#Constant and input that are already destroyed can't be used. Remove here as faster.
#node inputs that are Constant, already destroyed,
# env protected inputs and env outputs can't be used as inplace target.
# Remove here as faster.
candidate_inputs = [i for i in xrange(len(node.inputs)) if i not in baseline.values() \
and not isinstance(node.inputs[i],Constant)\
and not env.destroyers(node.inputs[i])]
and not env.destroyers(node.inputs[i])\
and node.inputs[i] not in protected_inputs]
raised_warning = False
for candidate_output in candidate_outputs:
for candidate_input in candidate_inputs:
#remove inputs that don't have the same dtype as the output.
......@@ -137,14 +173,34 @@ def insert_inplace_optimizer(env):
for i in xrange(len(node.outputs))]))
new = Elemwise(new_scal,inplace_pattern).make_node(*node.inputs)
env.replace_all_validate(zip(node.outputs, new.outputs),
for r,new_r in zip(node.outputs,new.outputs):
env.replace(r,new_r,
reason="insert_inplace_optimizer")
nb_change_no_validate +=1
if nb_change_no_validate >= validate_each_change:
env.validate()
chk = env.checkpoint()
nb_change_no_validate = 0
except (ValueError, TypeError, InconsistencyError), e:
if validate_each_change!=1 and not raised_warning:
print >> sys.stderr, "Their was some inplace optimization that was not done due to unexpected error:"
print >> sys.stderr, e
raised_warning = True
env.revert(chk)
continue
candidate_inputs.remove(candidate_input)
node = new
baseline = inplace_pattern
break
if nb_change_no_validate>0:
try:
env.validate()
except Exception, e:
if not raised_warning:
print >> sys.stderr, "Their was some inplace optimization that was not done due to unexpected error"
env.revert(chk)
compile.optdb.register('inplace_opt', insert_inplace_optimizer, 75, 'fast_run', 'inplace')
def register_canonicalize(lopt, *tags, **kwargs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论