提交 5ff8bfde authored 作者: Markus Beissinger's avatar Markus Beissinger

Added tests for isclose() and allclose().

上级 cac642d9
...@@ -47,7 +47,8 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as, ...@@ -47,7 +47,8 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
itensor3, Tile, switch, Diagonal, Diag, itensor3, Tile, switch, Diagonal, Diag,
nonzero, flatnonzero, nonzero_values, nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power, stacklists, DimShuffle, hessian, ptp, power,
swapaxes, choose, Choose, NoneConst, AllocEmpty swapaxes, choose, Choose, NoneConst, AllocEmpty,
isclose, allclose,
) )
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -4119,6 +4120,65 @@ class test_comparison(unittest.TestCase): ...@@ -4119,6 +4120,65 @@ class test_comparison(unittest.TestCase):
except TypeError: except TypeError:
assert err assert err
def test_isclose(self):
for dtype in ['float64', 'float32', 'complex64', 'complex128']:
l = numpy.asarray(
[0., 1., -1., 0.,
numpy.nan, numpy.inf, -numpy.inf, numpy.inf],
dtype=dtype)
r = numpy.asarray(
[0., 1.00001, -1.000000000001, numpy.nan,
numpy.nan, numpy.inf, numpy.inf, 0.],
dtype=dtype)
for x, y, err in [
(shared(l.astype(dtype)), shared(r.astype(dtype)), False),
(l, shared(r.astype(dtype)), True),
(constant(l), shared(r.astype(dtype)), False),
(shared(l.astype(dtype)), r, False),
(shared(l.astype(dtype)), constant(r), False),
]:
try:
fn1 = inplace_func([], isclose(x, y, equal_nan=False))
fn2 = inplace_func([], isclose(x, y, equal_nan=True))
v1 = fn1()
v2 = fn2()
self.assertTrue(
numpy.all(v1 == numpy.isclose(l, r, equal_nan=False)),
numpy.all(v2 == numpy.isclose(l, r, equal_nan=True))
)
except TypeError:
if not dtype.startswith('complex'):
assert err
def test_allclose(self):
# equal_nan argument not in current version of numpy allclose,
# force it to False.
for dtype in ['float64', 'float32', 'complex64', 'complex128']:
l = numpy.asarray(
[0., 1., -1., 0.,
numpy.nan, numpy.inf, -numpy.inf, numpy.inf],
dtype=dtype)
r = numpy.asarray(
[0., 1.00001, -1.000000000001, numpy.nan,
numpy.nan, numpy.inf, numpy.inf, 0.],
dtype=dtype)
for x, y, err in [
(shared(l.astype(dtype)), shared(r.astype(dtype)), False),
(l, shared(r.astype(dtype)), True),
(constant(l), shared(r.astype(dtype)), False),
(shared(l.astype(dtype)), r, False),
(shared(l.astype(dtype)), constant(r), False),
]:
try:
fn = inplace_func([], allclose(x, y, equal_nan=False))
v = fn()
self.assertTrue(
numpy.all(v == numpy.allclose(l, r)),
)
except TypeError:
if not dtype.startswith('complex'):
assert err
class test_bitwise(unittest.TestCase): class test_bitwise(unittest.TestCase):
dtype = ['int8', 'int16', 'int32', 'int64', ] dtype = ['int8', 'int16', 'int32', 'int64', ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论