提交 301305e8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

merge

...@@ -1169,6 +1169,58 @@ def local_subtensor_lift(node): ...@@ -1169,6 +1169,58 @@ def local_subtensor_lift(node):
new_inputs.append(i.dimshuffle(['x']*node.outputs[0].ndim)) new_inputs.append(i.dimshuffle(['x']*node.outputs[0].ndim))
return [u.owner.op(*new_inputs)] return [u.owner.op(*new_inputs)]
def greedy_local_optimizer( list_optimizations, out):
'''
This function traverses the computation graph described by
``node`` and applies each of the local_optimizations on
all the nodes in the graph once.
Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor.
'''
def local_recursive_function( list_opt, out, optimized_vars, depth):
if not out.owner :
return [out]
node = out.owner
for idx, inp in enumerate(node.inputs):
if inp in optimized_vars:
nw_in = optimized_vars[inp]
else:
if inp.owner:
outs, optimized_vars = local_recursive_function(
list_opt
, inp
, optimized_vars
, depth+1)
for k,v in zip(inp.owner.outputs, outs):
optimized_vars[k] = v
nw_in = outs[inp.owner.outputs.index(inp)]
else:
nw_in = inp
optimized_vars[inp] = inp
node.inputs[idx] = nw_in
results = node.outputs
for opt in list_opt:
ret = opt.transform(node)
if ret is not False and ret is not None:
assert len(ret) == len(node.outputs)
for k,v in zip(node.outputs, ret):
optimized_vars[k] = v
results = ret
if ret[0].owner :
node = out.owner
else:
break
return results, optimized_vars
final_outs, optimized_nodes = local_recursive_function(
list_optimizations, out, {}, 0)
return final_outs[0]
def merge_two_slices(slice1, len1, slice2, len2): def merge_two_slices(slice1, len1, slice2, len2):
''' '''
This function merges two slices into a single slice. The code works on This function merges two slices into a single slice. The code works on
...@@ -1184,18 +1236,7 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -1184,18 +1236,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
``len1`` is the length of the tensor **before** applying the first slice, ``len1`` is the length of the tensor **before** applying the first slice,
while ``len2`` is the length **after** applying the first slice. while ``len2`` is the length **after** applying the first slice.
''' '''
def const_fold(n): list_opt = [ constant_folding, local_remove_switch_const_cond ]
while True:
ret = constant_folding.transform(n)
if ret is not False and ret is not None:
#print n,ret
assert len(ret)==len(n.outputs)
assert len(ret)==1
n = ret[0].owner
else: break
return n.outputs
if type(slice1) is not slice: if type(slice1) is not slice:
...@@ -1292,10 +1333,9 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -1292,10 +1333,9 @@ def merge_two_slices(slice1, len1, slice2, len2):
step = T.switch( T.lt(reverse2*reverse1,0),n_step, p_step) step = T.switch( T.lt(reverse2*reverse1,0),n_step, p_step)
start = T.switch(T.le(flen,0), 0, start) start = T.switch(T.le(flen,0), 0, start)
stop = T.switch(T.le(flen,0), 0, stop) stop = T.switch(T.le(flen,0), 0, stop)
start = greedy_local_optimizer( list_opt, start)
start = const_fold(start.owner)[0] stop = greedy_local_optimizer( list_opt, stop)
stop = const_fold(stop.owner)[0] step = greedy_local_optimizer( list_opt, step)
step = const_fold(step.owner)[0]
start = theano.printing.Print('start')(start) start = theano.printing.Print('start')(start)
stop = theano.printing.Print('stop')(stop) stop = theano.printing.Print('stop')(stop)
......
...@@ -1544,13 +1544,13 @@ class test_local_subtensor_merge(unittest.TestCase): ...@@ -1544,13 +1544,13 @@ class test_local_subtensor_merge(unittest.TestCase):
#print topo[-1].op #print topo[-1].op
assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp) assert isinstance(topo[-1].op, theano.compile.function_module.DeepCopyOp)
b1r = self.rng.permutation(range(-8,8))[:4] b1r = self.rng.permutation(range(-8,8))[:2]
e1r = self.rng.permutation(range(-8,8))[:4] e1r = self.rng.permutation(range(-8,8))[:2]
b2r = self.rng.permutation(range(-8,8))[:4] b2r = self.rng.permutation(range(-8,8))[:2]
e2r = self.rng.permutation(range(-8,8))[:4] e2r = self.rng.permutation(range(-8,8))[:2]
s1r = self.rng.permutation([-7,-6,-5,-4,-3,-2,-1,1,2,3,4,5,6,7])[:4] s1r = self.rng.permutation([-7,-6,-5,-4,-3,-2,-1,1,2,3,4,5,6,7])[:2]
s2r = self.rng.permutation([-7,-6,-5,-4,-3,-2,-1,1,2,3,4,5,6,7])[:4] s2r = self.rng.permutation([-7,-6,-5,-4,-3,-2,-1,1,2,3,4,5,6,7])[:2]
for x_s in self.x_shapes: for x_s in self.x_shapes:
x_val = self.rng.uniform(size=x_s).astype(config.floatX) x_val = self.rng.uniform(size=x_s).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论