提交 5d90f41b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Better fix for verify_grad

上级 be3731f5
...@@ -17,7 +17,7 @@ from theano.compat import izip ...@@ -17,7 +17,7 @@ from theano.compat import izip
from six.moves import xrange, reduce from six.moves import xrange, reduce
from theano.gof.null_type import NullType, null_type from theano.gof.null_type import NullType, null_type
from theano.gof.op import get_debug_values from theano.gof.op import get_debug_values
from theano.compile import ViewOp, FAST_RUN, DebugMode from theano.compile import ViewOp, FAST_RUN, DebugMode, get_mode
__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow" __authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
__copyright__ = "(c) 2011, Universite de Montreal" __copyright__ = "(c) 2011, Universite de Montreal"
...@@ -1551,7 +1551,10 @@ class numeric_grad(object): ...@@ -1551,7 +1551,10 @@ class numeric_grad(object):
return (max_arg, max_pos, abs_errs[max_arg], rel_errs[max_arg]) return (max_arg, max_pos, abs_errs[max_arg], rel_errs[max_arg])
def mode_not_debug(mode): def mode_not_slow(mode):
if mode == 'FAST_COMPILE':
return FAST_RUN
mode = get_mode(mode)
if isinstance(mode, DebugMode): if isinstance(mode, DebugMode):
opt = mode.optimizer opt = mode.optimizer
return FAST_RUN.clone(optimizer=opt) return FAST_RUN.clone(optimizer=opt)
...@@ -1686,7 +1689,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1686,7 +1689,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
cost = theano.tensor.sum(t_r * o_output) cost = theano.tensor.sum(t_r * o_output)
if no_debug_ref: if no_debug_ref:
mode_for_cost = mode_not_debug(mode) mode_for_cost = mode_not_slow(mode)
else: else:
mode_for_cost = mode mode_for_cost = mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论