提交 4c303896 authored 作者: Frederic's avatar Frederic

add test for an optimization about scan sequences

上级 c1ebebce
...@@ -3564,11 +3564,10 @@ class T_Scan(unittest.TestCase): ...@@ -3564,11 +3564,10 @@ class T_Scan(unittest.TestCase):
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2]) assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1]) assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
def test_remove_constants_and_unused_inputs_scan(self): def test_remove_constants_and_unused_inputs_scan_non_seqs(self):
""" """Test the opt remove_constants_and_unused_inputs_scan for
Test the opt remove_constants_and_unused_inputs_scan non sequences.
TODO: currently we only test non_seqs, should test
""" """
W = theano.tensor.matrix(name='W') W = theano.tensor.matrix(name='W')
v = theano.tensor.ivector(name='v') v = theano.tensor.ivector(name='v')
...@@ -3594,17 +3593,61 @@ class T_Scan(unittest.TestCase): ...@@ -3594,17 +3593,61 @@ class T_Scan(unittest.TestCase):
f(numpy.zeros((3, 3), dtype=theano.config.floatX), [1, 2]) f(numpy.zeros((3, 3), dtype=theano.config.floatX), [1, 2])
scan_node = f.maker.fgraph.toposort()[-1] scan_node = f.maker.fgraph.toposort()[-1]
# TODO: Why this assert always fail? # The first input is the number of iteration.
# assert (len(scan_node.inputs) == assert (len(scan_node.inputs[1:]) ==
# len(set(scan_node.inputs))) len(set(scan_node.inputs[1:])))
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs) inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1 assert len(inp) == 1
assert (len(inp) == len(set(inp))) assert (len(inp) == len(set(inp)))
inp = scan_node.op.outer_non_seqs(scan_node) inp = scan_node.op.outer_non_seqs(scan_node)
assert len(inp) == 1 assert len(inp) == 1
assert (len(inp) == len(set(inp))) assert (len(inp) == len(set(inp)))
#import pdb;pdb.set_trace()
#utt.assert_allclose(f([1, 2]), [[0, 0, 0], [1, 1, 1], [1, 1, 1]]) def test_remove_constants_and_unused_inputs_scan_seqs(self):
"""
Test the opt remove_constants_and_unused_inputs_scan for sequences.
"""
W = theano.tensor.matrix(name='W')
v = theano.tensor.ivector(name='v')
vv = theano.tensor.matrix(name='vv')
y1, _ = theano.scan(lambda i, W: W[i], sequences=v,
outputs_info=None, non_sequences=[W])
y2, _ = theano.scan(lambda i, _, W: W[i], sequences=[v, v],
outputs_info=None, non_sequences=W)
y3, _ = theano.scan(lambda i, _, W: W[i], sequences=[v, vv[0]],
outputs_info=None, non_sequences=W)
y4, _ = theano.scan(lambda _, i, W: W[i], sequences=[vv[0], v],
outputs_info=None, non_sequences=W)
y5, _ = theano.scan(lambda _, i, _2, W: W[i], sequences=[vv, v, vv[0]],
outputs_info=None, non_sequences=W)
y6, _ = theano.scan(lambda _, _2, i, W: W[i], sequences=[vv[0], vv, v],
outputs_info=None, non_sequences=W)
y7, _ = theano.scan(lambda i, _, _2, W: W[i],
sequences=[v, vv[0], vv[0]],
outputs_info=None, non_sequences=W)
y8, _ = theano.scan(lambda _, i, W, _2, _3: W[i], sequences=[vv[0], v],
outputs_info=None, non_sequences=[W, W[0], W[0]])
for out in [y1, y2, y3, y4, y5, y6, y7, y8]:
#This used to raise an exception
f = theano.function([W, v, vv], out, on_unused_input='ignore',
mode=mode_with_opt)
f(numpy.zeros((3, 3), theano.config.floatX),
[1, 2],
numpy.zeros((3, 3), theano.config.floatX))
scan_node = f.maker.fgraph.toposort()[-1]
# The first input is the number of iteration.
assert (len(scan_node.inputs[1:]) ==
len(set(scan_node.inputs[1:])))
inp = scan_node.op.inner_seqs(scan_node.op.inputs)
assert len(inp) == 1
inp = scan_node.op.outer_seqs(scan_node)
assert len(inp) == 1
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1
inp = scan_node.op.outer_non_seqs(scan_node)
assert len(inp) == 1
def test_speed(): def test_speed():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论