提交 8560b2ca authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6199 from abergeron/fix_buildbot

Fix some problems in the daily buildbot in DEBUG_MODE.
......@@ -28,6 +28,7 @@ from theano.compile.function_module import (
std_fgraph)
from theano.compile.mode import Mode, register_mode
from theano.compile.ops import OutputGuard, _output_guard
from theano.configparser import change_flags
__docformat__ = "restructuredtext en"
......@@ -2227,17 +2228,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inputs, outputs, accept_inplace)
fgraph.equivalence_tracker = equivalence_tracker
# optimize the fgraph
compute_test_value_orig = theano.config.compute_test_value
try:
theano.config.compute_test_value = \
theano.config.compute_test_value_opt
with change_flags(compute_test_value=config.compute_test_value_opt):
optimizer(fgraph)
theano.compile.function_module.insert_deepcopy(
fgraph, inputs, list(chain(outputs, additional_outputs)))
finally:
theano.config.compute_test_value = compute_test_value_orig
if i == 0:
fgraph0 = fgraph
......@@ -2286,7 +2281,8 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph.attach_feature(gof.DestroyHandler())
for o in fgraph.outputs:
try:
fgraph.replace_validate(o, _output_guard(o), reason='output_guard')
with change_flags(compute_test_value=config.compute_test_value_opt):
fgraph.replace_validate(o, _output_guard(o), reason='output_guard')
raise Exception("Output variable %s required output_guard, "
"how was this output left unprotected against "
"destructive operations?" % o)
......
......@@ -293,7 +293,7 @@ def test_batch_normalization_train_grad_grad():
x_mean_val = np.random.randn(*param_shape).astype('float64')
x_invstd_val = np.random.randn(*param_shape).astype('float64')
utt.verify_grad(bn_grad_wrt_inputs_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
utt.verify_grad(bn_grad_wrt_inputs_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val], abs_tol=5e-4, rel_tol=5e-4)
utt.verify_grad(bn_grad_wrt_scale_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
utt.verify_grad(bn_grad_wrt_bias_f, [x_val, dy_val, scale_val, x_mean_val, x_invstd_val])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论