提交 dc5617c1 authored 作者: lamblin's avatar lamblin

Merge pull request #1341 from pascanur/recent_scan_bugs

Recent scan bugs
...@@ -192,19 +192,18 @@ class Scan(PureOp): ...@@ -192,19 +192,18 @@ class Scan(PureOp):
def make_node(self, *inputs): def make_node(self, *inputs):
""" """
Conventions: Conventions:
inner_? - the variable corresponding to ? in the inner function inner_X - the variable corresponding to X in the inner function
of scan (the lambda function executed at every time of scan (the lambda function executed at every time
step) step)
outer_? - the variable corresponding to ? in the outer graph, outer_X - the variable corresponding to X in the outer graph,
i.e. the main graph (where the scan op lives) i.e. the main graph (where the scan op lives)
inner_?_out - the variable representing the new value of ? after inner_X_out - the variable representing the new value of X after
executing one step of scan (i.e. outputs given by executing one step of scan (i.e. outputs given by
the inner function) the inner function)
""" """
assert numpy.all(isinstance(i, gof.Variable) for i in inputs) assert numpy.all(isinstance(i, gof.Variable) for i in inputs)
# Check that the number of inputs to the Scan node corresponds to # Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan # the number of inputs of the inner function of scan
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1 n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
n_inner_ins = (len(self.inner_seqs(self.inputs)) + n_inner_ins = (len(self.inner_seqs(self.inputs)) +
len(self.mitmot_taps()) + len(self.mitmot_taps()) +
...@@ -215,7 +214,7 @@ class Scan(PureOp): ...@@ -215,7 +214,7 @@ class Scan(PureOp):
assert n_outer_ins == n_inner_ins, \ assert n_outer_ins == n_inner_ins, \
("The number of inputs given to the inner function of scan" ("The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan.") " does not match the number of inputs given to scan.")
new_inputs = [inputs[0]]
# assert dtype is consistent # assert dtype is consistent
err_msg1 = ('When compiling the inner function of scan the ' err_msg1 = ('When compiling the inner function of scan the '
'following error has been encountered: The ' 'following error has been encountered: The '
...@@ -235,42 +234,35 @@ class Scan(PureOp): ...@@ -235,42 +234,35 @@ class Scan(PureOp):
'and %d dimension(s). This could happen if the inner ' 'and %d dimension(s). This could happen if the inner '
'graph of scan results in an upcast or downcast. ' 'graph of scan results in an upcast or downcast. '
'Please make sure that you use dtypes consistently') 'Please make sure that you use dtypes consistently')
# TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and
# inputs correspond)
#assert len(inputs) >= len(self.inputs)
#if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + \
# self.info["n_nit_sot"]
#else:
# assert len(inputs) == len(self.inputs) + 1 + \
# self.info["n_nit_sot"]
# Flags that indicate which inputs are vectors
self.vector_seqs = [seq.ndim == 1 for seq in def format(var, as_var):
inputs[1:1 + self.n_seqs]] """ This functions ensures that ``out`` has the same dtype as
self.vector_outs = [arg.ndim == 1 for arg in ``inp`` as well as calling filter_variable to make sure they are
inputs[1 + self.n_seqs: (1 + self.n_seqs + both TensorType or CudaNdarrayType. It internally deals with the
self.n_outs)]] corner case where inp.ndim + 1 = out.ndim
self.vector_outs += [False] * self.n_nit_sot """
if not hasattr(var, 'dtype'):
return var
rval = var
if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype)
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.__class__(
broadcastable=tuple(var.broadcastable[:1])+\
tuple(as_var.broadcastable),
dtype=as_var.dtype)
rval = tmp.filter_variable(rval)
return rval
# Check if input sequences and variables representing a slice of # Check if input sequences and variables representing a slice of
# them have the same dtype # them have the same dtype
argoffset = 0 argoffset = 0
for idx, (inner_seq, outer_seq) in enumerate( for inner_seq, outer_seq in zip(self.inner_seqs(self.inputs),
zip(self.inner_seqs(self.inputs), self.outer_seqs(inputs)):
self.outer_seqs(inputs))): new_inputs.append(format(outer_seq, as_var=inner_seq))
if inner_seq.type.dtype != outer_seq[0].type.dtype:
assert isinstance(idx, int)
raise ValueError(err_msg1 % ('sequence',
str(outer_seq),
idx,
outer_seq.type.dtype,
outer_seq.ndim,
str(inner_seq),
inner_seq.type.dtype,
inner_seq.ndim))
argoffset += len(self.outer_seqs(inputs)) argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot: # Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output # - initial state of the output
...@@ -280,10 +272,12 @@ class Scan(PureOp): ...@@ -280,10 +272,12 @@ class Scan(PureOp):
opos = 0 opos = 0
inner_mitmot = self.inner_mitmot(self.inputs) inner_mitmot = self.inner_mitmot(self.inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.outputs) inner_mitmot_outs = self.inner_mitmot_outs(self.outputs)
for idx, (itaps, otaps, outer_mitmot) in enumerate( for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), zip(self.mitmot_taps(),
self.mitmot_out_taps(), self.mitmot_out_taps(),
self.outer_mitmot(inputs))): self.outer_mitmot(inputs))):
outer_mitmot = format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot)
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if (inner_mitmot[ipos + k].type.dtype != if (inner_mitmot[ipos + k].type.dtype !=
outer_mitmot.type.dtype or outer_mitmot.type.dtype or
...@@ -316,13 +310,16 @@ class Scan(PureOp): ...@@ -316,13 +310,16 @@ class Scan(PureOp):
# Same checks as above but for outputs of type mit_sot # Same checks as above but for outputs of type mit_sot
ipos = 0 ipos = 0
inner_mitsots = self.inner_mitsot(self.inputs) inner_mitsots = self.inner_mitsot(self.inputs)
for idx, (itaps, outer_mitsot, inner_mitsot_out) in enumerate( for idx, (itaps, _outer_mitsot, inner_mitsot_out) in enumerate(
zip(self.mitsot_taps(), zip(self.mitsot_taps(),
self.outer_mitsot(inputs), self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.outputs))): self.inner_mitsot_outs(self.outputs))):
outer_mitsot = format(_outer_mitsot, as_var=inner_mitsots[ipos])
new_inputs.append(outer_mitsot)
for k in xrange(len(itaps)): for k in xrange(len(itaps)):
if (inner_mitsots[ipos + k].type.dtype != \ if (inner_mitsots[ipos + k].type.dtype != \
outer_mitsot.type.dtype or outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1): inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
...@@ -346,12 +343,13 @@ class Scan(PureOp): ...@@ -346,12 +343,13 @@ class Scan(PureOp):
argoffset += len(self.outer_mitsot(inputs)) argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot # Same checks as above but for outputs of type sit_sot
for idx, (inner_sitsot, outer_sitsot, inner_sitsot_out) in enumerate( for idx, (inner_sitsot, _outer_sitsot, inner_sitsot_out) in enumerate(
zip(self.inner_sitsot(self.inputs), zip(self.inner_sitsot(self.inputs),
self.outer_sitsot(inputs), self.outer_sitsot(inputs),
self.inner_sitsot_outs(self.outputs))): self.inner_sitsot_outs(self.outputs))):
if (inner_sitsot.type.dtype != outer_sitsot.type.dtype or outer_sitsot = format(_outer_sitsot, as_var=inner_sitsot)
inner_sitsot.ndim != outer_sitsot.ndim - 1): new_inputs.append(outer_sitsot)
if (inner_sitsot.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info' raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ', ' in scan nomenclature) ',
str(outer_sitsot), str(outer_sitsot),
...@@ -374,10 +372,12 @@ class Scan(PureOp): ...@@ -374,10 +372,12 @@ class Scan(PureOp):
argoffset += len(self.outer_sitsot(inputs)) argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same # Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?! # dtype. Maybe even same type ?!
for idx, (inner_shared, inner_shared_out, outer_shared) in enumerate( for idx, (inner_shared, inner_shared_out, _outer_shared) in enumerate(
zip(self.inner_shared(self.inputs), zip(self.inner_shared(self.inputs),
self.inner_shared_outs(self.outputs), self.inner_shared_outs(self.outputs),
self.outer_shared(inputs))): self.outer_shared(inputs))):
outer_shared = format(_outer_shared, as_var=inner_shared)
new_inputs.append(outer_shared)
if (hasattr(outer_shared, 'dtype') and if (hasattr(outer_shared, 'dtype') and
(outer_shared.dtype != inner_shared_out.dtype or (outer_shared.dtype != inner_shared_out.dtype or
outer_shared.ndim != inner_shared_out.ndim)): outer_shared.ndim != inner_shared_out.ndim)):
...@@ -400,13 +400,25 @@ class Scan(PureOp): ...@@ -400,13 +400,25 @@ class Scan(PureOp):
str(inner_shared), str(inner_shared),
inner_shared.dtype, inner_shared.dtype,
inner_shared.ndim)) inner_shared.ndim))
for inner_nonseq, outer_nonseq in zip( # We do not need to call `format` on outer_nisot arguments.
# outer_nitsot stands for no input tap single output tap. This means
# these are states that do not feed anything back in the recurrent
# computation, and hence they do not have an initial state. The scan
# node however receives an input for each such argument, the input
# in this case is just a int saying how many steps of this output we
# need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int.
new_inputs += self.outer_nitsot(inputs)
for inner_nonseq, _outer_nonseq in zip(
self.inner_non_seqs(self.inputs), self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)): self.outer_non_seqs(inputs)):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type: if inner_nonseq.type != outer_nonseq.type:
raise ValueError(('Argument %s given to scan node does not' raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s') % ' match its correspondance %s') %
(str(outer_nonseq), str(inner_nonseq))) (str(outer_nonseq), str(inner_nonseq)))
for outer_nitsot in self.outer_nitsot(inputs): for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that # For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is # depicts the size in memory for that sequence. This feature is
...@@ -415,9 +427,16 @@ class Scan(PureOp): ...@@ -415,9 +427,16 @@ class Scan(PureOp):
outer_nitsot.ndim != 0): outer_nitsot.ndim != 0):
raise ValueError('For output %s you need to provide a ' raise ValueError('For output %s you need to provide a '
'scalar int !', str(outer_nitsot)) 'scalar int !', str(outer_nitsot))
assert len(new_inputs) == len(inputs)
self.vector_seqs = [seq.ndim == 1 for seq in
new_inputs[1:1 + self.n_seqs]]
self.vector_outs = [arg.ndim == 1 for arg in
new_inputs[1 + self.n_seqs: (1 + self.n_seqs +
self.n_outs)]]
self.vector_outs += [False] * self.n_nit_sot
apply_node = Apply(self, apply_node = Apply(self,
inputs, new_inputs,
[t() for t in self.output_types]) [t() for t in self.output_types])
return apply_node return apply_node
...@@ -1199,6 +1218,9 @@ class Scan(PureOp): ...@@ -1199,6 +1218,9 @@ class Scan(PureOp):
return scan_outs return scan_outs
def get_input_pos(self, output_index): def get_input_pos(self, output_index):
""" For a given ``output_index``, an index in the inner outputs of
scan, find a corresponding first index in the inner inputs of scan
"""
ipos = self.n_seqs ipos = self.n_seqs
opos = output_index opos = output_index
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()): for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
...@@ -1219,6 +1241,9 @@ class Scan(PureOp): ...@@ -1219,6 +1241,9 @@ class Scan(PureOp):
return -1 return -1
def get_output_pos(self, input_index): def get_output_pos(self, input_index):
""" For a given ``input_index``, an index in the inner inputs of
scan, find a corresponding first index in the inner outputs of scan
"""
ipos = input_index ipos = input_index
opos = 0 opos = 0
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()): for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
...@@ -1239,6 +1264,9 @@ class Scan(PureOp): ...@@ -1239,6 +1264,9 @@ class Scan(PureOp):
return -1 return -1
def get_output_slice_idx(self, output_index): def get_output_slice_idx(self, output_index):
""" For an ``output_index``, an index in the outter ouputs of scan,
find a corresponding index in the inner outputs of scan.
"""
ipos = 0 ipos = 0
opos = output_index opos = output_index
for otaps in zip(self.mitmot_out_taps()): for otaps in zip(self.mitmot_out_taps()):
...@@ -1339,6 +1367,7 @@ class Scan(PureOp): ...@@ -1339,6 +1367,7 @@ class Scan(PureOp):
# Applying Floyd-Warshall to find all paths connecting inputs to # Applying Floyd-Warshall to find all paths connecting inputs to
# outputs. Note that if `x` is an input to `y_t` and `y_tm1` is an # outputs. Note that if `x` is an input to `y_t` and `y_tm1` is an
# input to `z_t` then `x` is an input to `z_t`. # input to `z_t` then `x` is an input to `z_t`.
n_outs = len(node.outputs) n_outs = len(node.outputs)
for steps in xrange(n_outs): for steps in xrange(n_outs):
for iidx in xrange(n_outs): for iidx in xrange(n_outs):
...@@ -1429,13 +1458,15 @@ class Scan(PureOp): ...@@ -1429,13 +1458,15 @@ class Scan(PureOp):
odx = get_out_idx(self_outputs.index(y)) odx = get_out_idx(self_outputs.index(y))
wrt = [x for x in theano.gof.graph.inputs([y]) wrt = [x for x in theano.gof.graph.inputs([y])
if (x in diff_inputs) and if (x in diff_inputs) and
connection_pattern[get_inp_idx(self_inputs.index(x))][odx]] (connection_pattern[
get_inp_idx(self_inputs.index(x))][odx])]
grads = gradient.grad( grads = gradient.grad(
cost=None, cost=None,
known_grads={y: g_y}, known_grads={y: g_y},
wrt=wrt, consider_constant=wrt, wrt=wrt,
disconnected_inputs='ignore', consider_constant=wrt,
return_disconnected='None') disconnected_inputs='ignore',
return_disconnected='None')
gmp = dict(zip(wrt, grads)) gmp = dict(zip(wrt, grads))
rval = [gmp.get(p, None) for p in diff_inputs] rval = [gmp.get(p, None) for p in diff_inputs]
return rval return rval
......
...@@ -1159,7 +1159,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1159,7 +1159,7 @@ class ScanMerge(gof.Optimizer):
Questionable, we should also consider profile ? Questionable, we should also consider profile ?
""" """
rep = set_nodes[0] rep = set_nodes[0]
if not rep.op.as_while and node.op.as_while: if rep.op.as_while != node.op.as_while:
return False return False
nsteps = node.inputs[0] nsteps = node.inputs[0]
......
...@@ -48,7 +48,10 @@ def safe_new(x, tag='', dtype=None): ...@@ -48,7 +48,10 @@ def safe_new(x, tag='', dtype=None):
nw_name = None nw_name = None
if isinstance(x, theano.Constant): if isinstance(x, theano.Constant):
if dtype and x.dtype != dtype: if dtype and x.dtype != dtype:
return x.clone().astype(dtype) casted_x = x.astype(dtype)
nwx = x.__class__(casted_x.type, x.data, x.name)
nwx.tag = copy(x.tag)
return nwx
else: else:
return x.clone() return x.clone()
# Note, as_tensor_variable will convert the Scalar into a # Note, as_tensor_variable will convert the Scalar into a
...@@ -70,6 +73,8 @@ def safe_new(x, tag='', dtype=None): ...@@ -70,6 +73,8 @@ def safe_new(x, tag='', dtype=None):
# ndarrays # ndarrays
pass pass
nw_x = x.type() nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype).type()
nw_x.name = nw_name nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used. # Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions # The test value is deep-copied to ensure there can be no interactions
...@@ -82,9 +87,6 @@ def safe_new(x, tag='', dtype=None): ...@@ -82,9 +87,6 @@ def safe_new(x, tag='', dtype=None):
# This means `x` has no test value. # This means `x` has no test value.
pass pass
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype)
return nw_x return nw_x
......
...@@ -3452,6 +3452,39 @@ class T_Scan(unittest.TestCase): ...@@ -3452,6 +3452,39 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=True), assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=True),
1.21000003815) 1.21000003815)
def test_grad_find_input(self):
w = theano.shared(numpy.array(0, dtype='float32'), name='w')
init = tensor.fscalar('init')
out, _ = theano.scan(
fn=lambda prev: w,
outputs_info=init,
n_steps=2,
)
tensor.grad(out[-1], w)
def test_scan_merge_nodes(self):
inps = tensor.vector()
state = tensor.scalar()
y1, _ = theano.scan(lambda x,y: x*y,
sequences = inps,
outputs_info = state,
n_steps = 5)
y2, _ = theano.scan(lambda x,y : (x+y, theano.scan_module.until(x>0)),
sequences = inps,
outputs_info = state,
n_steps = 5)
scan_node1 = y1.owner.inputs[0].owner
assert isinstance(scan_node1.op, theano.scan_module.scan_op.Scan)
scan_node2 = y2.owner.inputs[0].owner
assert isinstance(scan_node2.op, theano.scan_module.scan_op.Scan)
opt_obj = theano.scan_module.scan_opt.ScanMerge()
# Test the method belongs_to of this class. Specifically see if it
# detects the two scan_nodes as not being similar
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
def test_speed(): def test_speed():
# #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论