提交 62d8c5ea authored 作者: Razvan Pascanu's avatar Razvan Pascanu

A greedy local way of applying constant folding and removing switches with

constant conditions.
上级 711f5cfd
......@@ -1169,6 +1169,58 @@ def local_subtensor_lift(node):
new_inputs.append(i.dimshuffle(['x']*node.outputs[0].ndim))
return [u.owner.op(*new_inputs)]
def greedy_local_optimizer( list_optimizations, out):
'''
This function traverses the computation graph described by
``node`` and applies each of the local_optimizations on
all the nodes in the graph once.
Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor.
'''
def local_recursive_function( list_opt, out, optimized_vars, depth):
if not out.owner :
return [out]
node = out.owner
for idx, inp in enumerate(node.inputs):
if inp in optimized_vars:
nw_in = optimized_vars[inp]
else:
if inp.owner:
outs, optimized_vars = local_recursive_function(
list_opt
, inp
, optimized_vars
, depth+1)
for k,v in zip(inp.owner.outputs, outs):
optimized_vars[k] = v
nw_in = outs[inp.owner.outputs.index(inp)]
else:
nw_in = inp
optimized_vars[inp] = inp
node.inputs[idx] = nw_in
results = node.outputs
for opt in list_opt:
ret = opt.transform(node)
if ret is not False and ret is not None:
assert len(ret) == len(node.outputs)
for k,v in zip(node.outputs, ret):
optimized_vars[k] = v
results = ret
if ret[0].owner :
node = out.owner
else:
break
return results, optimized_vars
final_outs, optimized_nodes = local_recursive_function(
list_optimizations, out, {}, 0)
return final_outs[0]
def merge_two_slices(slice1, len1, slice2, len2):
'''
This function merges two slices into a single slice. The code works on
......@@ -1184,18 +1236,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
``len1`` is the length of the tensor **before** applying the first slice,
while ``len2`` is the length **after** applying the first slice.
'''
def const_fold(n):
while True:
ret = constant_folding.transform(n)
if ret is not False and ret is not None:
#print n,ret
assert len(ret)==len(n.outputs)
assert len(ret)==1
n = ret[0].owner
else: break
return n.outputs
list_opt = [ constant_folding, local_remove_switch_const_cond ]
if type(slice1) is not slice:
......@@ -1292,10 +1333,9 @@ def merge_two_slices(slice1, len1, slice2, len2):
step = T.switch( T.lt(reverse2*reverse1,0),n_step, p_step)
start = T.switch(T.le(flen,0), 0, start)
stop = T.switch(T.le(flen,0), 0, stop)
start = const_fold(start.owner)[0]
stop = const_fold(stop.owner)[0]
step = const_fold(step.owner)[0]
start = greedy_local_optimizer( list_opt, start)
stop = greedy_local_optimizer( list_opt, stop)
step = greedy_local_optimizer( list_opt, step)
start = theano.printing.Print('start')(start)
stop = theano.printing.Print('stop')(stop)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论