提交 460debbf authored 作者: Frederic Bastien's avatar Frederic Bastien

fix test in FAST_COMPILE mode

上级 60c176f2
...@@ -22,11 +22,13 @@ from numpy.testing import dec ...@@ -22,11 +22,13 @@ from numpy.testing import dec
from numpy.testing.noseclasses import KnownFailureTest from numpy.testing.noseclasses import KnownFailureTest
imported_scipy_special = False imported_scipy_special = False
mode_no_scipy = get_default_mode()
try: try:
import scipy.special import scipy.special
imported_scipy_special = True imported_scipy_special = True
except ImportError: except ImportError:
pass if config.mode=="FAST_COMPILE":
mode_no_scipy = "FAST_RUN"
### seed random number generator so that unittests are deterministic ### ### seed random number generator so that unittests are deterministic ###
utt.seed_rng() utt.seed_rng()
...@@ -57,11 +59,11 @@ def safe_make_node(op, *inputs): ...@@ -57,11 +59,11 @@ def safe_make_node(op, *inputs):
else: else:
return node.owner return node.owner
def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}): def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}, mode = None):
if grad is True: if grad is True:
grad = good grad = good
_op, _expected, _checks, _good, _bad_build, _bad_runtime, _grad = op, expected, checks, good, bad_build, bad_runtime, grad _op, _expected, _checks, _good, _bad_build, _bad_runtime, _grad, _mode = op, expected, checks, good, bad_build, bad_runtime, grad, mode
class Checker(unittest.TestCase): class Checker(unittest.TestCase):
...@@ -72,6 +74,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r ...@@ -72,6 +74,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r
bad_build = _bad_build bad_build = _bad_build
bad_runtime = _bad_runtime bad_runtime = _bad_runtime
grad = _grad grad = _grad
mode = _mode
def test_good(self): def test_good(self):
for testname, inputs in self.good.items(): for testname, inputs in self.good.items():
...@@ -88,7 +91,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r ...@@ -88,7 +91,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r
raise type, exc_value, traceback raise type, exc_value, traceback
try: try:
f = inplace_func(inputrs, node.outputs) f = inplace_func(inputrs, node.outputs, mode = mode)
except: except:
type, exc_value, traceback = sys.exc_info() type, exc_value, traceback = sys.exc_info()
err_msg = "Test %s::%s: Error occurred while trying to make a Function" \ err_msg = "Test %s::%s: Error occurred while trying to make a Function" \
...@@ -172,7 +175,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r ...@@ -172,7 +175,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r
inputs = [copy(input) for input in inputs] inputs = [copy(input) for input in inputs]
inputrs = [value(input) for input in inputs] inputrs = [value(input) for input in inputs]
try: try:
utt.verify_grad(self.op, inputs) utt.verify_grad(self.op, inputs, mode=self.mode)
except: except:
type, exc_value, traceback = sys.exc_info() type, exc_value, traceback = sys.exc_info()
err_msg = "Test %s::%s: Error occurred while computing the gradient on the following inputs: %s" \ err_msg = "Test %s::%s: Error occurred while computing the gradient on the following inputs: %s" \
...@@ -645,11 +648,13 @@ else: ...@@ -645,11 +648,13 @@ else:
ErfTester = makeBroadcastTester(op = erf, ErfTester = makeBroadcastTester(op = erf,
expected = expected, expected = expected,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal,
mode = mode_no_scipy)
ErfInplaceTester = makeBroadcastTester(op = inplace.erf_inplace, ErfInplaceTester = makeBroadcastTester(op = inplace.erf_inplace,
expected = expected, expected = expected,
good = _good_broadcast_unary_normal_no_int, good = _good_broadcast_unary_normal_no_int,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
mode = mode_no_scipy,
inplace = True) inplace = True)
if imported_scipy_special: if imported_scipy_special:
...@@ -668,11 +673,13 @@ else: ...@@ -668,11 +673,13 @@ else:
ErfcTester = makeBroadcastTester(op = erfc, ErfcTester = makeBroadcastTester(op = erfc,
expected = expected, expected = expected,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal) grad = _grad_broadcast_unary_normal,
mode = mode_no_scipy)
ErfcInplaceTester = makeBroadcastTester(op = inplace.erfc_inplace, ErfcInplaceTester = makeBroadcastTester(op = inplace.erfc_inplace,
expected = expected, expected = expected,
good = _good_broadcast_unary_normal_no_int, good = _good_broadcast_unary_normal_no_int,
grad = _grad_broadcast_unary_normal, grad = _grad_broadcast_unary_normal,
mode = mode_no_scipy,
inplace = True) inplace = True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论