提交 83c7e294 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

I changed the perform of scan to use directly the linker to execute the

inner function, without creating a special type of function. This reduced understanding complexity, but also reduced a potential source of bugs ( when unpickling, or having alias outputs). I also added a new test for the bug Arnaud discovered, which this fix also adresses in a more broad way.
上级 b5fe3cd1
......@@ -846,15 +846,7 @@ def scan( fn
info['inplace'] = False
info['gpu'] = False
revised_outs = []
for o in new_outs:
if (o in inner_inputs or
isinstance(o, tensor.Constant)):
revised_outs.append( scan_utils.cloneOp(o))
else:
revised_outs.append(o)
local_op = scan_op.Scan( inner_inputs, revised_outs, info )
local_op = scan_op.Scan( inner_inputs, new_outs, info )
##
### Step 8. Compute the outputs using the scan op
......
......@@ -18,7 +18,8 @@ import logging
import numpy
import sys
from theano.compile import SharedVariable, function, Param
from theano.compile import SharedVariable, function, Param, Out
from theano.compile.function_module import ViewOp, DeepCopyOp
from theano import compile
from theano import gradient
from theano.gof.python25 import all
......@@ -166,47 +167,25 @@ class Scan(Op):
self.info['name'] = self.name
self.info['mode_instance'] = self.mode_instance
if isinstance(self.mode_instance, compile.debugmode.DebugMode):
theano_fn = function(
inputs
, outputs
, mode = self.mode_instance
, name = self.name )
def fn_wrapper(ins_storage, outs_storage):
'''
Wrap theano_fn to have same interface as scan_utils's
scan_function
'''
outputs = theano_fn(*ins_storage)
for (out,out_storage) in zip( outputs, outs_storage):
if out_storage[0] is not None and out_storage[0].shape:
out_storage[0][:] = out
elif out_storage[0] is not None:
out_storage[0].itemset(out)
return [[o] for o in outputs ]
self.fn = fn_wrapper
self.fn.maker = scan_utils.EmptyObject()
self.fn.maker.inputs = inputs
self.fn.maker.outputs = outputs
self.fn.maker.env = theano_fn.maker.env
self.mask = [ 0 for x in xrange(self.n_shared_outs)]
else:
self.mask, self.fn = scan_utils.scan_function(
inputs
, outputs
, nonmutable
, mode = self.mode_instance
, name = self.name
, slices = ( self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot )
)
# check for shared variables in the inputs
assert not numpy.any( [isinstance(x, SharedVariable) for x
in self.fn.maker.inputs])
wrapped_inputs = [Param(x,borrow=True) for x in inputs ]
wrapped_outputs = [Out(x, borrow=True) for x in outputs ]
self.fn = function(wrapped_inputs,
wrapped_outputs,
mode = self.mode_instance,
name = self.name )
self.mask = [ 0 for x in xrange(self.n_shared_outs) ]
# If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of
# scan is done
slices = ( self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot )
for i in xrange(slices, slices+self.n_shared_outs):
if isinstance(self.fn.maker.env.outputs[i].owner.op,
ViewOp):
self.mask[i-slices] = 1
# Pre-computing some values to speed up perform
self.mintaps = [ numpy.min(x) for x in self.tap_array]
......@@ -406,10 +385,8 @@ class Scan(Op):
if n_steps < 0:
n_steps = abs(n_steps)
seqs = [ seq[::-1] for seq in args[1:self.seqs_arg_offset]]
seqs = zip( seqs, self.vector_seqs )
else:
seqs = args[1:self.seqs_arg_offset]
seqs = zip( seqs, self.vector_seqs )
# 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containting the length of each output
......@@ -447,62 +424,81 @@ class Scan(Op):
offset = self.nit_sot_arg_offset + self.n_nit_sot
other_args = args[offset:]
zipped_outs = [(outs[idx], self.vector_outs[idx], tap,
store_steps[idx], idx) for idx in xrange(self.n_outs)
for tap in self.tap_array[idx] ]
end = self.n_outs + self.n_nit_sot
sot_outs = zip( outs[self.n_mit_mot:end]
, self.vector_outs[self.n_mit_mot:end]
, store_steps[self.n_mit_mot:end]
, range(self.n_mit_mot, end ))
input_storage = self.fn.input_storage
output_storage = self.fn.output_storage
fn = self.fn.fn
offset = ( self.n_seqs + sum(map(len, self.tap_array[:self.n_outs])) +
self.n_shared_outs)
for idx in xrange(len(other_args)):
input_storage[idx+offset].storage[0] = other_args[idx]
############## THE MAIN LOOP #########################
for i in xrange(n_steps):
# sequences over which scan iterates
# 3. collect input slices
if i == 1 and self.n_nit_sot > 0 :
sot_outs = zip( outs[self.n_mit_mot:end]
, self.vector_outs[self.n_mit_mot:end]
, store_steps[self.n_mit_mot:end]
, range(self.n_mit_mot, end ))
for idx in xrange(self.n_seqs):
if self.vector_seqs[idx]:
input_storage[idx].storage[0] = seqs[idx][i:i+1].reshape(())
else:
input_storage[idx].storage[0] = seqs[idx][i]
offset = self.n_seqs
for idx in xrange(self.n_outs):
if self.vector_outs[idx]:
for tap in self.tap_array[idx]:
_idx = (pos[idx]+tap)%store_steps[idx]
input_storage[offset].storage[0] =\
outs[idx][0][_idx:_idx+1].reshape(())
offset += 1
else:
for tap in self.tap_array[idx]:
_idx = (pos[idx]+tap)%store_steps[idx]
input_storage[offset].storage[0] = outs[idx][0][_idx]
offset += 1
fn_args = [ seq[i:i+1].reshape(()) if c else seq[i]
for seq,c in seqs]
fn_args += [ out[0][(pos[j]+tap)%sz:
(pos[j]+tap)%sz+1].reshape(())
if c else out[0][(pos[j]+tap)%sz]
for (out, c, tap, sz, j) in zipped_outs ]
a_offset = self.shared_arg_offset
o_offset = self.n_outs + self.n_nit_sot
fn_args += [ args[a_offset+j] if i==0 else outs[o_offset+j][0]
for j in xrange(self.n_shared_outs) ]
fn_args += other_args
if i == 0:
for j in xrange(self.n_shared_outs):
input_storage[offset].storage[0] = args[a_offset+j]
offset += 1
else:
for j in xrange(self.n_shared_outs):
input_storage[offset].storage[0] = outs[o_offset+j][0]
offset += 1
# 4. collecting slices where the output should be stored
fn_out_storage = [ [None] for x in xrange(self.n_mit_mot_outs)]
if i == 0 and self.n_nit_sot > 0:
fn_out_storage += [
[None] if store == 1 or c else [out[0][pos[j]]]
for out,c,store,j in sot_outs[:-self.n_nit_sot] ]
fn_out_storage += [[None]]*self.n_nit_sot
for idx in xrange(self.n_mit_mot_outs):
output_storage[idx].storage[0] = None
offset = self.n_mit_mot_outs
if i !=0 and self.n_nit_sot >0:
for idx in xrange(self.n_outs + self.n_nit_sot -
self.n_mit_mot):
if ( store_steps[idx+self.n_mit_mot] == 1 or
self.vector_outs[idx+self.n_mit_mot]):
output_storage[idx+offset].storage[0] = None
else:
output_storage[idx+offset].storage[0] =\
outs[idx+self.n_mit_mot][0][pos[idx+self.n_mit_mot]]
else:
fn_out_storage += [
[ None ] if store == 1 or c else [out[0][pos[j]]]
for out,c,store,j in sot_outs ]
fn_out_storage += [ [None] for x in xrange(self.n_shared_outs) ]
for idx in xrange(self.n_outs + self.n_nit_sot -
self.n_mit_mot):
output_storage[idx+offset].storage[0] = None
offset += self.n_outs+self.n_nit_sot - self.n_mit_mot
for idx in xrange(self.n_shared_outs):
output_storage[idx+offset].storage[0] = None
# 5. compute outputs
something = self.fn(fn_args, fn_out_storage)
fn()
offset_out = 0
# 5.1 Copy over the values for mit_mot outputs
for j in xrange(self.n_mit_mot):
for k in self.mit_mot_out_slices[j]:
outs[j][0][k+pos[j]] = something[offset_out][0]
outs[j][0][k+pos[j]] = output_storage[offset_out].storage[0]
offset_out += 1
# 5.2 Copy over the values for mit_sot/sit_sot outputs
......@@ -511,8 +507,10 @@ class Scan(Op):
offset_out -= self.n_mit_mot
for j in xrange(begin, end):
if store_steps[j] == 1 or self.vector_outs[j]:
outs[j][0][pos[j]] = something[offset_out+j][0]
if ( store_steps[j] == 1 or self.vector_outs[j] or
outs[j][0][pos[j]] is not output_storage[offset_out+j].storage[0]):
outs[j][0][pos[j]] = output_storage[offset_out+j].storage[0]
# 5.3 Copy over the values for nit_sot outputs
begin = end
......@@ -520,10 +518,10 @@ class Scan(Op):
for j in xrange(begin,end):
if i == 0:
jout = j+offset_out
shape = (store_steps[j],) + something[jout][0].shape
if len(something[jout][0].shape) == 0:
shape = (store_steps[j],) + output_storage[jout].storage[0].shape
if len(output_storage[jout].storage[0].shape) == 0:
self.vector_outs[j] = True
dtype = something[jout][0].dtype
dtype = output_storage[jout].storage[0].dtype
if (outs[j][0] is None or
outs[j][0].shape[0] < store_steps[j] or
outs[j][0].shape[1:] != shape[1:] or
......@@ -534,9 +532,10 @@ class Scan(Op):
outs[j][0] = numpy.zeros(shape, dtype)
elif outs[j][0].shape[0] != store_steps[j]:
outs[j][0] = outs[j][0][:store_steps[j]]
outs[j][0][pos[j]] = something[jout][0]
elif store_steps[j] == 1 or self.vector_outs[j]:
outs[j][0][pos[j]] = something[j+offset_out][0]
outs[j][0][pos[j]] = output_storage[jout].storage[0]
elif (store_steps[j] == 1 or self.vector_outs[j] or
outs[j][0][pos[j]] is not output_storage[j+offset_out].storage[0]):
outs[j][0][pos[j]] = output_storage[j+offset_out].storage[0]
# 5.4 Copy over the values for outputs corresponding to shared
......@@ -545,7 +544,7 @@ class Scan(Op):
end += self.n_shared_outs
for j in xrange(begin,end):
jout = j +offset_out
outs[j][0] = something[jout][0]
outs[j][0] = output_storage[jout].storage[0]
pos = [ (idx+1)%store for idx,store in
itertools.izip(pos, store_steps)
......
......@@ -86,279 +86,6 @@ def traverse(out, x,x_copy, d):
d = traverse(inp, x, x_copy, d)
return d
class EmptyObject(object):
def __init__(self):
pass
class ScanInnerFunction(object):
"""
Stripped down, simplified version of theano.function class that has a
low overhead at calling a function.
"""
def __init__( self
, fn
, input_storage
, output_storage
, env
, inputs
, outputs
, nonmutable_indices
, mode
, name
):
self.fn = fn
self.input_storage = input_storage
self.n_ins = len(input_storage)
self.n_outs = len(output_storage)
self.outputs_storage = output_storage
self.maker = EmptyObject()
self.maker.env = env
self.maker.inputs = inputs
for i in inputs:
i.update = None
self.maker.expanded_inputs = inputs
self.maker.outputs = outputs
self.maker.nonmutable_indices = nonmutable_indices
self.maker.mode = mode
self.name = name
def __call__(self, inputs, outputs):
t0 = time.time()
# put data into the storage
for idx in xrange(self.n_ins):
self.input_storage[idx][0] = inputs[idx]
for idx in xrange(self.n_outs):
self.outputs_storage[idx][0] = outputs[idx][0]
_t0 = time.time()
self.fn()
dt_fn = time.time() - _t0
for idx in xrange(self.n_outs):
if outputs[idx][0] is not None:
if outputs[idx][0] is not self.outputs_storage[idx][0]:
if outputs[idx][0].shape:
outputs[idx][0][:] = self.outputs_storage[idx][0]
else:
outputs[idx][0].itemset(self.outputs_storage[idx][0])
dt_call = time.time() - t0
if hasattr(self.maker.mode,'fct_call_time'):
self.maker.mode.fct_call_time[self] += dt_call
self.maker.mode.fct_call[self] += 1
self.maker.mode.fn_time += dt_fn
self.maker.mode.call_time += dt_call
return self.outputs_storage
def __getstate__(self):
state = self.__dict__.copy()
del state['fn']
del state['input_storage']
del state['outputs_storage']
del state['maker'].env
return state
def __setstate__(self):
self.__dict__ = state
name = self.name
mode = self.maker.mode
inputs = self.maker.inputs
outputs = self.maker.outputs
nonmutable_indices = self.maker.nonmutable_indices
new_inputs, new_outputs = gof.graph.clone( inputs, ouputs )
env = gof.env.Env(new_inputs, new_outputs)
nonmutable = []
for idx in nonmutable_indices :
nonmutable.append( new_inputs[idx] )
env.extend(
Supervisor( inp for inp in nonmutable if
not (hasattr(env,'destroyers') and
env.destroyers(inp))))
# If named nodes are replaced, keep the name
env.extend(gof.toolbox.PreserveNames())
optimizer, linker = mode.optimizer, copy.copy(mode.linker)
# optimize the env
t0 = time.time()
optimizer(env)
_logger.debug('Optimizing took %f seconds' %(time.time() - t0))
if not hasattr(linker, 'accept'):
raise ValueError( ( "'linker' parameter of FunctionFactory "
"should be a Linker with an accept method "
"or one of %s") %
mode_module.predefined_linkers.keys())
my_linker = linker.accept ( env )
input_storage = []
output_storage = []
for input in inputs:
input_storage += [[ None ]]
for output in outputs:
output_storage += [[ None ]]
t0 = time.time()
_fn, _i,_o = my_linker.make_thunk( input_storage = input_storage,
output_storage = output_storage)
_logger.debug('Linking took %f seconds' %(time.time() - t0))
fn = ScanInnerFunction( _fn
, input_storage
, output_storage
, env)
t2 = time.time()
self.fn = _fn
self.input_storage = input_storage
self.outputs_storage = output_storage
if hasattr(mode, 'fct_call_time'):
mode.fct_call_time.setdefault(fn, 0)
if hasattr(mode, 'fct_call'):
mode.fct_call.set_default(fn,0)
def scan_function( inputs
, outputs
, nonmutable_indices = None
, mode = None
, name = None
, slices = 0
):
"""
``Constructor`` of the ScanInnerFunction ( a simplified version of
theano.function ). This should only be used internally by Scan.
:param inputs: theano variable that represent the input of the function
:param outputs: theano expression that represents the outputs of the
function
:param nonmutable_indices: the subset of indices corresponding to
nonmutable inputs
:param mode: compilation mode for the function
:param name: name of the function
"""
t1 = time.time()
mode = mode_module.get_mode(mode)
if isinstance(mode, (list, tuple)): # "mode comparison" semantics
_logger.warning('Passing multiple modes is deprecated (20091019)')
if not mode:
raise ValueError("Please provide at least one mode.")
else:
mode = mode[0]
## Replacing the Function Maker
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
new_inputs, new_outputs = gof.graph.clone( inputs, outputs )
env = gof.env.Env(new_inputs, new_outputs)
nonmutable = []
for idx in nonmutable_indices :
nonmutable.append( new_inputs[idx] )
env.extend(
Supervisor( inp for inp in nonmutable if
not (hasattr(env,'destroyers') and env.destroyers(inp))))
# If named nodes are replaced, keep the name
env.extend(gof.toolbox.PreserveNames())
optimizer, linker = mode.optimizer, copy.copy(mode.linker)
# optimize the env
t0 = time.time()
optimizer(env)
_logger.debug('Optimizing took %f seconds' %(time.time() - t0))
mask = [ 0 for x in env.outputs[slices:] ]
for i,out in enumerate(env.outputs):
if (out in env.inputs or
isinstance(out, tensor.Constant) or
out in env.outputs[i+1:]):
env.change_input('output', i, Clone()(out) )
for i in xrange(len(env.outputs[slices:])):
views_of_output_i = set()
view_tree_set(alias_root(env.outputs[i]), views_of_output_i)
copied = False
# do not allow outputs to be aliased
for j in xrange(i+1, len(env.outputs)):
if ( env.outputs[j] in views_of_output_i or
env.outputs[j] == env.outputs[i]):
mask[i] = 1
copied = True
break
if not copied:
for input_j in env.inputs:
# do not allow outputs to be aliased to an inputs (j), unless
# a) that j'th input has been 'destroyed' by e.g. in-place computations
if hasattr(env,'get_destroyers_of') and env.get_destroyers_of(input_j):
continue
if input_j in views_of_output_i:
mask[i] = 1
break
if not hasattr(linker, 'accept'):
raise ValueError( ( "'linker' parameter of FunctionFactory "
"should be a Linker with an accept method "
"or one of %s") %
mode_module.predefined_linkers.keys())
my_linker = linker.accept ( env )
input_storage = []
output_storage = []
for input in inputs:
input_storage += [[ None ]]
for output in outputs:
output_storage += [[ None ]]
t0 = time.time()
_fn, _i,_o = my_linker.make_thunk( input_storage = input_storage,
output_storage = output_storage)
_logger.debug('Linking took %f seconds' %(time.time() - t0))
if hasattr(mode, 'apply_time'):
for i, node in enumerate(env.toposort()):
mode.apply_time[(i,node)] = 0.0
assert len(_fn.thunk_groups[i])==1
mode.op_cimpl[node.op] = hasattr(_fn.thunk_groups[i][0],'cthunk')
fn = ScanInnerFunction( _fn
, input_storage
, output_storage
, env
, inputs
, outputs
, nonmutable_indices
, mode
, name
)
t2 = time.time()
if hasattr(mode, 'compile_time'):
mode.compile_time += t2-t1
if hasattr(mode, 'fct_call_time'):
mode.fct_call_time.setdefault(fn, 0)
if hasattr(mode, 'fct_call'):
mode.fct_call.setdefault(fn,0)
return mask, fn
# Hashing a dictionary/list/tuple by xoring the hash of each element
def hash_listsDictsTuples(x):
......@@ -519,33 +246,8 @@ def expand( tensor_var, size):
class Clone(Op):
def __init__(self):
self.view_map = {0:[0]}
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def __str__(self):
return 'clone[as_view]'
def make_node(self, *inputs):
x = inputs[0]
return Apply(self, inputs, [x.type()] )
def perform( self, node, args, outs):
outs[0][0] = args[0]
def infer_shape(self, node, input_shapes):
return input_shapes
def grad(self, args, g_outs):
return g_outs
cloneOp = Clone()
def equal_computations(x,y, strict=False):
'''
......
......@@ -2007,7 +2007,34 @@ class T_Scan(unittest.TestCase):
assert scan1.owner.op == scan2.owner.op
assert hash(scan1.owner.op) == hash(scan2.owner.op)
def test_same(self):
# This test is checking a bug discovered by Arnaud and it is based
# on his code
x = theano.tensor.fmatrix('x')
mem_val = numpy.zeros((2,), dtype='float32')
memory = theano.shared(mem_val.copy())
W = theano.shared(numpy.random.random((5, 2)).astype('float32'))
def f(inp, mem):
i = theano.tensor.join(0, inp, mem)
d = theano.tensor.dot(i, W)
return d, d
outs, updts = theano.scan(f, sequences=[x],
non_sequences=[],
outputs_info=[None, memory])
f = theano.function([x], outs[0])
f2 = theano.function([x], outs[1])
x_val = numpy.random.random((4, 3)).astype('float32')
f_vals = f(x_val)
memory.set_value(mem_val.copy())
f2_vals = f2(x_val)
assert numpy.allclose(f_vals, f2_vals)
if __name__ == '__main__':
#'''
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论