提交 ff91168e authored 作者: Mathieu Germain's avatar Mathieu Germain

added nose_parameterized dependency

上级 35c8fdf1
......@@ -49,7 +49,7 @@ instructions below for detailed installation steps):
The following libraries and software are optional:
`nose <http://nose.readthedocs.org/en/latest/>`_ >= 1.3.0
`nose <http://nose.readthedocs.org/en/latest/>`_ >= 1.3.0 and `nose-parameterized <https://pypi.python.org/pypi/nose-parameterized/>`_ >= 0.5.0
Recommended, to run Theano's test-suite.
`Sphinx <http://sphinx.pocoo.org/>`_ >= 0.5.1, `pygments <http://pygments.org/>`_
......
......@@ -164,6 +164,9 @@ def do_setup():
packages=find_packages(),
# 1.7.0 give too much warning related to numpy.diagonal.
install_requires=['numpy>=1.7.1', 'scipy>=0.11', 'six>=1.9.0'],
extras_require={
'test': ['nose>=1.3.0', 'nose-parameterized>=0.5.0']
},
package_data={
'': ['*.txt', '*.rst', '*.cu', '*.cuh', '*.c', '*.sh', '*.pkl',
'*.h', '*.cpp', 'ChangeLog'],
......
......@@ -4,6 +4,7 @@ from functools import wraps
import logging
import sys
import unittest
from nose_parameterized import parameterized
from six import integer_types
from six.moves import StringIO
......@@ -31,6 +32,13 @@ except ImportError:
_logger = logging.getLogger("theano.tests.unittest_tools")
def custom_name_func(testcase_func, param_num, param):
return "%s_%s" % (
testcase_func.__name__,
parameterized.to_safe_name("_".join(str(x) for x in param.args)),
)
def fetch_seed(pseed=None):
"""
Returns the seed to use for running the unit tests.
......@@ -96,6 +104,7 @@ verify_grad.E_grad = T.verify_grad.E_grad
class TestOptimizationMixin(object):
def assertFunctionContains(self, f, op, min=1, max=sys.maxsize):
toposort = f.maker.fgraph.toposort()
matches = [node for node in toposort if node.op == op]
......@@ -172,6 +181,7 @@ class T_OpContractMixin(object):
class InferShapeTester(unittest.TestCase):
def setUp(self):
seed_rng()
# Take into account any mode that may be defined in a child class
......@@ -311,6 +321,7 @@ def str_diagnostic(expected, value, rtol, atol):
class WrongValue(Exception):
def __init__(self, expected_val, val, rtol, atol):
Exception.__init__(self) # to be compatible with python2.4
self.val1 = expected_val
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论