提交 d1e5c5ec authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #1605 from lamblin/scan_minor

Various small fixes in scan
......@@ -196,9 +196,13 @@ def scan(fn,
inner_slices = [] # Actual slices if scan is removed from the picture
# go through sequences picking up time slices as needed
for i, seq in enumerate(seqs):
if isinstance(seq, dict):
seq = seq['input']
actual_slice = seq[0]
_seq_val = tensor.as_tensor_variable(seq)
_seq_val_slice = _seq_val[0]
nw_slice = _seq_val_slice.type()
# Try to transfer test_value to the new variable
if config.compute_test_value != 'off':
try:
......@@ -212,7 +216,6 @@ def scan(fn,
'the inner function of scan, input value '
'missing %s'), e)
nw_slice = _seq_val_slice.type()
if seq.name:
nw_slice.name = seq.name + '[t]'
scan_seqs.append(_seq_val)
......
......@@ -384,7 +384,7 @@ def scan(fn,
seqs[i] = OrderedDict([('input', seqs[i]), ('taps', [0])])
elif seqs[i].get('taps', None):
seqs[i]['taps'] = wrap_into_list(seqs[i]['taps'])
elif seqs[i].get('taps', True) is None:
elif seqs[i].get('taps', None) is None:
# seqs dictionary does not have the ``taps`` key
seqs[i]['taps'] = [0]
......
......@@ -1451,51 +1451,15 @@ def scan_merge_inouts(node):
seen.append((outer_i, inner_o, outer_o))
return outer_o
def map_nitsot_out(outer_i, inner_o, outer_o, sh, seen):
# Like map_out, but also checks the needed shape.
for p, (s_outer_i, s_inner_o, s_outer_o, ssh) in enumerate(seen):
if (equal_computations([inner_o], [s_inner_o], left, right)
and outer_i == s_outer_i):
if equal_computations([sh], [ssh]):
return s_outer_o
try:
vsh = int(opt.get_scalar_constant_value(sh))
vssh = int(opt.get_scalar_constant_value(ssh))
except tensor.NotScalarConstantError:
return outer_o
if vsh == vssh:
return s_outer_o
elif vsh > vssh:
seen[p] = (outer_i, inner_o, outer_o, sh)
return outer_o
else:
return s_outer_o[:vsh]
seen.append((outer_i, inner_o, outer_o, sh))
return outer_o
seen = []
shapes = []
for x in na.outer_in_nit_sot:
if x.ndim > 0:
if hasattr(node.fgraph, 'shape_feature'):
shapes.append(
node.fgraph.shape_feature.shape_of[x][0])
else:
shapes.append(x.shape[0])
else:
# If x is a scalar, then it means its value is the number of
# items scan is supposed to store for this nit_sot sequence
shapes.append(x)
assert len(na.outer_in_nit_sot) == len(na.inner_out_nit_sot)
assert len(na.inner_out_nit_sot) == len(na.outer_out_nit_sot)
assert len(na.outer_out_nit_sot) == len(shapes)
na.outer_out_nit_sot = [
map_nitsot_out(outer_i, inner_o, outer_o, sh, seen)
for outer_i, inner_o, outer_o, sh in zip(na.outer_in_nit_sot,
na.inner_out_nit_sot,
na.outer_out_nit_sot,
shapes)]
map_out(outer_i, inner_o, outer_o, seen)
for outer_i, inner_o, outer_o in zip(na.outer_in_nit_sot,
na.inner_out_nit_sot,
na.outer_out_nit_sot)]
seen = []
assert len(na.outer_in_sit_sot) == len(na.inner_out_sit_sot)
......
......@@ -2477,6 +2477,18 @@ class T_Scan(unittest.TestCase):
# Run it so DebugMode can detect optimization problems.
f(x_val, y_val)
def test_sequence_dict(self):
# Test that we can specify sequences as a dictionary with
# only the 'input' key
def incr(s):
return s + 1
x = theano.tensor.vector()
sx, upx = theano.scan(
fn=incr,
sequences=[{'input': x}])
f = theano.function([x], sx)
def test_hash(self):
x = theano.tensor.vector()
y = theano.tensor.vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论