提交 9eb0b187 authored 作者: Frederic Bastien's avatar Frederic Bastien

moved greedy_local_optimizer to theano/got/opt.py and renamed it to pre_greedy_local_optimizer.

上级 34123842
......@@ -1142,6 +1142,66 @@ def check_chain(r, *chain):
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
def pre_greedy_local_optimizer(list_optimizations, out):
'''
This function traverses the computation graph described by all
``node`` in the graph before the variable out but that are not in the env.
it applies each of the local_optimizations on the traversed graph.
Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor.
We should not apply optimizations on node that are in env.
So we don't optimize node that have an attribute env.
:note: This don't do an equilibrium... So if there is optimization
like local_upcast_elemwise_constant_inputs in the list, that
add additional node to the inputs of the node, it can
be needed to call this function multiple time.
'''
def local_recursive_function( list_opt, out, optimized_vars, depth):
if not out.owner :
return [out], optimized_vars
node = out.owner
if hasattr(node, 'env'):
return node.outputs, optimized_vars
for idx, inp in enumerate(node.inputs):
if inp in optimized_vars:
nw_in = optimized_vars[inp]
else:
if inp.owner:
outs, optimized_vars = local_recursive_function(
list_opt
, inp
, optimized_vars
, depth+1)
for k,v in zip(inp.owner.outputs, outs):
optimized_vars[k] = v
nw_in = outs[inp.owner.outputs.index(inp)]
else:
nw_in = inp
optimized_vars[inp] = inp
node.inputs[idx] = nw_in
results = node.outputs
for opt in list_opt:
ret = opt.transform(node)
if ret is not False and ret is not None:
assert len(ret) == len(node.outputs)
for k,v in zip(node.outputs, ret):
optimized_vars[k] = v
results = ret
if ret[0].owner :
node = out.owner
else:
break
return results, optimized_vars
final_outs, optimized_nodes = local_recursive_function(
list_optimizations, out, {}, 0)
return final_outs[0]
......
......@@ -25,7 +25,7 @@ import basic as T
from theano import compile #to register the optimizer built by this file
from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer, pre_constant_merge
from theano.gof.opt import Optimizer, pre_constant_merge, pre_greedy_local_optimizer
from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value
......@@ -1230,67 +1230,6 @@ def local_subtensor_lift(node):
return [u.owner.op(*new_inputs)]
def greedy_local_optimizer(list_optimizations, out):
'''
This function traverses the computation graph described by all
``node`` in the graph before the variable out but that are not in the env.
it applies each of the local_optimizations on the traversed graph.
Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor.
We should not apply optimizations on node that are in env.
So we don't optimize node that have an attribute env.
:note: This don't do an equilibrium... So if there is optimization
like local_upcast_elemwise_constant_inputs in the list, that
add additional node to the inputs of the node, it can
be needed to call this function multiple time.
'''
def local_recursive_function( list_opt, out, optimized_vars, depth):
if not out.owner :
return [out], optimized_vars
node = out.owner
if hasattr(node, 'env'):
return node.outputs, optimized_vars
for idx, inp in enumerate(node.inputs):
if inp in optimized_vars:
nw_in = optimized_vars[inp]
else:
if inp.owner:
outs, optimized_vars = local_recursive_function(
list_opt
, inp
, optimized_vars
, depth+1)
for k,v in zip(inp.owner.outputs, outs):
optimized_vars[k] = v
nw_in = outs[inp.owner.outputs.index(inp)]
else:
nw_in = inp
optimized_vars[inp] = inp
node.inputs[idx] = nw_in
results = node.outputs
for opt in list_opt:
ret = opt.transform(node)
if ret is not False and ret is not None:
assert len(ret) == len(node.outputs)
for k,v in zip(node.outputs, ret):
optimized_vars[k] = v
results = ret
if ret[0].owner :
node = out.owner
else:
break
return results, optimized_vars
final_outs, optimized_nodes = local_recursive_function(
list_optimizations, out, {}, 0)
return final_outs[0]
def merge_two_slices(slice1, len1, slice2, len2):
'''
This function merges two slices into a single slice. The code works on
......@@ -1408,12 +1347,12 @@ def merge_two_slices(slice1, len1, slice2, len2):
# and is not simplified. We simplify it in advance here
# as otherwise this create too many useless optimization that
# DebugMode must check.
start = greedy_local_optimizer( list_opt, start)
stop = greedy_local_optimizer( list_opt, stop)
step = greedy_local_optimizer( list_opt, step)
start = greedy_local_optimizer( list_opt, start)
stop = greedy_local_optimizer( list_opt, stop)
step = greedy_local_optimizer( list_opt, step)
start = pre_greedy_local_optimizer( list_opt, start)
stop = pre_greedy_local_optimizer( list_opt, stop)
step = pre_greedy_local_optimizer( list_opt, step)
start = pre_greedy_local_optimizer( list_opt, start)
stop = pre_greedy_local_optimizer( list_opt, stop)
step = pre_greedy_local_optimizer( list_opt, step)
#Pre merge constant for the same reason.
start, stop, step = pre_constant_merge([start, stop, step])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论