提交 1422638a authored 作者: nouiz's avatar nouiz

Merge pull request #571 from lamblin/test_div_grad

Do not test div on denominator too close to 0.
...@@ -345,6 +345,17 @@ def rand(*shape): ...@@ -345,6 +345,17 @@ def rand(*shape):
return r * 2 - 1 return r * 2 - 1
def rand_nonzero(shape, eps=3e-4):
"""Like rand, but the absolute value has to be at least eps"""
# covers [0, 1)
r = numpy.asarray(numpy.random.rand(*shape), dtype=config.floatX)
# covers [0, (1 - eps) / 2) U [(1 + eps) / 2, 1)
r = r * (1 - eps) + eps * (r >= 0.5)
# covers [-1, -eps) U [eps, 1)
r = r * 2 - 1
return r
def randint(*shape): def randint(*shape):
return numpy.random.random_integers(-5, 5, shape) return numpy.random.random_integers(-5, 5, shape)
...@@ -355,6 +366,10 @@ def randcomplex(*shape): ...@@ -355,6 +366,10 @@ def randcomplex(*shape):
return numpy.complex128(2 * r - 1) return numpy.complex128(2 * r - 1)
def randcomplex_nonzero(shape, eps=1e-4):
return numpy.complex128(rand_nonzero(shape, eps))
def randint_nonzero(*shape): def randint_nonzero(*shape):
r = numpy.random.random_integers(-5, 4, shape) r = numpy.random.random_integers(-5, 4, shape)
return r + (r == 0) * 5 return r + (r == 0) * 5
...@@ -559,6 +574,7 @@ MulInplaceTester = makeBroadcastTester(op = inplace.mul_inplace, ...@@ -559,6 +574,7 @@ MulInplaceTester = makeBroadcastTester(op = inplace.mul_inplace,
grad = _grad_broadcast_binary_normal, grad = _grad_broadcast_binary_normal,
inplace = True) inplace = True)
def copymod(dct, without=[], **kwargs): def copymod(dct, without=[], **kwargs):
"""Return dct but with the keys named by args removed, and with """Return dct but with the keys named by args removed, and with
kwargs added. kwargs added.
...@@ -572,12 +588,12 @@ def copymod(dct, without=[], **kwargs): ...@@ -572,12 +588,12 @@ def copymod(dct, without=[], **kwargs):
return rval return rval
_good_broadcast_div_mod_normal_float_no_complex = dict( _good_broadcast_div_mod_normal_float_no_complex = dict(
same_shapes=(rand(2, 3), rand(2, 3)), same_shapes=(rand(2, 3), rand_nonzero((2, 3))),
scalar=(rand(2, 3), rand(1, 1)), scalar=(rand(2, 3), rand_nonzero((1, 1))),
row=(rand(2, 3), rand(1, 3)), row=(rand(2, 3), rand_nonzero((1, 3))),
column=(rand(2, 3), rand(2, 1)), column=(rand(2, 3), rand_nonzero((2, 1))),
dtype_mixup_1=(rand(2, 3), randint_nonzero(2, 3)), dtype_mixup_1=(rand(2, 3), randint_nonzero(2, 3)),
dtype_mixup_2=(randint_nonzero(2, 3), rand(2, 3)), dtype_mixup_2=(randint_nonzero(2, 3), rand_nonzero((2, 3))),
integer=(randint(2, 3), randint_nonzero(2, 3)), integer=(randint(2, 3), randint_nonzero(2, 3)),
uinteger=(randint(2, 3).astype("uint8"), uinteger=(randint(2, 3).astype("uint8"),
randint_nonzero(2, 3).astype("uint8")), randint_nonzero(2, 3).astype("uint8")),
...@@ -588,8 +604,8 @@ _good_broadcast_div_mod_normal_float_no_complex = dict( ...@@ -588,8 +604,8 @@ _good_broadcast_div_mod_normal_float_no_complex = dict(
_good_broadcast_div_mod_normal_float_inplace = copymod( _good_broadcast_div_mod_normal_float_inplace = copymod(
_good_broadcast_div_mod_normal_float_no_complex, _good_broadcast_div_mod_normal_float_no_complex,
empty1=(numpy.asarray([]), numpy.asarray([1])), empty1=(numpy.asarray([]), numpy.asarray([1])),
complex1=(randcomplex(2, 3), randcomplex(2, 3)), complex1=(randcomplex(2, 3), randcomplex_nonzero((2, 3))),
complex2=(randcomplex(2, 3), rand(2, 3)), complex2=(randcomplex(2, 3), rand_nonzero((2, 3))),
# Inplace on the first element. Must have the same type. # Inplace on the first element. Must have the same type.
#complex3=(rand(2, 3) ,randcomplex(2, 3)), #complex3=(rand(2, 3) ,randcomplex(2, 3)),
) )
...@@ -600,18 +616,19 @@ _good_broadcast_div_mod_normal_float = copymod( ...@@ -600,18 +616,19 @@ _good_broadcast_div_mod_normal_float = copymod(
) )
_grad_broadcast_div_mod_normal = dict(same_shapes = (rand(2, 3), rand(2, 3)), _grad_broadcast_div_mod_normal = dict(
scalar = (rand(2, 3), rand(1, 1)), same_shapes=(rand(2, 3), rand_nonzero((2, 3))),
row = (rand(2, 3), rand(1, 3)), scalar=(rand(2, 3), rand_nonzero((1, 1))),
column = (rand(2, 3), rand(2, 1)), row=(rand(2, 3), rand_nonzero((1, 3))),
#complex1 = (randcomplex(2,3),randcomplex(2,3)), column=(rand(2, 3), rand_nonzero((2, 1))),
#complex2 = (randcomplex(2,3),rand(2,3)), #complex1=(randcomplex(2, 3), randcomplex_nonzero((2, 3))),
#complex3 = (rand(2,3),randcomplex(2,3)), #complex2=(randcomplex(2, 3), rand_nonzero((2, 3))),
#dtype_mixup_1 = (rand(2, 3), randint_nonzero(2, 3)), #complex3=(rand(2, 3), randcomplex_nonzero((2, 3))),
#dtype_mixup_2 = (randint_nonzero(2, 3), rand(2, 3)), #dtype_mixup_1=(rand(2, 3), randint_nonzero(2, 3)),
#empty1 = (numpy.asarray([]), numpy.asarray([1.])), #dtype_mixup_2=(randint_nonzero(2, 3), rand_nonzero((2, 3))),
#empty2 = (numpy.asarray([0]), numpy.asarray([])), #empty1=(numpy.asarray([]), numpy.asarray([1.])),
) #empty2=(numpy.asarray([0]), numpy.asarray([])),
)
div_grad_rtol=None div_grad_rtol=None
if config.floatX=='float32': if config.floatX=='float32':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论