提交 e5b58faf authored 作者: Frederic's avatar Frederic

By default FAST_COMPILE do lazy evaluation

上级 d9fc9d73
...@@ -323,7 +323,9 @@ class Mode(object): ...@@ -323,7 +323,9 @@ class Mode(object):
# If a string is passed as the mode argument in function or # If a string is passed as the mode argument in function or
# FunctionMaker, the Mode will be taken from this dictionary using the # FunctionMaker, the Mode will be taken from this dictionary using the
# string as the key # string as the key
FAST_COMPILE = Mode('py', 'fast_compile') # Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(theano.gof.vm.VM_Linker(use_cloop=False, c_thunks=False),
'fast_compile')
if theano.config.cxx: if theano.config.cxx:
FAST_RUN = Mode('cvm', 'fast_run') FAST_RUN = Mode('cvm', 'fast_run')
else: else:
......
...@@ -53,6 +53,24 @@ class TestCallbacks(unittest.TestCase): ...@@ -53,6 +53,24 @@ class TestCallbacks(unittest.TestCase):
assert self.n_callbacks['IfElse'] == 2 assert self.n_callbacks['IfElse'] == 2
def test_c_thunks():
a = tensor.scalars('a')
b, c = tensor.vectors('bc')
cases = [False]
if theano.config.cxx:
cases.append(True)
for c_thunks in cases:
f = function([a, b, c], ifelse(a, a*b, b*c),
mode=Mode(
optimizer=None,
linker=vm.VM_Linker(c_thunks=c_thunks,
use_cloop=False)))
f(1, [2], [3, 2])
from nose.tools import assert_raises
assert_raises(ValueError, f, 0, [2], [3, 4])
assert any([hasattr(t, 'cthunk') for t in f.fn.thunks]) == c_thunks
def test_speed(): def test_speed():
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
......
...@@ -686,7 +686,7 @@ class VM_Linker(link.LocalLinker): ...@@ -686,7 +686,7 @@ class VM_Linker(link.LocalLinker):
""" """
def __init__(self, allow_gc=None, use_cloop=False, callback=None, def __init__(self, allow_gc=None, use_cloop=False, callback=None,
lazy=None, schedule=None): lazy=None, schedule=None, c_thunks=None):
""" """
allow_gc - force the virtual machine to clean up unnecessary allow_gc - force the virtual machine to clean up unnecessary
references, in order to allow garbage collection on references, in order to allow garbage collection on
...@@ -705,6 +705,8 @@ class VM_Linker(link.LocalLinker): ...@@ -705,6 +705,8 @@ class VM_Linker(link.LocalLinker):
version. If lazy is True or False, we force the version used version. If lazy is True or False, we force the version used
between Loop/LoopGC and Stack. between Loop/LoopGC and Stack.
c_thunks - If None or True, don't change the default. If False,
don't compile c code for the thunks.
""" """
# Note: if more parameters are added to __init__, make sure to forward # Note: if more parameters are added to __init__, make sure to forward
# them in the "type(self)(...)" call in the "accept" method below. # them in the "type(self)(...)" call in the "accept" method below.
...@@ -715,6 +717,7 @@ class VM_Linker(link.LocalLinker): ...@@ -715,6 +717,7 @@ class VM_Linker(link.LocalLinker):
self.use_cloop = use_cloop self.use_cloop = use_cloop
self.callback = callback self.callback = callback
self.lazy = lazy self.lazy = lazy
self.c_thunks = c_thunks
self.updated_vars = {} self.updated_vars = {}
if schedule: if schedule:
self.schedule = schedule self.schedule = schedule
...@@ -1010,6 +1013,8 @@ class VM_Linker(link.LocalLinker): ...@@ -1010,6 +1013,8 @@ class VM_Linker(link.LocalLinker):
for node in order: for node in order:
try: try:
if self.c_thunks is False:
node.op._op_use_c_code = False
thunks.append(node.op.make_thunk(node, thunks.append(node.op.make_thunk(node,
storage_map, storage_map,
compute_map, compute_map,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论