提交 b8165faa authored 作者: nouiz's avatar nouiz

Merge pull request #621 from pascanur/fix_scan_savemem_matthias

fix for the bug involving multiple subtensor on a scan. The consequence was warning printed to the user and scan use more memory.
...@@ -609,7 +609,7 @@ def scan(fn, ...@@ -609,7 +609,7 @@ def scan(fn,
info['gpu'] = False info['gpu'] = False
info['as_while'] = as_while info['as_while'] = as_while
info['profile'] = profile info['profile'] = profile
info['_scan_merge_visited'] = True info['_scan_savemem_visited'] = True
local_op = scan_op.Scan(inner_inputs, new_outs, info) local_op = scan_op.Scan(inner_inputs, new_outs, info)
......
...@@ -526,7 +526,6 @@ class ScanSaveMem(gof.Optimizer): ...@@ -526,7 +526,6 @@ class ScanSaveMem(gof.Optimizer):
if (isinstance(this_slice[0], slice) and if (isinstance(this_slice[0], slice) and
this_slice[0].stop is None): this_slice[0].stop is None):
global_nsteps = None global_nsteps = None
break
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
stop = tensor.basic.extract_constant(cf_slice[0].stop) stop = tensor.basic.extract_constant(cf_slice[0].stop)
else: else:
...@@ -741,7 +740,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -741,7 +740,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.6 Compose the new scan # 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization # I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that # twice since bad things usually happen if I do that
info['_scan_merge_visited'] = True info['_scan_savemem_visited'] = True
new_outs = scan_op.Scan(inps, new_outs = scan_op.Scan(inps,
outs, outs,
info).make_node(*node_ins).outputs info).make_node(*node_ins).outputs
...@@ -834,7 +833,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -834,7 +833,7 @@ class ScanSaveMem(gof.Optimizer):
nodelist = [x for x in env.toposort() if isinstance(x.op, nodelist = [x for x in env.toposort() if isinstance(x.op,
scan_op.Scan)] scan_op.Scan)]
for node in nodelist: for node in nodelist:
if not hasattr(node.op, '_scan_merge_visited'): if not hasattr(node.op, '_scan_savemem_visited'):
self.process_node(env, node) self.process_node(env, node)
# Just before specialize to have the other optimization # Just before specialize to have the other optimization
......
...@@ -285,6 +285,39 @@ class T_Scan(unittest.TestCase): ...@@ -285,6 +285,39 @@ class T_Scan(unittest.TestCase):
theano_values = my_f(state, steps) theano_values = my_f(state, steps)
assert numpy.allclose(numpy_values, theano_values) assert numpy.allclose(numpy_values, theano_values)
def test_subtensor_multiple_slices(self):
# This addresses a bug reported by Matthias Zoehrer
# the bug happens when you have multiple subtensors on the output of
# scan (the bug requires the reshape to be produced, and it has
# which has something to do with how the subtensors overlap
def f_pow2(x_tm1):
return 2 * x_tm1
state = theano.tensor.vector('state')
n_steps = theano.tensor.iscalar('nsteps')
output, updates = theano.scan(f_pow2,
[],
state,
[],
n_steps=n_steps,
truncate_gradient=-1,
go_backwards=False)
nw_shape = tensor.ivector('nw_shape')
# Note that the output is reshaped to 3 dimensional tensor, and
my_f = theano.function([state, n_steps, nw_shape],
[tensor.reshape(output, nw_shape, ndim=3)[:-2],
output[:-4]],
updates=updates,
allow_input_downcast=True)
nodes = [x for x in my_f.maker.env.toposort()
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
# This assertation fails if savemem optimization failed on scan
assert nodes[0].op._scan_savemem_visited
rng = numpy.random.RandomState(utt.fetch_seed())
my_f(rng.uniform(size=(3,)),
4,
numpy.int64([2, 2, 3]))
# simple rnn, one input, one state, weights for each; input/state # simple rnn, one input, one state, weights for each; input/state
# are vectors, weights are scalars # are vectors, weights are scalars
def test_one_sequence_one_output_weights(self): def test_one_sequence_one_output_weights(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论