提交 8560b2ca authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6199 from abergeron/fix_buildbot

Fix some problems in the daily buildbot in DEBUG_MODE.
...@@ -28,6 +28,7 @@ from theano.compile.function_module import ( ...@@ -28,6 +28,7 @@ from theano.compile.function_module import (
std_fgraph) std_fgraph)
from theano.compile.mode import Mode, register_mode from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard, _output_guard from theano.compile.ops import OutputGuard, _output_guard
from theano.configparser import change_flags
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
...@@ -2227,17 +2228,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2227,17 +2228,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inputs, outputs, accept_inplace) inputs, outputs, accept_inplace)
fgraph.equivalence_tracker = equivalence_tracker fgraph.equivalence_tracker = equivalence_tracker
# optimize the fgraph with change_flags(compute_test_value=config.compute_test_value_opt):
compute_test_value_orig = theano.config.compute_test_value
try:
theano.config.compute_test_value = \
theano.config.compute_test_value_opt
optimizer(fgraph) optimizer(fgraph)
theano.compile.function_module.insert_deepcopy( theano.compile.function_module.insert_deepcopy(
fgraph, inputs, list(chain(outputs, additional_outputs))) fgraph, inputs, list(chain(outputs, additional_outputs)))
finally:
theano.config.compute_test_value = compute_test_value_orig
if i == 0: if i == 0:
fgraph0 = fgraph fgraph0 = fgraph
...@@ -2286,7 +2281,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2286,7 +2281,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph.attach_feature(gof.DestroyHandler()) fgraph.attach_feature(gof.DestroyHandler())
for o in fgraph.outputs: for o in fgraph.outputs:
try: try:
fgraph.replace_validate(o, _output_guard(o), reason='output_guard') with change_flags(compute_test_value=config.compute_test_value_opt):
fgraph.replace_validate(o, _output_guard(o), reason='output_guard')
raise Exception("Output variable %s required output_guard, " raise Exception("Output variable %s required output_guard, "
"how was this output left unprotected against " "how was this output left unprotected against "
"destructive operations?" % o) "destructive operations?" % o)
......
...@@ -293,7 +293,7 @@ def test_batch_normalization_train_grad_grad(): ...@@ -293,7 +293,7 @@ def test_batch_normalization_train_grad_grad():
x_mean_val = np.random.randn(*param_shape).astype('float64') x_mean_val = np.random.randn(*param_shape).astype('float64')
x_invstd_val = np.random.randn(*param_shape).astype('float64') x_invstd_val = np.random.randn(*param_shape).astype('float64')
utt.verify_grad(bn_grad_wrt_inputs_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val]) utt.verify_grad(bn_grad_wrt_inputs_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val], abs_tol=5e-4, rel_tol=5e-4)
utt.verify_grad(bn_grad_wrt_scale_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val]) utt.verify_grad(bn_grad_wrt_scale_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
utt.verify_grad(bn_grad_wrt_bias_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val]) utt.verify_grad(bn_grad_wrt_bias_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论