提交 80197cf5 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #3580 from carriepl/scan_infer_shape

[ENH] Infer more things in Scan.infer_shape()
......@@ -1566,23 +1566,46 @@ class Scan(PureOp):
# Infer Shape
def infer_shape(self, node, input_shapes):
# input_shapes correspond to the shapes of node.inputs
# Here, we build a list inner_ins_shape, such that inner_ins_shape[i]
# is the shape of self.inputs[i]
for inp, inp_shp in izip(node.inputs, input_shapes):
assert inp_shp is None or len(inp_shp) == inp.type.ndim
# sequences
# We skip iputs_shapes[0] as it is the total or current number
# Here we build 2 variables;
# - A list `inner_ins_shapes`, such that inner_ins_shapes[i] is the
# shape of self.inputs[i]
# - A dictionary `out_equivalent` containing, for every inner input,
# an equivalent variable computed from the outer inputs.
# NOTE : For non-sequences, this equivalence is trivial. For
# sequences and recurrent states, there is no direct equivalence
# between outer and inner inputs. However, because every iteration
# of the Scan needs to give the same output shapes, we can give an
# equivalence between these inner inputs and the subelements of the
# corresponding outer inputs that the Scan would use as input for
# any given iteration. For simplicity, we use iteration 0.
inner_ins_shapes = []
out_equivalent = OrderedDict()
# We skip the first outer input as it is the total or current number
# of iterations.
# sequences
seqs_shape = [x[1:] for x in input_shapes[1:1 + self.n_seqs]]
inner_seqs = self.inputs[:self.n_seqs]
outer_seqs = node.inputs[1:1 + self.n_seqs]
for in_s, out_s in izip(inner_seqs, outer_seqs):
out_equivalent[in_s] = out_s[0]
# mit_mot, mit_sot, sit_sot
outer_inp_idx = 1 + self.n_seqs
inner_inp_idx = self.n_seqs
n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outs_shape = []
for idx in xrange(n_outs):
mintap = abs(min(self.tap_array[idx]))
for k in self.tap_array[idx]:
outs_shape += [input_shapes[idx + self.n_seqs + 1][1:]]
corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
inner_inp_idx += 1
outer_inp_idx += 1
# shared_outs
offset = 1 + self.n_seqs + n_outs
......@@ -1597,9 +1620,9 @@ class Scan(PureOp):
# Non-sequences have a direct equivalent from self.inputs in
# node.inputs
inner_non_sequences = self.inputs[len(seqs_shape) + len(outs_shape):]
out_equivalent = OrderedDict()
for in_ns, out_ns in izip(inner_non_sequences, node.inputs[offset:]):
out_equivalent[in_ns] = out_ns
if self.as_while:
self_outs = self.outputs[:-1]
else:
......
......@@ -857,16 +857,19 @@ class Validator(object):
return None
if out.owner is None:
# This is an unknown input node, so it is invalid.
self.invalid.add(out)
if isinstance(out, tensor.TensorConstant):
# We can clone it to get a valid constant
# This might be a constant from the outer graph or a constant
# from the inner graph. In all cases, we can clone it to be
# certain we have a valid constant
cloned_out = out.clone()
self.valid.add(cloned_out)
self.invalid.add(out)
self.valid_equivalent[out] = cloned_out
return cloned_out, False
return None
else:
# This is an input node and it has not been explicitly marked
# as invalid so we can use it
return out, True
# Recurse over inputs
inputs = [self.check(i) for i in out.owner.inputs]
......
......@@ -2549,6 +2549,44 @@ class T_Scan(unittest.TestCase):
output, g_output = fct(i)
assert len(output) == g_output
def test_infer_shape2(self):
# Ensure that the shape inference can remove the Scan node in the
# case of a complicated inner graph involving sequences and recurrent
# states
seq = tensor.lvector()
sitsot_init = tensor.lscalar()
mitsot_init = tensor.lvector()
def step(seq1, sitsot_m1, mitsot_m2, mitsot_m1):
# Every iteration, the sitsot state decreases and the mitsot state
# increases such that their total value remains identical. This
# is because this value will be used as the shape of a nitsot
# output and the outputs of every iteration need to have the same
# shape
diff = mitsot_m1 + seq1
next_mitsot_val = mitsot_m2 + diff
next_sitsot_val = sitsot_m1 - diff
nitsot_out = tensor.AllocEmpty('float32')(next_mitsot_val +
next_sitsot_val)
return next_sitsot_val, next_mitsot_val, nitsot_out
out, updates = theano.scan(fn=step,
sequences=seq,
outputs_info=[sitsot_init,
{'initial' : mitsot_init,
'taps' : [-2, -1]},
None],
n_steps=5)
f = theano.function([seq, sitsot_init, mitsot_init], out[2].shape,
mode='FAST_RUN')
assert(len(scan_nodes_from_fct(f)) == 0)
output_shape = f(numpy.arange(5), 5, [1, 2])
assert(all(output_shape == (5,6)))
# The following test will fail in DebugMode if there are
# some problems in Scan.infer_shape
def test_remove_stuff(self):
......@@ -3946,7 +3984,11 @@ class T_Scan(unittest.TestCase):
assert numpy.all(exp_out == f(inp))
def test_borrow_bug_jeremiah(self):
# This test fails if scan uses wrongly the borrow flag
# This tests two things. The first is a bug occuring when scan wrongly
# used the borrow flag. The second thing it that Scan's infer_shape()
# method will be able to remove the Scan node from the graph in this
# case.
inp = numpy.arange(10).reshape(-1, 1).astype(theano.config.floatX)
exp_out = numpy.zeros((10, 1)).astype(theano.config.floatX)
exp_out[4:] = inp[:-4]
......@@ -3967,8 +4009,17 @@ class T_Scan(unittest.TestCase):
updates = OrderedDict([(sharedvar, results[0][-1:])])
f = theano.function([seq], results[1], updates=updates)
# This fails if scan uses wrongly the borrow flag
assert numpy.all(exp_out == f(inp))
# This fails if Scan's infer_shape() is unable to remove the Scan
# node from the graph.
f_infershape = theano.function([seq], results[1].shape,
mode='FAST_RUN')
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
assert(len(scan_nodes_infershape) == 0)
def test_memory_reuse_with_outputs_as_inputs(self):
# Test the memory pre-allocation feature in scan for the following
# cases :
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论