提交 90ca8271 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Made corrections.

上级 4fca63f6
...@@ -118,7 +118,7 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5, gap=None): ...@@ -118,7 +118,7 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5, gap=None):
value = numpy.random.random(shape) value = numpy.random.random(shape)
elif len(gap) == 2: elif len(gap) == 2:
a, b = gap a, b = gap
value = a + numpy.random.random(shape) * b - a value = a + numpy.random.random(shape) * (b - a)
else: else:
value = numpy.random.random(shape) * gap[0] value = numpy.random.random(shape) * gap[0]
...@@ -2242,7 +2242,7 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2242,7 +2242,7 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
:param op: Op to test. :param op: Op to test.
:expected_f: Function use to compare. This function must act :expected_f: Function use to compare. This function must act
on dense matrix. If the op the structured on dense matrix. If the op is structured
see the `structure_function` decorator to make see the `structure_function` decorator to make
this function structured. this function structured.
:param gap: Tuple for the range of the random sample. When :param gap: Tuple for the range of the random sample. When
...@@ -2261,10 +2261,7 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2261,10 +2261,7 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
""" """
if test_dtypes is None: if test_dtypes is None:
test_dtypes = [d for d in sparse.all_dtypes test_dtypes = sparse.all_dtypes
if not (d == 'int' or
d == 'int8' or
d in sparse.complex_dtypes)]
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
__name__ = op.__name__.capitalize() + 'Tester' __name__ = op.__name__.capitalize() + 'Tester'
...@@ -2329,7 +2326,7 @@ def structure_function(f, index=0): ...@@ -2329,7 +2326,7 @@ def structure_function(f, index=0):
`index` parameter. `index` parameter.
""" """
def structured_function(*args, **kwargs): def structured_function(*args):
pattern = args[index] pattern = args[index]
evaluated = f(*args) evaluated = f(*args)
evaluated[pattern == 0] = 0 evaluated[pattern == 0] = 0
...@@ -2338,9 +2335,11 @@ def structure_function(f, index=0): ...@@ -2338,9 +2335,11 @@ def structure_function(f, index=0):
StructuredSigmoidTester = elemwise_checker( StructuredSigmoidTester = elemwise_checker(
sparse.structured_sigmoid, sparse.structured_sigmoid,
structure_function(lambda x: 1.0 / (1.0 + numpy.exp(-x)))) structure_function(lambda x: 1.0 / (1.0 + numpy.exp(-x))),
test_dtypes=[m for m in sparse.all_dtypes
if not m in sparse.complex_dtypes])
StructuredLogTester = elemwise_checker( StructuredExpTester = elemwise_checker(
sparse.structured_exp, sparse.structured_exp,
structure_function(numpy.exp)) structure_function(numpy.exp))
...@@ -2354,13 +2353,11 @@ StructuredPowTester = elemwise_checker( ...@@ -2354,13 +2353,11 @@ StructuredPowTester = elemwise_checker(
StructuredMinimumTester = elemwise_checker( StructuredMinimumTester = elemwise_checker(
lambda x: structured_minimum(x, 2), lambda x: structured_minimum(x, 2),
structure_function(lambda x: numpy.minimum(x, 2)), structure_function(lambda x: numpy.minimum(x, 2)))
grad_test=False)
StructuredMaximumTester = elemwise_checker( StructuredMaximumTester = elemwise_checker(
lambda x: structured_maximum(x, 2), lambda x: structured_maximum(x, 2),
structure_function(lambda x: numpy.maximum(x, 2)), structure_function(lambda x: numpy.maximum(x, 2)))
grad_test=False)
StructuredAddTester = elemwise_checker( StructuredAddTester = elemwise_checker(
lambda x: structured_add(x, 2), lambda x: structured_add(x, 2),
...@@ -2409,17 +2406,23 @@ RintTester = elemwise_checker( ...@@ -2409,17 +2406,23 @@ RintTester = elemwise_checker(
SgnTester = elemwise_checker( SgnTester = elemwise_checker(
sparse.sgn, sparse.sgn,
numpy.sign, numpy.sign,
grad_test=False) grad_test=False,
test_dtypes=[m for m in sparse.all_dtypes
if not m in sparse.complex_dtypes])
CeilTester = elemwise_checker( CeilTester = elemwise_checker(
sparse.ceil, sparse.ceil,
numpy.ceil, numpy.ceil,
grad_test=False) grad_test=False,
test_dtypes=[m for m in sparse.all_dtypes
if not m in sparse.complex_dtypes])
FloorTester = elemwise_checker( FloorTester = elemwise_checker(
sparse.floor, sparse.floor,
numpy.floor, numpy.floor,
grad_test=False) grad_test=False,
test_dtypes=[m for m in sparse.all_dtypes
if not m in sparse.complex_dtypes])
Log1pTester = elemwise_checker( Log1pTester = elemwise_checker(
sparse.log1p, sparse.log1p,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论