提交 460debbf authored 作者: Frederic Bastien's avatar Frederic Bastien

fix test in FAST_COMPILE mode

上级 60c176f2
......@@ -22,11 +22,13 @@ from numpy.testing import dec
from numpy.testing.noseclasses import KnownFailureTest
imported_scipy_special = False
mode_no_scipy = get_default_mode()
try:
import scipy.special
imported_scipy_special = True
except ImportError:
pass
if config.mode=="FAST_COMPILE":
mode_no_scipy = "FAST_RUN"
### seed random number generator so that unittests are deterministic ###
utt.seed_rng()
......@@ -57,11 +59,11 @@ def safe_make_node(op, *inputs):
else:
return node.owner
def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}):
def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}, mode = None):
if grad is True:
grad = good
_op, _expected, _checks, _good, _bad_build, _bad_runtime, _grad = op, expected, checks, good, bad_build, bad_runtime, grad
_op, _expected, _checks, _good, _bad_build, _bad_runtime, _grad, _mode = op, expected, checks, good, bad_build, bad_runtime, grad, mode
class Checker(unittest.TestCase):
......@@ -72,6 +74,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r
bad_build = _bad_build
bad_runtime = _bad_runtime
grad = _grad
mode = _mode
def test_good(self):
for testname, inputs in self.good.items():
......@@ -88,7 +91,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r
raise type, exc_value, traceback
try:
f = inplace_func(inputrs, node.outputs)
f = inplace_func(inputrs, node.outputs, mode = mode)
except:
type, exc_value, traceback = sys.exc_info()
err_msg = "Test %s::%s: Error occurred while trying to make a Function" \
......@@ -172,7 +175,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r
inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs]
try:
utt.verify_grad(self.op, inputs)
utt.verify_grad(self.op, inputs, mode=self.mode)
except:
type, exc_value, traceback = sys.exc_info()
err_msg = "Test %s::%s: Error occurred while computing the gradient on the following inputs: %s" \
......@@ -645,11 +648,13 @@ else:
ErfTester = makeBroadcastTester(op = erf,
expected = expected,
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal)
grad = _grad_broadcast_unary_normal,
mode = mode_no_scipy)
ErfInplaceTester = makeBroadcastTester(op = inplace.erf_inplace,
expected = expected,
good = _good_broadcast_unary_normal_no_int,
grad = _grad_broadcast_unary_normal,
mode = mode_no_scipy,
inplace = True)
if imported_scipy_special:
......@@ -668,11 +673,13 @@ else:
ErfcTester = makeBroadcastTester(op = erfc,
expected = expected,
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal)
grad = _grad_broadcast_unary_normal,
mode = mode_no_scipy)
ErfcInplaceTester = makeBroadcastTester(op = inplace.erfc_inplace,
expected = expected,
good = _good_broadcast_unary_normal_no_int,
grad = _grad_broadcast_unary_normal,
mode = mode_no_scipy,
inplace = True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论