提交 e0821760 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6000 from ReyhaneAskari/new_destroy_handler

New destroy handler
...@@ -43,6 +43,8 @@ CPUs. In fact, Theano asks g++ what are the equivalent flags it uses, and re-use ...@@ -43,6 +43,8 @@ CPUs. In fact, Theano asks g++ what are the equivalent flags it uses, and re-use
them directly. them directly.
.. _faster-theano-function-compilation:
Faster Theano Function Compilation Faster Theano Function Compilation
---------------------------------- ----------------------------------
...@@ -67,6 +69,13 @@ compilation but it will also use more memory because ...@@ -67,6 +69,13 @@ compilation but it will also use more memory because
resulting in a trade off between speed of compilation and memory resulting in a trade off between speed of compilation and memory
usage. usage.
Alternatively, if the graph is big, using the flag ``cycle_detection=fast``
will speedup the computations by removing some of the inplace
optimizations. This would allow theano to skip a time consuming cycle
detection algorithm. If the graph is big enough,we suggest that you use
this flag instead of ``optimizer_excluding=inplace``. It will result in a
computation time that is in between fast compile and fast run.
Theano flag `reoptimize_unpickled_function` controls if an unpickled Theano flag `reoptimize_unpickled_function` controls if an unpickled
theano function should reoptimize its graph or not. Theano users can theano function should reoptimize its graph or not. Theano users can
use the standard python pickle tools to save a compiled theano use the standard python pickle tools to save a compiled theano
......
...@@ -225,7 +225,8 @@ stabilize "+++++" "++" Only applies stability opts ...@@ -225,7 +225,8 @@ stabilize "+++++" "++" Only applies stability opts
================= ============ ============== ================================================== ================= ============ ============== ==================================================
For a detailed list of the specific optimizations applied for each of these For a detailed list of the specific optimizations applied for each of these
optimizers, see :ref:`optimizations`. Also, see :ref:`unsafe_optimization`. optimizers, see :ref:`optimizations`. Also, see :ref:`unsafe_optimization` and
:ref:`faster-theano-function-compilation` for other trade-off.
.. _using_debugmode: .. _using_debugmode:
......
...@@ -2273,25 +2273,26 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2273,25 +2273,26 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
"of", len(li), "events was stable.", "of", len(li), "events was stable.",
file=sys.stderr) file=sys.stderr)
self.fgraph = fgraph self.fgraph = fgraph
destroy_handler_added = False if theano.config.cycle_detection == 'regular':
for feature in fgraph._features: destroy_handler_added = False
if isinstance(feature, gof.DestroyHandler): for feature in fgraph._features:
destroy_handler_added = True if isinstance(feature, gof.DestroyHandler):
break destroy_handler_added = True
if not destroy_handler_added: break
fgraph.attach_feature(gof.DestroyHandler()) if not destroy_handler_added:
for o in fgraph.outputs: fgraph.attach_feature(gof.DestroyHandler())
try: for o in fgraph.outputs:
with change_flags(compute_test_value=config.compute_test_value_opt): try:
fgraph.replace_validate(o, _output_guard(o), reason='output_guard') with change_flags(compute_test_value=config.compute_test_value_opt):
raise Exception("Output variable %s required output_guard, " fgraph.replace_validate(o, _output_guard(o), reason='output_guard')
"how was this output left unprotected against " raise Exception("Output variable %s required output_guard, "
"destructive operations?" % o) "how was this output left unprotected against "
"destructive operations?" % o)
except gof.InconsistencyError:
# This output is already impossible to destroy. except gof.InconsistencyError:
# No guard necessary # This output is already impossible to destroy.
pass # No guard necessary
pass
linker = _Linker(self) linker = _Linker(self)
......
...@@ -132,6 +132,11 @@ class Supervisor: ...@@ -132,6 +132,11 @@ class Supervisor:
self.protected = list(protected) self.protected = list(protected)
def validate(self, fgraph): def validate(self, fgraph):
if config.cycle_detection == 'fast' and hasattr(fgraph, 'has_destroyers'):
if fgraph.has_destroyers(self.protected):
raise gof.InconsistencyError("Trying to destroy a protected"
"Variable.")
return True
if not hasattr(fgraph, 'destroyers'): if not hasattr(fgraph, 'destroyers'):
return True return True
for r in self.protected + list(fgraph.outputs): for r in self.protected + list(fgraph.outputs):
...@@ -190,7 +195,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -190,7 +195,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
for spec, input in zip(input_specs, fgraph.inputs) for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or if not (spec.mutable or
(hasattr(fgraph, 'destroyers') and (hasattr(fgraph, 'destroyers') and
fgraph.destroyers(input))))) fgraph.has_destroyers([input])))))
# If named nodes are replaced, keep the name # If named nodes are replaced, keep the name
for feature in std_fgraph.features: for feature in std_fgraph.features:
...@@ -1111,7 +1116,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1111,7 +1116,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# We can't use fgraph.inputs as this don't include Constant Value. # We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs) all_graph_inputs = gof.graph.inputs(fgraph.outputs)
has_destroyers = hasattr(fgraph, 'get_destroyers_of') has_destroyers_attr = hasattr(fgraph, 'has_destroyers')
for i in xrange(len(fgraph.outputs)): for i in xrange(len(fgraph.outputs)):
views_of_output_i = set() views_of_output_i = set()
...@@ -1142,7 +1147,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1142,7 +1147,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# being updated # being updated
if input_j in updated_fgraph_inputs: if input_j in updated_fgraph_inputs:
continue continue
if input_j in views_of_output_i and not (has_destroyers and fgraph.get_destroyers_of(input_j)): if input_j in views_of_output_i and not (has_destroyers_attr and fgraph.has_destroyers([input_j])):
# We don't put deep_copy_op if the input and the # We don't put deep_copy_op if the input and the
# output have borrow==True # output have borrow==True
if input_j in fgraph.inputs: if input_j in fgraph.inputs:
......
...@@ -1575,7 +1575,7 @@ AddConfigVar('cycle_detection', ...@@ -1575,7 +1575,7 @@ AddConfigVar('cycle_detection',
"The interaction of which one give the lower peak memory usage is" "The interaction of which one give the lower peak memory usage is"
"complicated and not predictable, so if you are close to the peak" "complicated and not predictable, so if you are close to the peak"
"memory usage, triyng both could give you a small gain. ", "memory usage, triyng both could give you a small gain.",
EnumStr('regular', 'fast'), EnumStr('regular', 'fast'),
in_c_key=False) in_c_key=False)
......
...@@ -250,7 +250,7 @@ def fast_inplace_check(inputs): ...@@ -250,7 +250,7 @@ def fast_inplace_check(inputs):
inputs = [i for i in inputs if inputs = [i for i in inputs if
not isinstance(i, graph.Constant) and not isinstance(i, graph.Constant) and
not fgraph.destroyers(i) and not fgraph.has_destroyers([i]) and
i not in protected_inputs] i not in protected_inputs]
return inputs return inputs
...@@ -297,7 +297,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -297,7 +297,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
<unknown> <unknown>
""" """
pickle_rm_attr = ["destroyers"] pickle_rm_attr = ["destroyers", "has_destroyers"]
def __init__(self, do_imports_on_attach=True, algo=None): def __init__(self, do_imports_on_attach=True, algo=None):
self.fgraph = None self.fgraph = None
...@@ -394,6 +394,41 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -394,6 +394,41 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
return [] return []
fgraph.destroyers = get_destroyers_of fgraph.destroyers = get_destroyers_of
def has_destroyers(protected_list):
if self.algo != 'fast':
droot, _, root_destroyer = self.refresh_droot_impact()
for protected_var in protected_list:
try:
root_destroyer[droot[protected_var]]
return True
except KeyError:
pass
return False
def recursive_destroys_finder(protected_var):
# protected_var is the idx'th input of app.
for (app, idx) in protected_var.clients:
if app == 'output':
continue
destroy_maps = getattr(app.op, 'destroy_map', {}).values()
# If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
return True
for var_idx in getattr(app.op, 'view_map', {}).keys():
if idx in app.op.view_map[var_idx]:
# We need to recursivly check the destroy_map of all the
# outputs that we have a view_map on.
if recursive_destroys_finder(app.outputs[var_idx]):
return True
return False
for protected_var in protected_list:
if recursive_destroys_finder(protected_var):
return True
return False
fgraph.has_destroyers = has_destroyers
def refresh_droot_impact(self): def refresh_droot_impact(self):
""" """
Makes sure self.droot, self.impact, and self.root_destroyer are up to Makes sure self.droot, self.impact, and self.root_destroyer are up to
...@@ -416,6 +451,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -416,6 +451,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
del self.stale_droot del self.stale_droot
assert self.fgraph.destroyer_handler is self assert self.fgraph.destroyer_handler is self
delattr(self.fgraph, 'destroyers') delattr(self.fgraph, 'destroyers')
delattr(self.fgraph, 'has_destroyers')
delattr(self.fgraph, 'destroy_handler') delattr(self.fgraph, 'destroy_handler')
self.fgraph = None self.fgraph = None
...@@ -452,11 +488,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -452,11 +488,11 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
if len(v) > 0: if len(v) > 0:
self.fail_validate[app] = theano.gof.InconsistencyError( self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has view_map. " + str(reason)) "Destroyed variable has view_map. " + str(reason))
elif d: elif d:
d = d.get(inp_idx2, []) d = d.get(inp_idx2, [])
if len(d) > 0: if len(d) > 0:
self.fail_validate[app] = theano.gof.InconsistencyError( self.fail_validate[app] = theano.gof.InconsistencyError(
"Destroyed variable has destroy_map. " + str(reason)) "Destroyed variable has destroy_map. " + str(reason))
# These 2 assertions are commented since this function is called so many times # These 2 assertions are commented since this function is called so many times
# but they should be true. # but they should be true.
...@@ -474,13 +510,15 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -474,13 +510,15 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list # If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', None): dmap = getattr(app.op, 'destroy_map', None)
vmap = getattr(app.op, 'view_map', {})
if dmap:
self.destroyers.add(app) self.destroyers.add(app)
if self.algo == 'fast': if self.algo == 'fast':
self.fast_destroy(app, reason) self.fast_destroy(app, reason)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in iteritems(getattr(app.op, 'view_map', {})): for o_idx, i_idx_list in iteritems(vmap):
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
raise NotImplementedError( raise NotImplementedError(
'destroying this output invalidates multiple inputs', 'destroying this output invalidates multiple inputs',
......
...@@ -11,6 +11,7 @@ from theano.gof.opt import (OpKeyOptimizer, PatternSub, NavigatorOptimizer, ...@@ -11,6 +11,7 @@ from theano.gof.opt import (OpKeyOptimizer, PatternSub, NavigatorOptimizer,
from theano.gof import destroyhandler from theano.gof import destroyhandler
from theano.gof.fg import FunctionGraph, InconsistencyError from theano.gof.fg import FunctionGraph, InconsistencyError
from theano.gof.toolbox import ReplaceValidate from theano.gof.toolbox import ReplaceValidate
from theano.tests.unittest_tools import assertFailure_fast
from theano.configparser import change_flags from theano.configparser import change_flags
...@@ -169,6 +170,7 @@ def test_misc(): ...@@ -169,6 +170,7 @@ def test_misc():
###################### ######################
@assertFailure_fast
def test_aliased_inputs_replacement(): def test_aliased_inputs_replacement():
x, y, z = inputs() x, y, z = inputs()
tv = transpose_view(x) tv = transpose_view(x)
...@@ -200,6 +202,7 @@ def test_indestructible(): ...@@ -200,6 +202,7 @@ def test_indestructible():
consistent(g) consistent(g)
@assertFailure_fast
def test_usage_loop_through_views_2(): def test_usage_loop_through_views_2():
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(transpose_view(sigmoid(x))) e0 = transpose_view(transpose_view(sigmoid(x)))
...@@ -210,6 +213,7 @@ def test_usage_loop_through_views_2(): ...@@ -210,6 +213,7 @@ def test_usage_loop_through_views_2():
inconsistent(g) # we cut off the path to the sigmoid inconsistent(g) # we cut off the path to the sigmoid
@assertFailure_fast
def test_destroyers_loop(): def test_destroyers_loop():
# AddInPlace(x, y) and AddInPlace(y, x) should not coexist # AddInPlace(x, y) and AddInPlace(y, x) should not coexist
x, y, z = inputs() x, y, z = inputs()
...@@ -259,6 +263,7 @@ def test_aliased_inputs2(): ...@@ -259,6 +263,7 @@ def test_aliased_inputs2():
inconsistent(g) inconsistent(g)
@assertFailure_fast
def test_aliased_inputs_tolerate(): def test_aliased_inputs_tolerate():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_2(x, x) e = add_in_place_2(x, x)
...@@ -273,6 +278,7 @@ def test_aliased_inputs_tolerate2(): ...@@ -273,6 +278,7 @@ def test_aliased_inputs_tolerate2():
inconsistent(g) inconsistent(g)
@assertFailure_fast
def test_same_aliased_inputs_ignored(): def test_same_aliased_inputs_ignored():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_3(x, x) e = add_in_place_3(x, x)
...@@ -280,6 +286,7 @@ def test_same_aliased_inputs_ignored(): ...@@ -280,6 +286,7 @@ def test_same_aliased_inputs_ignored():
consistent(g) consistent(g)
@assertFailure_fast
def test_different_aliased_inputs_ignored(): def test_different_aliased_inputs_ignored():
x, y, z = inputs() x, y, z = inputs()
e = add_in_place_3(x, transpose_view(x)) e = add_in_place_3(x, transpose_view(x))
...@@ -314,6 +321,7 @@ def test_indirect(): ...@@ -314,6 +321,7 @@ def test_indirect():
inconsistent(g) inconsistent(g)
@assertFailure_fast
def test_indirect_2(): def test_indirect_2():
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(x) e0 = transpose_view(x)
...@@ -325,6 +333,7 @@ def test_indirect_2(): ...@@ -325,6 +333,7 @@ def test_indirect_2():
consistent(g) consistent(g)
@assertFailure_fast
def test_long_destroyers_loop(): def test_long_destroyers_loop():
x, y, z = inputs() x, y, z = inputs()
e = dot(dot(add_in_place(x, y), e = dot(dot(add_in_place(x, y),
...@@ -366,6 +375,7 @@ def test_multi_destroyers(): ...@@ -366,6 +375,7 @@ def test_multi_destroyers():
pass pass
@assertFailure_fast
def test_multi_destroyers_through_views(): def test_multi_destroyers_through_views():
x, y, z = inputs() x, y, z = inputs()
e = dot(add(transpose_view(z), y), add(z, x)) e = dot(add(transpose_view(z), y), add(z, x))
...@@ -408,6 +418,7 @@ def test_usage_loop_through_views(): ...@@ -408,6 +418,7 @@ def test_usage_loop_through_views():
consistent(g) consistent(g)
@assertFailure_fast
def test_usage_loop_insert_views(): def test_usage_loop_insert_views():
x, y, z = inputs() x, y, z = inputs()
e = dot(add_in_place(x, add(y, z)), e = dot(add_in_place(x, add(y, z)),
...@@ -442,6 +453,7 @@ def test_value_repl_2(): ...@@ -442,6 +453,7 @@ def test_value_repl_2():
consistent(g) consistent(g)
@assertFailure_fast
def test_multiple_inplace(): def test_multiple_inplace():
# this tests issue #5223 # this tests issue #5223
# there were some problems with Ops that have more than # there were some problems with Ops that have more than
......
...@@ -1754,6 +1754,7 @@ def test_without_dnn_batchnorm_train_without_running_averages(): ...@@ -1754,6 +1754,7 @@ def test_without_dnn_batchnorm_train_without_running_averages():
f_abstract(X, Scale, Bias, Dy) f_abstract(X, Scale, Bias, Dy)
@utt.assertFailure_fast
def test_dnn_batchnorm_train_inplace(): def test_dnn_batchnorm_train_inplace():
# test inplace_running_mean and inplace_running_var # test inplace_running_mean and inplace_running_var
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
...@@ -1876,6 +1877,7 @@ def test_batchnorm_inference(): ...@@ -1876,6 +1877,7 @@ def test_batchnorm_inference():
utt.assert_allclose(outputs_abstract[5], outputs_ref[5], rtol=2e-3, atol=4e-5) # dvar utt.assert_allclose(outputs_abstract[5], outputs_ref[5], rtol=2e-3, atol=4e-5) # dvar
@utt.assertFailure_fast
def test_batchnorm_inference_inplace(): def test_batchnorm_inference_inplace():
# test inplace # test inplace
if not dnn.dnn_available(test_ctx_name): if not dnn.dnn_available(test_ctx_name):
......
...@@ -175,6 +175,7 @@ class TestGpuCholesky(unittest.TestCase): ...@@ -175,6 +175,7 @@ class TestGpuCholesky(unittest.TestCase):
GpuCholesky(lower=True, inplace=False)(A) GpuCholesky(lower=True, inplace=False)(A)
self.assertRaises(AssertionError, invalid_input_func) self.assertRaises(AssertionError, invalid_input_func)
@utt.assertFailure_fast
def test_diag_chol(self): def test_diag_chol(self):
# Diagonal matrix input Cholesky test. # Diagonal matrix input Cholesky test.
for lower in [True, False]: for lower in [True, False]:
...@@ -183,6 +184,7 @@ class TestGpuCholesky(unittest.TestCase): ...@@ -183,6 +184,7 @@ class TestGpuCholesky(unittest.TestCase):
A_val = np.diag(np.random.uniform(size=5).astype("float32") + 1) A_val = np.diag(np.random.uniform(size=5).astype("float32") + 1)
self.compare_gpu_cholesky_to_np(A_val, lower=lower, inplace=inplace) self.compare_gpu_cholesky_to_np(A_val, lower=lower, inplace=inplace)
@utt.assertFailure_fast
def test_dense_chol_lower(self): def test_dense_chol_lower(self):
# Dense matrix input lower-triangular Cholesky test. # Dense matrix input lower-triangular Cholesky test.
for lower in [True, False]: for lower in [True, False]:
...@@ -243,6 +245,7 @@ class TestMagma(unittest.TestCase): ...@@ -243,6 +245,7 @@ class TestMagma(unittest.TestCase):
A_val_inv = fn(A_val) A_val_inv = fn(A_val)
utt.assert_allclose(np.eye(N), np.dot(A_val_inv, A_val), atol=1e-2) utt.assert_allclose(np.eye(N), np.dot(A_val_inv, A_val), atol=1e-2)
@utt.assertFailure_fast
def test_gpu_matrix_inverse_inplace(self): def test_gpu_matrix_inverse_inplace(self):
N = 1000 N = 1000
test_rng = np.random.RandomState(seed=1) test_rng = np.random.RandomState(seed=1)
...@@ -258,6 +261,7 @@ class TestMagma(unittest.TestCase): ...@@ -258,6 +261,7 @@ class TestMagma(unittest.TestCase):
fn() fn()
utt.assert_allclose(np.eye(N), np.dot(A_val_gpu.get_value(), A_val_copy), atol=5e-3) utt.assert_allclose(np.eye(N), np.dot(A_val_gpu.get_value(), A_val_copy), atol=5e-3)
@utt.assertFailure_fast
def test_gpu_matrix_inverse_inplace_opt(self): def test_gpu_matrix_inverse_inplace_opt(self):
A = theano.tensor.fmatrix("A") A = theano.tensor.fmatrix("A")
fn = theano.function([A], matrix_inverse(A), mode=mode_with_gpu) fn = theano.function([A], matrix_inverse(A), mode=mode_with_gpu)
...@@ -360,6 +364,7 @@ class TestMagma(unittest.TestCase): ...@@ -360,6 +364,7 @@ class TestMagma(unittest.TestCase):
assert any([isinstance(node.op, GpuMagmaCholesky) assert any([isinstance(node.op, GpuMagmaCholesky)
for node in fn.maker.fgraph.toposort()]) for node in fn.maker.fgraph.toposort()])
@utt.assertFailure_fast
def test_gpu_cholesky_inplace(self): def test_gpu_cholesky_inplace(self):
A = self.rand_symmetric(1000) A = self.rand_symmetric(1000)
A_gpu = gpuarray_shared_constructor(A) A_gpu = gpuarray_shared_constructor(A)
...@@ -375,6 +380,7 @@ class TestMagma(unittest.TestCase): ...@@ -375,6 +380,7 @@ class TestMagma(unittest.TestCase):
L = A_gpu.get_value() L = A_gpu.get_value()
utt.assert_allclose(np.dot(L, L.T), A_copy, atol=1e-3) utt.assert_allclose(np.dot(L, L.T), A_copy, atol=1e-3)
@utt.assertFailure_fast
def test_gpu_cholesky_inplace_opt(self): def test_gpu_cholesky_inplace_opt(self):
A = theano.tensor.fmatrix("A") A = theano.tensor.fmatrix("A")
fn = theano.function([A], GpuMagmaCholesky()(A), mode=mode_with_gpu) fn = theano.function([A], GpuMagmaCholesky()(A), mode=mode_with_gpu)
......
...@@ -585,6 +585,7 @@ def test_no_complex(): ...@@ -585,6 +585,7 @@ def test_no_complex():
mode=mode_with_gpu) mode=mode_with_gpu)
@utt.assertFailure_fast
def test_local_lift_solve(): def test_local_lift_solve():
if not cusolver_available: if not cusolver_available:
raise SkipTest('No cuSolver') raise SkipTest('No cuSolver')
...@@ -619,6 +620,7 @@ def test_gpu_solve_not_inplace(): ...@@ -619,6 +620,7 @@ def test_gpu_solve_not_inplace():
utt.assert_allclose(f_cpu(A_val, b_val), f_gpu(A_val, b_val)) utt.assert_allclose(f_cpu(A_val, b_val), f_gpu(A_val, b_val))
@utt.assertFailure_fast
def test_local_lift_cholesky(): def test_local_lift_cholesky():
if not cusolver_available: if not cusolver_available:
raise SkipTest('No cuSolver') raise SkipTest('No cuSolver')
......
...@@ -886,6 +886,7 @@ class T_Scan(unittest.TestCase): ...@@ -886,6 +886,7 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(numpy_out, theano_out) utt.assert_allclose(numpy_out, theano_out)
# simple rnn ; compute inplace version 1 # simple rnn ; compute inplace version 1
@utt.assertFailure_fast
def test_inplace1(self): def test_inplace1(self):
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
vW = asarrayX(np.random.uniform()) vW = asarrayX(np.random.uniform())
...@@ -950,6 +951,7 @@ class T_Scan(unittest.TestCase): ...@@ -950,6 +951,7 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(theano_x1, numpy_x1) utt.assert_allclose(theano_x1, numpy_x1)
# simple rnn ; compute inplace version 2 # simple rnn ; compute inplace version 2
@utt.assertFailure_fast
def test_inplace2(self): def test_inplace2(self):
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
vW = asarrayX(np.random.uniform()) vW = asarrayX(np.random.uniform())
...@@ -1021,6 +1023,7 @@ class T_Scan(unittest.TestCase): ...@@ -1021,6 +1023,7 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(theano_x0, numpy_x0) utt.assert_allclose(theano_x0, numpy_x0)
utt.assert_allclose(theano_x1, numpy_x1) utt.assert_allclose(theano_x1, numpy_x1)
@utt.assertFailure_fast
def test_inplace3(self): def test_inplace3(self):
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
......
...@@ -3201,6 +3201,7 @@ import theano.tensor.tests.test_sharedvar ...@@ -3201,6 +3201,7 @@ import theano.tensor.tests.test_sharedvar
theano_fct_=lambda a: dense_from_sparse(a * 2.), theano_fct_=lambda a: dense_from_sparse(a * 2.),
ref_fct_=lambda a: np.asarray((a * 2).todense()), ref_fct_=lambda a: np.asarray((a * 2).todense()),
cast_value_=scipy.sparse.csr_matrix, cast_value_=scipy.sparse.csr_matrix,
expect_fail_fast_shape_inplace=False,
) )
class test_shared_options(object): class test_shared_options(object):
pass pass
......
...@@ -579,7 +579,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -579,7 +579,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph) theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
assert (fgraph.outputs[0].owner.op == assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
...@@ -652,7 +651,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -652,7 +651,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
# print node.op # print node.op
# print '====' # print '===='
assert len(fgraph.toposort()) == 2 assert len(fgraph.toposort()) == 2
assert (fgraph.outputs[0].owner.op == assert (fgraph.outputs[0].owner.op ==
crossentropy_softmax_argmax_1hot_with_bias) crossentropy_softmax_argmax_1hot_with_bias)
...@@ -1382,6 +1380,7 @@ def test_argmax_pushdown_bias(): ...@@ -1382,6 +1380,7 @@ def test_argmax_pushdown_bias():
# print node.op # print node.op
types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.Argmax) types_to_check = (tensor.DimShuffle, tensor.Elemwise, tensor.Argmax)
assert len(fgraph.toposort()) == 3 assert len(fgraph.toposort()) == 3
for i, type in enumerate(types_to_check): for i, type in enumerate(types_to_check):
assert isinstance(fgraph.toposort()[i].op, type) assert isinstance(fgraph.toposort()[i].op, type)
assert check_stack_trace(fgraph, ops_to_check=types_to_check) assert check_stack_trace(fgraph, ops_to_check=types_to_check)
......
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import theano import theano
from theano import tensor from theano import tensor
from theano.tests.unittest_tools import assertFailure_fast
from theano.gof.opt import check_stack_trace from theano.gof.opt import check_stack_trace
from theano.tensor.nnet.blocksparse import ( from theano.tensor.nnet.blocksparse import (
sparse_block_dot, sparse_block_gemv_inplace, sparse_block_outer_inplace, sparse_block_dot, sparse_block_gemv_inplace, sparse_block_outer_inplace,
...@@ -25,6 +26,9 @@ def test_blocksparse_inplace_gemv_opt(): ...@@ -25,6 +26,9 @@ def test_blocksparse_inplace_gemv_opt():
assert f.maker.fgraph.toposort()[-1].op.inplace assert f.maker.fgraph.toposort()[-1].op.inplace
assert check_stack_trace(f, ops_to_check=[sparse_block_gemv_inplace]) assert check_stack_trace(f, ops_to_check=[sparse_block_gemv_inplace])
if theano.config.mode != 'FAST_COMPILE':
test_blocksparse_inplace_gemv_opt = assertFailure_fast(test_blocksparse_inplace_gemv_opt)
def test_blocksparse_inplace_outer_opt(): def test_blocksparse_inplace_outer_opt():
b = tensor.fmatrix() b = tensor.fmatrix()
......
...@@ -265,8 +265,8 @@ class InplaceElemwiseOptimizer(Optimizer): ...@@ -265,8 +265,8 @@ class InplaceElemwiseOptimizer(Optimizer):
candidate_inputs = [i for i in xrange(len(node.inputs)) candidate_inputs = [i for i in xrange(len(node.inputs))
if i not in baseline.values() and if i not in baseline.values() and
not isinstance(node.inputs[i], Constant) and not isinstance(node.inputs[i], Constant) and
# Is next line costly? # the next line should not be costly most of the time.
not fgraph.destroyers(node.inputs[i]) and not fgraph.has_destroyers([node.inputs[i]]) and
node.inputs[i] not in protected_inputs] node.inputs[i] not in protected_inputs]
else: else:
baseline = [] baseline = []
...@@ -277,7 +277,7 @@ class InplaceElemwiseOptimizer(Optimizer): ...@@ -277,7 +277,7 @@ class InplaceElemwiseOptimizer(Optimizer):
# Remove here as faster. # Remove here as faster.
candidate_inputs = [i for i in xrange(len(node.inputs)) candidate_inputs = [i for i in xrange(len(node.inputs))
if not isinstance(node.inputs[i], Constant) and if not isinstance(node.inputs[i], Constant) and
not fgraph.destroyers(node.inputs[i]) and not fgraph.has_destroyers([node.inputs[i]]) and
node.inputs[i] not in protected_inputs] node.inputs[i] not in protected_inputs]
verbose = False verbose = False
......
...@@ -4806,6 +4806,9 @@ class T_exp(unittest.TestCase): ...@@ -4806,6 +4806,9 @@ class T_exp(unittest.TestCase):
np.asarray([[1.5089518, 1.48439076, -4.7820262], np.asarray([[1.5089518, 1.48439076, -4.7820262],
[2.04832468, 0.50791564, -1.58892269]])]) [2.04832468, 0.50791564, -1.58892269]])])
if theano.config.cycle_detection == 'fast' and theano.config.mode != 'FAST_COMPILE':
test_grad_1 = unittest.expectedFailure(test_grad_1)
def test_int(self): def test_int(self):
x = ivector() x = ivector()
f = function([x], exp(x)) f = function([x], exp(x))
......
...@@ -500,6 +500,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], ...@@ -500,6 +500,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
raise raise
@unittest_tools.assertFailure_fast
def test_gemm_opt0(): def test_gemm_opt0():
# Many subgraphs whose dots can be eliminated # Many subgraphs whose dots can be eliminated
X, Y, Z, a, b = XYZab() X, Y, Z, a, b = XYZab()
...@@ -528,6 +529,7 @@ def test_gemm_opt0(): ...@@ -528,6 +529,7 @@ def test_gemm_opt0():
just_gemm([X, Y, Z, a, b], [Z - a * b * a * T.dot(X, Y)]) just_gemm([X, Y, Z, a, b], [Z - a * b * a * T.dot(X, Y)])
@unittest_tools.assertFailure_fast
def test_gemm_opt_double_gemm(): def test_gemm_opt_double_gemm():
# This is the pattern that shows up in the autoencoder # This is the pattern that shows up in the autoencoder
X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar() X, Y, Z, a, b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
......
...@@ -1367,6 +1367,7 @@ class TestCompositeCodegen(unittest.TestCase): ...@@ -1367,6 +1367,7 @@ class TestCompositeCodegen(unittest.TestCase):
utt.assert_allclose(f([[1.]]), [[0.]]) utt.assert_allclose(f([[1.]]), [[0.]])
@utt.assertFailure_fast
def test_log1p(): def test_log1p():
m = theano.config.mode m = theano.config.mode
if m == 'FAST_COMPILE': if m == 'FAST_COMPILE':
...@@ -1989,6 +1990,7 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -1989,6 +1990,7 @@ class test_local_subtensor_lift(unittest.TestCase):
assert len(prog) == 3 assert len(prog) == 3
f([4, 5]) # let debugmode test something f([4, 5]) # let debugmode test something
@utt.assertFailure_fast
def test4(self): def test4(self):
# basic test that the optimization doesn't work with broadcasting # basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to, # ... It *could* be extended to,
......
...@@ -27,6 +27,7 @@ def makeSharedTester(shared_constructor_, ...@@ -27,6 +27,7 @@ def makeSharedTester(shared_constructor_,
theano_fct_, theano_fct_,
ref_fct_, ref_fct_,
cast_value_=np.asarray, cast_value_=np.asarray,
expect_fail_fast_shape_inplace=True,
): ):
""" """
This is a generic fct to allow reusing the same test function This is a generic fct to allow reusing the same test function
...@@ -549,6 +550,10 @@ def makeSharedTester(shared_constructor_, ...@@ -549,6 +550,10 @@ def makeSharedTester(shared_constructor_,
assert sum([node.op.__class__.__name__ in ["Gemm", "GpuGemm", "StructuredDot"] for node in topo]) == 1 assert sum([node.op.__class__.__name__ in ["Gemm", "GpuGemm", "StructuredDot"] for node in topo]) == 1
assert all(node.op == tensor.blas.gemm_inplace for node in topo if isinstance(node.op, tensor.blas.Gemm)) assert all(node.op == tensor.blas.gemm_inplace for node in topo if isinstance(node.op, tensor.blas.Gemm))
assert all(node.op.inplace for node in topo if node.op.__class__.__name__ == "GpuGemm") assert all(node.op.inplace for node in topo if node.op.__class__.__name__ == "GpuGemm")
if theano.config.cycle_detection == 'fast' and expect_fail_fast_shape_inplace and theano.config.mode != 'FAST_COMPILE':
test_specify_shape_inplace = unittest.expectedFailure(test_specify_shape_inplace)
def test_values_eq(self): def test_values_eq(self):
""" Test the type.values_eq[_approx] function""" """ Test the type.values_eq[_approx] function"""
dtype = self.dtype dtype = self.dtype
......
...@@ -5,6 +5,7 @@ import logging ...@@ -5,6 +5,7 @@ import logging
import sys import sys
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
from nose.tools import assert_raises
from six import integer_types from six import integer_types
from six.moves import StringIO from six.moves import StringIO
...@@ -445,3 +446,16 @@ class AttemptManyTimes: ...@@ -445,3 +446,16 @@ class AttemptManyTimes:
current_seed = str(int(current_seed) + 1) current_seed = str(int(current_seed) + 1)
return attempt_multiple_times return attempt_multiple_times
def assertFailure_fast(f):
"""A Decorator to handle the test cases that are failing when
THEANO_FLAGS =cycle_detection='fast'.
"""
if theano.config.cycle_detection == 'fast':
def test_with_assert(*args, **kwargs):
with assert_raises(Exception):
f(*args, **kwargs)
return test_with_assert
else:
return f
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论