提交 9bcf61fa authored 作者: Vincent Michalski's avatar Vincent Michalski

added test case to test_local_subtensor_make_vector, where the opt inserts a node

上级 2a9da88c
...@@ -1950,6 +1950,9 @@ class test_local_subtensor_make_vector(unittest.TestCase): ...@@ -1950,6 +1950,9 @@ class test_local_subtensor_make_vector(unittest.TestCase):
x, y, z = tensor.lscalars('xyz') x, y, z = tensor.lscalars('xyz')
v = make_vector(x, y, z) v = make_vector(x, y, z)
# FIXME: remove the two test cases with v[0]? they are creating graphs
# without apply nodes, which don't require check_stack_trace.
# Compile function using only the 'local_subtensor_make_vector' optimization, # Compile function using only the 'local_subtensor_make_vector' optimization,
# which requires us to add the 'canonicalize' phase. # which requires us to add the 'canonicalize' phase.
mode = theano.compile.mode.Mode(optimizer=None).including('canonicalize_db').including("local_subtensor_make_vector") mode = theano.compile.mode.Mode(optimizer=None).including('canonicalize_db').including("local_subtensor_make_vector")
...@@ -1964,6 +1967,20 @@ class test_local_subtensor_make_vector(unittest.TestCase): ...@@ -1964,6 +1967,20 @@ class test_local_subtensor_make_vector(unittest.TestCase):
# local_subtensor_make_vector inserts a Subtensor node (See issue #4421) # local_subtensor_make_vector inserts a Subtensor node (See issue #4421)
# self.assertTrue(check_stack_trace(f, ops_to_check='all')) # self.assertTrue(check_stack_trace(f, ops_to_check='all'))
# Cases, in which local_subtensor_make_vector adds a new MakeVector
# node
# Compile function using only the 'local_subtensor_make_vector' optimization,
# which requires us to add the 'canonicalize' phase.
mode = theano.compile.mode.Mode(optimizer=None).including('canonicalize_db').including("local_subtensor_make_vector")
f = function([x, y, z], v[::2], mode=mode)
self.assertTrue(check_stack_trace(f, ops_to_check='all'))
# Compile function using all optimizations in fast_compile mode,
# including the 'local_subtensor_make_vector' optimization
mode = theano.compile.mode.get_mode('FAST_COMPILE').including("local_subtensor_make_vector")
f = function([x, y, z], v[::2], mode=mode)
self.assertTrue(check_stack_trace(f, ops_to_check='all'))
class test_local_subtensor_lift(unittest.TestCase): class test_local_subtensor_lift(unittest.TestCase):
def test0(self): def test0(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论