提交 ff91168e authored 作者: Mathieu Germain's avatar Mathieu Germain

added nose_parameterized dependency

上级 35c8fdf1
...@@ -49,7 +49,7 @@ instructions below for detailed installation steps): ...@@ -49,7 +49,7 @@ instructions below for detailed installation steps):
The following libraries and software are optional: The following libraries and software are optional:
`nose <http://nose.readthedocs.org/en/latest/>`_ >= 1.3.0 `nose <http://nose.readthedocs.org/en/latest/>`_ >= 1.3.0 and `nose-parameterized <https://pypi.python.org/pypi/nose-parameterized/>`_ >= 0.5.0
Recommended, to run Theano's test-suite. Recommended, to run Theano's test-suite.
`Sphinx <http://sphinx.pocoo.org/>`_ >= 0.5.1, `pygments <http://pygments.org/>`_ `Sphinx <http://sphinx.pocoo.org/>`_ >= 0.5.1, `pygments <http://pygments.org/>`_
......
...@@ -164,6 +164,9 @@ def do_setup(): ...@@ -164,6 +164,9 @@ def do_setup():
packages=find_packages(), packages=find_packages(),
# 1.7.0 give too much warning related to numpy.diagonal. # 1.7.0 give too much warning related to numpy.diagonal.
install_requires=['numpy>=1.7.1', 'scipy>=0.11', 'six>=1.9.0'], install_requires=['numpy>=1.7.1', 'scipy>=0.11', 'six>=1.9.0'],
extras_require={
'test': ['nose>=1.3.0', 'nose-parameterized>=0.5.0']
},
package_data={ package_data={
'': ['*.txt', '*.rst', '*.cu', '*.cuh', '*.c', '*.sh', '*.pkl', '': ['*.txt', '*.rst', '*.cu', '*.cuh', '*.c', '*.sh', '*.pkl',
'*.h', '*.cpp', 'ChangeLog'], '*.h', '*.cpp', 'ChangeLog'],
......
...@@ -4,6 +4,7 @@ from functools import wraps ...@@ -4,6 +4,7 @@ from functools import wraps
import logging import logging
import sys import sys
import unittest import unittest
from nose_parameterized import parameterized
from six import integer_types from six import integer_types
from six.moves import StringIO from six.moves import StringIO
...@@ -31,6 +32,13 @@ except ImportError: ...@@ -31,6 +32,13 @@ except ImportError:
_logger = logging.getLogger("theano.tests.unittest_tools") _logger = logging.getLogger("theano.tests.unittest_tools")
def custom_name_func(testcase_func, param_num, param):
return "%s_%s" % (
testcase_func.__name__,
parameterized.to_safe_name("_".join(str(x) for x in param.args)),
)
def fetch_seed(pseed=None): def fetch_seed(pseed=None):
""" """
Returns the seed to use for running the unit tests. Returns the seed to use for running the unit tests.
...@@ -96,6 +104,7 @@ verify_grad.E_grad = T.verify_grad.E_grad ...@@ -96,6 +104,7 @@ verify_grad.E_grad = T.verify_grad.E_grad
class TestOptimizationMixin(object): class TestOptimizationMixin(object):
def assertFunctionContains(self, f, op, min=1, max=sys.maxsize): def assertFunctionContains(self, f, op, min=1, max=sys.maxsize):
toposort = f.maker.fgraph.toposort() toposort = f.maker.fgraph.toposort()
matches = [node for node in toposort if node.op == op] matches = [node for node in toposort if node.op == op]
...@@ -172,6 +181,7 @@ class T_OpContractMixin(object): ...@@ -172,6 +181,7 @@ class T_OpContractMixin(object):
class InferShapeTester(unittest.TestCase): class InferShapeTester(unittest.TestCase):
def setUp(self): def setUp(self):
seed_rng() seed_rng()
# Take into account any mode that may be defined in a child class # Take into account any mode that may be defined in a child class
...@@ -311,6 +321,7 @@ def str_diagnostic(expected, value, rtol, atol): ...@@ -311,6 +321,7 @@ def str_diagnostic(expected, value, rtol, atol):
class WrongValue(Exception): class WrongValue(Exception):
def __init__(self, expected_val, val, rtol, atol): def __init__(self, expected_val, val, rtol, atol):
Exception.__init__(self) # to be compatible with python2.4 Exception.__init__(self) # to be compatible with python2.4
self.val1 = expected_val self.val1 = expected_val
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论