提交 a0478c3e authored 作者: Vincent Michalski's avatar Vincent Michalski

some refactoring

上级 9bcf61fa
...@@ -1946,39 +1946,36 @@ class test_local_subtensor_make_vector(unittest.TestCase): ...@@ -1946,39 +1946,36 @@ class test_local_subtensor_make_vector(unittest.TestCase):
r = f(0, 1, 2) r = f(0, 1, 2)
assert r[0] == 0 and r[1] == 2 assert r[0] == 0 and r[1] == 2
def test_stacktrace(self): def test_stack_trace(self):
x, y, z = tensor.lscalars('xyz') x, y, z = tensor.lscalars('xyz')
v = make_vector(x, y, z) v = make_vector(x, y, z)
# FIXME: remove the two test cases with v[0]? they are creating graphs # Compile functions in two modes:
# without apply nodes, which don't require check_stack_trace. # - only with 'local_subtensor_make_vector' (requires adding
# the 'canonicalize' phase)
# - all optimizations in fast_compile including the
# 'local_subtensor_make_vector' optimization
modes = [
theano.compile.mode.Mode(optimizer=None).including(
'canonicalize_db').including("local_subtensor_make_vector"),
theano.compile.mode.get_mode('FAST_COMPILE').including(
"local_subtensor_make_vector")
]
# Compile function using only the 'local_subtensor_make_vector' optimization, # list of subtensor cases, where local_subtensor_make_vector
# which requires us to add the 'canonicalize' phase. # inserts a new MakeVector node
mode = theano.compile.mode.Mode(optimizer=None).including('canonicalize_db').including("local_subtensor_make_vector") v_subtensors = [v[:2], v[::2], v[[0, 2]]]
f = function([x, y, z], v[0], mode=mode)
# Compile function using all optimizations in fast_compile mode, for mode in modes:
# including the 'local_subtensor_make_vector' optimization # case, where local_subtensor_make_vector only removes nodes
mode = theano.compile.mode.get_mode('FAST_COMPILE').including("local_subtensor_make_vector") # FIXME: remove this useless case, where the graph only contains a
# DeepCopyOp? Or is there a meaningful test case for constant
# scalar index subtensor?
f = function([x, y, z], v[0], mode=mode) f = function([x, y, z], v[0], mode=mode)
# The two cases in this test do not check the case where # cases, where local_subtensor_make_vector inserts nodes
# local_subtensor_make_vector inserts a Subtensor node (See issue #4421) for v_subtensor in v_subtensors:
# self.assertTrue(check_stack_trace(f, ops_to_check='all')) f = function([x, y, z], v_subtensor, mode=mode)
# Cases, in which local_subtensor_make_vector adds a new MakeVector
# node
# Compile function using only the 'local_subtensor_make_vector' optimization,
# which requires us to add the 'canonicalize' phase.
mode = theano.compile.mode.Mode(optimizer=None).including('canonicalize_db').including("local_subtensor_make_vector")
f = function([x, y, z], v[::2], mode=mode)
self.assertTrue(check_stack_trace(f, ops_to_check='all'))
# Compile function using all optimizations in fast_compile mode,
# including the 'local_subtensor_make_vector' optimization
mode = theano.compile.mode.get_mode('FAST_COMPILE').including("local_subtensor_make_vector")
f = function([x, y, z], v[::2], mode=mode)
self.assertTrue(check_stack_trace(f, ops_to_check='all')) self.assertTrue(check_stack_trace(f, ops_to_check='all'))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论