提交 90ca8271 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Made corrections.

上级 4fca63f6
......@@ -118,7 +118,7 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5, gap=None):
value = numpy.random.random(shape)
elif len(gap) == 2:
a, b = gap
value = a + numpy.random.random(shape) * b - a
value = a + numpy.random.random(shape) * (b - a)
else:
value = numpy.random.random(shape) * gap[0]
......@@ -2242,7 +2242,7 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
:param op: Op to test.
:expected_f: Function use to compare. This function must act
on dense matrix. If the op the structured
on dense matrix. If the op is structured
see the `structure_function` decorator to make
this function structured.
:param gap: Tuple for the range of the random sample. When
......@@ -2261,10 +2261,7 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
"""
if test_dtypes is None:
test_dtypes = [d for d in sparse.all_dtypes
if not (d == 'int' or
d == 'int8' or
d in sparse.complex_dtypes)]
test_dtypes = sparse.all_dtypes
class Tester(unittest.TestCase):
__name__ = op.__name__.capitalize() + 'Tester'
......@@ -2329,7 +2326,7 @@ def structure_function(f, index=0):
`index` parameter.
"""
def structured_function(*args, **kwargs):
def structured_function(*args):
pattern = args[index]
evaluated = f(*args)
evaluated[pattern == 0] = 0
......@@ -2338,9 +2335,11 @@ def structure_function(f, index=0):
StructuredSigmoidTester = elemwise_checker(
sparse.structured_sigmoid,
structure_function(lambda x: 1.0 / (1.0 + numpy.exp(-x))))
structure_function(lambda x: 1.0 / (1.0 + numpy.exp(-x))),
test_dtypes=[m for m in sparse.all_dtypes
if not m in sparse.complex_dtypes])
StructuredLogTester = elemwise_checker(
StructuredExpTester = elemwise_checker(
sparse.structured_exp,
structure_function(numpy.exp))
......@@ -2354,13 +2353,11 @@ StructuredPowTester = elemwise_checker(
StructuredMinimumTester = elemwise_checker(
lambda x: structured_minimum(x, 2),
structure_function(lambda x: numpy.minimum(x, 2)),
grad_test=False)
structure_function(lambda x: numpy.minimum(x, 2)))
StructuredMaximumTester = elemwise_checker(
lambda x: structured_maximum(x, 2),
structure_function(lambda x: numpy.maximum(x, 2)),
grad_test=False)
structure_function(lambda x: numpy.maximum(x, 2)))
StructuredAddTester = elemwise_checker(
lambda x: structured_add(x, 2),
......@@ -2409,17 +2406,23 @@ RintTester = elemwise_checker(
SgnTester = elemwise_checker(
sparse.sgn,
numpy.sign,
grad_test=False)
grad_test=False,
test_dtypes=[m for m in sparse.all_dtypes
if not m in sparse.complex_dtypes])
CeilTester = elemwise_checker(
sparse.ceil,
numpy.ceil,
grad_test=False)
grad_test=False,
test_dtypes=[m for m in sparse.all_dtypes
if not m in sparse.complex_dtypes])
FloorTester = elemwise_checker(
sparse.floor,
numpy.floor,
grad_test=False)
grad_test=False,
test_dtypes=[m for m in sparse.all_dtypes
if not m in sparse.complex_dtypes])
Log1pTester = elemwise_checker(
sparse.log1p,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论