提交 c5dbc0d7 authored 作者: Frederic Bastien's avatar Frederic Bastien

Don't pre-apply optimization to node in the env.

上级 0bf76bf5
......@@ -1170,19 +1170,24 @@ def local_subtensor_lift(node):
return [u.owner.op(*new_inputs)]
def greedy_local_optimizer( list_optimizations, out):
def greedy_local_optimizer(list_optimizations, out, no_opt):
'''
This function traverses the computation graph described by
``node`` and applies each of the local_optimizations on
all the nodes in the graph once.
This function traverses the computation graph described by all
``node`` in the graph before the variable out but that are not in the env.
it applies each of the local_optimizations on the traversed graph.
Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor.
We should not apply optimizations on node that are in env.
So we don't optimize node in no_opt.
'''
def local_recursive_function( list_opt, out, optimized_vars, depth):
if not out.owner :
return [out]
node = out.owner
if node in no_opt:
return node.outputs, optimized_vars
for idx, inp in enumerate(node.inputs):
if inp in optimized_vars:
nw_in = optimized_vars[inp]
......@@ -1333,13 +1338,29 @@ def merge_two_slices(slice1, len1, slice2, len2):
step = T.switch( T.lt(reverse2*reverse1,0),n_step, p_step)
start = T.switch(T.le(flen,0), 0, start)
stop = T.switch(T.le(flen,0), 0, stop)
start = greedy_local_optimizer( list_opt, start)
stop = greedy_local_optimizer( list_opt, stop)
step = greedy_local_optimizer( list_opt, step)
start = theano.printing.Print('start')(start)
stop = theano.printing.Print('stop')(stop)
step = theano.printing.Print('step')(step)
# Find the list of nodes in the env.
# We should not optimize them here!
list_no_opt = set()
for sl in [slice1, slice2]:
if isinstance(sl, slice):
for idx in [sl.start, sl.stop, sl.step]:
if isinstance(idx, Variable):
list_no_opt.update(sl.start.env.nodes)
if isinstance(sl, Variable):
list_no_opt.update(sl.env.nodes)
# The canonical form of the slice is pretty complicated
# and is not simplified. We simplify it in advance here
# as otherwise this create too many useless optimization that
# DebugMode must check.
start = greedy_local_optimizer( list_opt, start, list_no_opt)
stop = greedy_local_optimizer( list_opt, stop, list_no_opt)
step = greedy_local_optimizer( list_opt, step, list_no_opt)
#start = theano.printing.Print('start')(start)
#stop = theano.printing.Print('stop')(stop)
#step = theano.printing.Print('step')(step)
return slice(start, stop, step)
@register_canonicalize
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论