提交 e0f91f59 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3038 from nouiz/default_mode

[MRG] FAST_COMPILE now lazy by default
...@@ -55,7 +55,7 @@ AddConfigVar('DebugMode.check_finite', ...@@ -55,7 +55,7 @@ AddConfigVar('DebugMode.check_finite',
AddConfigVar('DebugMode.check_strides', AddConfigVar('DebugMode.check_strides',
("Check that Python- and C-produced ndarrays have same strides. " ("Check that Python- and C-produced ndarrays have same strides. "
"On difference: (0) - ignore, (1) warn, or (2) raise error"), "On difference: (0) - ignore, (1) warn, or (2) raise error"),
IntParam(1, lambda i: i in (0, 1, 2)), IntParam(0, lambda i: i in (0, 1, 2)),
in_c_key=False) in_c_key=False)
AddConfigVar('DebugMode.warn_input_not_reused', AddConfigVar('DebugMode.warn_input_not_reused',
......
...@@ -324,7 +324,9 @@ class Mode(object): ...@@ -324,7 +324,9 @@ class Mode(object):
# If a string is passed as the mode argument in function or # If a string is passed as the mode argument in function or
# FunctionMaker, the Mode will be taken from this dictionary using the # FunctionMaker, the Mode will be taken from this dictionary using the
# string as the key # string as the key
FAST_COMPILE = Mode('py', 'fast_compile') # Use VM_linker to allow lazy evaluation by default.
FAST_COMPILE = Mode(theano.gof.vm.VM_Linker(use_cloop=False, c_thunks=False),
'fast_compile')
if theano.config.cxx: if theano.config.cxx:
FAST_RUN = Mode('cvm', 'fast_run') FAST_RUN = Mode('cvm', 'fast_run')
else: else:
......
...@@ -55,6 +55,24 @@ class TestCallbacks(unittest.TestCase): ...@@ -55,6 +55,24 @@ class TestCallbacks(unittest.TestCase):
assert self.n_callbacks['IfElse'] == 2 assert self.n_callbacks['IfElse'] == 2
def test_c_thunks():
a = tensor.scalars('a')
b, c = tensor.vectors('bc')
cases = [False]
if theano.config.cxx:
cases.append(True)
for c_thunks in cases:
f = function([a, b, c], ifelse(a, a*b, b*c),
mode=Mode(
optimizer=None,
linker=vm.VM_Linker(c_thunks=c_thunks,
use_cloop=False)))
f(1, [2], [3, 2])
from nose.tools import assert_raises
assert_raises(ValueError, f, 0, [2], [3, 4])
assert any([hasattr(t, 'cthunk') for t in f.fn.thunks]) == c_thunks
def test_speed(): def test_speed():
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
......
...@@ -106,7 +106,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, ...@@ -106,7 +106,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re,
if ins not in view_of and not viewed_by.get(ins, []): if ins not in view_of and not viewed_by.get(ins, []):
# where gc # where gc
for i in range(idx + 1, len(order)): for i in range(idx + 1, len(order)):
if reuse_out: if reuse_out is not None:
break break
for out in order[i].outputs: for out in order[i].outputs:
if (getattr(out, 'ndim', None) == 0 and if (getattr(out, 'ndim', None) == 0 and
...@@ -115,6 +115,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, ...@@ -115,6 +115,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re,
reuse_out = out reuse_out = out
pre_allocated.add(out) pre_allocated.add(out)
allocated.add(ins) allocated.add(ins)
break
elif ins in view_of: elif ins in view_of:
origin = view_of[ins] origin = view_of[ins]
if ins in viewed_by[origin]: if ins in viewed_by[origin]:
...@@ -124,7 +125,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, ...@@ -124,7 +125,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re,
not isinstance(origin, theano.Constant)): not isinstance(origin, theano.Constant)):
# where gc # where gc
for i in range(idx + 1, len(order)): for i in range(idx + 1, len(order)):
if reuse_out: if reuse_out is not None:
break break
for out in order[i].outputs: for out in order[i].outputs:
if (getattr(out, 'ndim', None) == 0 and if (getattr(out, 'ndim', None) == 0 and
...@@ -133,8 +134,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, ...@@ -133,8 +134,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re,
reuse_out = out reuse_out = out
pre_allocated.add(out) pre_allocated.add(out)
allocated.add(ins) allocated.add(ins)
break
if reuse_out: if reuse_out is not None:
reallocated_info[ins] = [ins, reuse_out] reallocated_info[ins] = [ins, reuse_out]
return reallocated_info return reallocated_info
...@@ -688,7 +689,7 @@ class VM_Linker(link.LocalLinker): ...@@ -688,7 +689,7 @@ class VM_Linker(link.LocalLinker):
""" """
def __init__(self, allow_gc=None, use_cloop=False, callback=None, def __init__(self, allow_gc=None, use_cloop=False, callback=None,
lazy=None, schedule=None): lazy=None, schedule=None, c_thunks=None):
""" """
allow_gc - force the virtual machine to clean up unnecessary allow_gc - force the virtual machine to clean up unnecessary
references, in order to allow garbage collection on references, in order to allow garbage collection on
...@@ -707,6 +708,8 @@ class VM_Linker(link.LocalLinker): ...@@ -707,6 +708,8 @@ class VM_Linker(link.LocalLinker):
version. If lazy is True or False, we force the version used version. If lazy is True or False, we force the version used
between Loop/LoopGC and Stack. between Loop/LoopGC and Stack.
c_thunks - If None or True, don't change the default. If False,
don't compile c code for the thunks.
""" """
# Note: if more parameters are added to __init__, make sure to forward # Note: if more parameters are added to __init__, make sure to forward
# them in the "type(self)(...)" call in the "accept" method below. # them in the "type(self)(...)" call in the "accept" method below.
...@@ -717,6 +720,7 @@ class VM_Linker(link.LocalLinker): ...@@ -717,6 +720,7 @@ class VM_Linker(link.LocalLinker):
self.use_cloop = use_cloop self.use_cloop = use_cloop
self.callback = callback self.callback = callback
self.lazy = lazy self.lazy = lazy
self.c_thunks = c_thunks
self.updated_vars = {} self.updated_vars = {}
if schedule: if schedule:
self.schedule = schedule self.schedule = schedule
...@@ -755,7 +759,8 @@ class VM_Linker(link.LocalLinker): ...@@ -755,7 +759,8 @@ class VM_Linker(link.LocalLinker):
use_cloop=self.use_cloop, use_cloop=self.use_cloop,
callback=self.callback, callback=self.callback,
lazy=self.lazy, lazy=self.lazy,
schedule=self.schedule schedule=self.schedule,
c_thunks=self.c_thunks,
).accept(fgraph, no_recycling) ).accept(fgraph, no_recycling)
self.fgraph = fgraph self.fgraph = fgraph
self.no_recycling = no_recycling self.no_recycling = no_recycling
...@@ -1012,6 +1017,8 @@ class VM_Linker(link.LocalLinker): ...@@ -1012,6 +1017,8 @@ class VM_Linker(link.LocalLinker):
for node in order: for node in order:
try: try:
if self.c_thunks is False:
node.op._op_use_c_code = False
thunks.append(node.op.make_thunk(node, thunks.append(node.op.make_thunk(node,
storage_map, storage_map,
compute_map, compute_map,
...@@ -1071,3 +1078,8 @@ class VM_Linker(link.LocalLinker): ...@@ -1071,3 +1078,8 @@ class VM_Linker(link.LocalLinker):
for output, storage in zip(fgraph.outputs, output_storage)], for output, storage in zip(fgraph.outputs, output_storage)],
thunks, thunks,
order) order)
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'c_thunks'):
self.c_thunks = True
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论