提交 dc5617c1 authored 作者: lamblin's avatar lamblin

Merge pull request #1341 from pascanur/recent_scan_bugs

Recent scan bugs
......@@ -192,19 +192,18 @@ class Scan(PureOp):
def make_node(self, *inputs):
"""
Conventions:
inner_? - the variable corresponding to ? in the inner function
inner_X - the variable corresponding to X in the inner function
of scan (the lambda function executed at every time
step)
outer_? - the variable corresponding to ? in the outer graph,
outer_X - the variable corresponding to X in the outer graph,
i.e. the main graph (where the scan op lives)
inner_?_out - the variable representing the new value of ? after
inner_X_out - the variable representing the new value of X after
executing one step of scan (i.e. outputs given by
the inner function)
"""
assert numpy.all(isinstance(i, gof.Variable) for i in inputs)
# Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
n_inner_ins = (len(self.inner_seqs(self.inputs)) +
len(self.mitmot_taps()) +
......@@ -215,7 +214,7 @@ class Scan(PureOp):
assert n_outer_ins == n_inner_ins, \
("The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan.")
new_inputs = [inputs[0]]
# assert dtype is consistent
err_msg1 = ('When compiling the inner function of scan the '
'following error has been encountered: The '
......@@ -235,42 +234,35 @@ class Scan(PureOp):
'and %d dimension(s). This could happen if the inner '
'graph of scan results in an upcast or downcast. '
'Please make sure that you use dtypes consistently')
# TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and
# inputs correspond)
#assert len(inputs) >= len(self.inputs)
#if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + \
# self.info["n_nit_sot"]
#else:
# assert len(inputs) == len(self.inputs) + 1 + \
# self.info["n_nit_sot"]
# Flags that indicate which inputs are vectors
self.vector_seqs = [seq.ndim == 1 for seq in
inputs[1:1 + self.n_seqs]]
self.vector_outs = [arg.ndim == 1 for arg in
inputs[1 + self.n_seqs: (1 + self.n_seqs +
self.n_outs)]]
self.vector_outs += [False] * self.n_nit_sot
def format(var, as_var):
""" This functions ensures that ``out`` has the same dtype as
``inp`` as well as calling filter_variable to make sure they are
both TensorType or CudaNdarrayType. It internally deals with the
corner case where inp.ndim + 1 = out.ndim
"""
if not hasattr(var, 'dtype'):
return var
rval = var
if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype)
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.__class__(
broadcastable=tuple(var.broadcastable[:1])+\
tuple(as_var.broadcastable),
dtype=as_var.dtype)
rval = tmp.filter_variable(rval)
return rval
# Check if input sequences and variables representing a slice of
# them have the same dtype
argoffset = 0
for idx, (inner_seq, outer_seq) in enumerate(
zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs))):
if inner_seq.type.dtype != outer_seq[0].type.dtype:
assert isinstance(idx, int)
raise ValueError(err_msg1 % ('sequence',
str(outer_seq),
idx,
outer_seq.type.dtype,
outer_seq.ndim,
str(inner_seq),
inner_seq.type.dtype,
inner_seq.ndim))
for inner_seq, outer_seq in zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs)):
new_inputs.append(format(outer_seq, as_var=inner_seq))
argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output
......@@ -280,10 +272,12 @@ class Scan(PureOp):
opos = 0
inner_mitmot = self.inner_mitmot(self.inputs)
inner_mitmot_outs = self.inner_mitmot_outs(self.outputs)
for idx, (itaps, otaps, outer_mitmot) in enumerate(
for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(),
self.mitmot_out_taps(),
self.outer_mitmot(inputs))):
outer_mitmot = format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot)
for k in xrange(len(itaps)):
if (inner_mitmot[ipos + k].type.dtype !=
outer_mitmot.type.dtype or
......@@ -316,13 +310,16 @@ class Scan(PureOp):
# Same checks as above but for outputs of type mit_sot
ipos = 0
inner_mitsots = self.inner_mitsot(self.inputs)
for idx, (itaps, outer_mitsot, inner_mitsot_out) in enumerate(
for idx, (itaps, _outer_mitsot, inner_mitsot_out) in enumerate(
zip(self.mitsot_taps(),
self.outer_mitsot(inputs),
self.inner_mitsot_outs(self.outputs))):
outer_mitsot = format(_outer_mitsot, as_var=inner_mitsots[ipos])
new_inputs.append(outer_mitsot)
for k in xrange(len(itaps)):
if (inner_mitsots[ipos + k].type.dtype != \
outer_mitsot.type.dtype or
outer_mitsot.type.dtype or
inner_mitsots[ipos + k].ndim != outer_mitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
......@@ -346,12 +343,13 @@ class Scan(PureOp):
argoffset += len(self.outer_mitsot(inputs))
# Same checks as above but for outputs of type sit_sot
for idx, (inner_sitsot, outer_sitsot, inner_sitsot_out) in enumerate(
for idx, (inner_sitsot, _outer_sitsot, inner_sitsot_out) in enumerate(
zip(self.inner_sitsot(self.inputs),
self.outer_sitsot(inputs),
self.inner_sitsot_outs(self.outputs))):
if (inner_sitsot.type.dtype != outer_sitsot.type.dtype or
inner_sitsot.ndim != outer_sitsot.ndim - 1):
outer_sitsot = format(_outer_sitsot, as_var=inner_sitsot)
new_inputs.append(outer_sitsot)
if (inner_sitsot.ndim != outer_sitsot.ndim - 1):
raise ValueError(err_msg1 % ('initial state (outputs_info'
' in scan nomenclature) ',
str(outer_sitsot),
......@@ -374,10 +372,12 @@ class Scan(PureOp):
argoffset += len(self.outer_sitsot(inputs))
# Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?!
for idx, (inner_shared, inner_shared_out, outer_shared) in enumerate(
for idx, (inner_shared, inner_shared_out, _outer_shared) in enumerate(
zip(self.inner_shared(self.inputs),
self.inner_shared_outs(self.outputs),
self.outer_shared(inputs))):
outer_shared = format(_outer_shared, as_var=inner_shared)
new_inputs.append(outer_shared)
if (hasattr(outer_shared, 'dtype') and
(outer_shared.dtype != inner_shared_out.dtype or
outer_shared.ndim != inner_shared_out.ndim)):
......@@ -400,13 +400,25 @@ class Scan(PureOp):
str(inner_shared),
inner_shared.dtype,
inner_shared.ndim))
for inner_nonseq, outer_nonseq in zip(
# We do not need to call `format` on outer_nisot arguments.
# outer_nitsot stands for no input tap single output tap. This means
# these are states that do not feed anything back in the recurrent
# computation, and hence they do not have an initial state. The scan
# node however receives an input for each such argument, the input
# in this case is just a int saying how many steps of this output we
# need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int.
new_inputs += self.outer_nitsot(inputs)
for inner_nonseq, _outer_nonseq in zip(
self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type:
raise ValueError(('Argument %s given to scan node does not'
' match its correspondance %s') %
(str(outer_nonseq), str(inner_nonseq)))
for outer_nitsot in self.outer_nitsot(inputs):
# For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is
......@@ -415,9 +427,16 @@ class Scan(PureOp):
outer_nitsot.ndim != 0):
raise ValueError('For output %s you need to provide a '
'scalar int !', str(outer_nitsot))
assert len(new_inputs) == len(inputs)
self.vector_seqs = [seq.ndim == 1 for seq in
new_inputs[1:1 + self.n_seqs]]
self.vector_outs = [arg.ndim == 1 for arg in
new_inputs[1 + self.n_seqs: (1 + self.n_seqs +
self.n_outs)]]
self.vector_outs += [False] * self.n_nit_sot
apply_node = Apply(self,
inputs,
new_inputs,
[t() for t in self.output_types])
return apply_node
......@@ -1199,6 +1218,9 @@ class Scan(PureOp):
return scan_outs
def get_input_pos(self, output_index):
""" For a given ``output_index``, an index in the inner outputs of
scan, find a corresponding first index in the inner inputs of scan
"""
ipos = self.n_seqs
opos = output_index
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
......@@ -1219,6 +1241,9 @@ class Scan(PureOp):
return -1
def get_output_pos(self, input_index):
""" For a given ``input_index``, an index in the inner inputs of
scan, find a corresponding first index in the inner outputs of scan
"""
ipos = input_index
opos = 0
for otaps, itaps in zip(self.mitmot_out_taps(), self.mitmot_taps()):
......@@ -1239,6 +1264,9 @@ class Scan(PureOp):
return -1
def get_output_slice_idx(self, output_index):
""" For an ``output_index``, an index in the outter ouputs of scan,
find a corresponding index in the inner outputs of scan.
"""
ipos = 0
opos = output_index
for otaps in zip(self.mitmot_out_taps()):
......@@ -1339,6 +1367,7 @@ class Scan(PureOp):
# Applying Floyd-Warshall to find all paths connecting inputs to
# outputs. Note that if `x` is an input to `y_t` and `y_tm1` is an
# input to `z_t` then `x` is an input to `z_t`.
n_outs = len(node.outputs)
for steps in xrange(n_outs):
for iidx in xrange(n_outs):
......@@ -1429,13 +1458,15 @@ class Scan(PureOp):
odx = get_out_idx(self_outputs.index(y))
wrt = [x for x in theano.gof.graph.inputs([y])
if (x in diff_inputs) and
connection_pattern[get_inp_idx(self_inputs.index(x))][odx]]
(connection_pattern[
get_inp_idx(self_inputs.index(x))][odx])]
grads = gradient.grad(
cost=None,
known_grads={y: g_y},
wrt=wrt, consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
cost=None,
known_grads={y: g_y},
wrt=wrt,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
gmp = dict(zip(wrt, grads))
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
......
......@@ -1159,7 +1159,7 @@ class ScanMerge(gof.Optimizer):
Questionable, we should also consider profile ?
"""
rep = set_nodes[0]
if not rep.op.as_while and node.op.as_while:
if rep.op.as_while != node.op.as_while:
return False
nsteps = node.inputs[0]
......
......@@ -48,7 +48,10 @@ def safe_new(x, tag='', dtype=None):
nw_name = None
if isinstance(x, theano.Constant):
if dtype and x.dtype != dtype:
return x.clone().astype(dtype)
casted_x = x.astype(dtype)
nwx = x.__class__(casted_x.type, x.data, x.name)
nwx.tag = copy(x.tag)
return nwx
else:
return x.clone()
# Note, as_tensor_variable will convert the Scalar into a
......@@ -70,6 +73,8 @@ def safe_new(x, tag='', dtype=None):
# ndarrays
pass
nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype).type()
nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
......@@ -82,9 +87,6 @@ def safe_new(x, tag='', dtype=None):
# This means `x` has no test value.
pass
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype)
return nw_x
......
......@@ -3452,6 +3452,39 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=True),
1.21000003815)
def test_grad_find_input(self):
w = theano.shared(numpy.array(0, dtype='float32'), name='w')
init = tensor.fscalar('init')
out, _ = theano.scan(
fn=lambda prev: w,
outputs_info=init,
n_steps=2,
)
tensor.grad(out[-1], w)
def test_scan_merge_nodes(self):
inps = tensor.vector()
state = tensor.scalar()
y1, _ = theano.scan(lambda x,y: x*y,
sequences = inps,
outputs_info = state,
n_steps = 5)
y2, _ = theano.scan(lambda x,y : (x+y, theano.scan_module.until(x>0)),
sequences = inps,
outputs_info = state,
n_steps = 5)
scan_node1 = y1.owner.inputs[0].owner
assert isinstance(scan_node1.op, theano.scan_module.scan_op.Scan)
scan_node2 = y2.owner.inputs[0].owner
assert isinstance(scan_node2.op, theano.scan_module.scan_op.Scan)
opt_obj = theano.scan_module.scan_opt.ScanMerge()
# Test the method belongs_to of this class. Specifically see if it
# detects the two scan_nodes as not being similar
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
def test_speed():
#
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论