提交 d8db448d authored 作者: Frederic's avatar Frederic

Reuse the new copymod function to clean up a little random tests value used.

上级 2b598f80
...@@ -562,31 +562,28 @@ _good_broadcast_div_mod_normal_float_no_complex = dict( ...@@ -562,31 +562,28 @@ _good_broadcast_div_mod_normal_float_no_complex = dict(
column=(rand(2, 3), rand(2, 1)), column=(rand(2, 3), rand(2, 1)),
dtype_mixup_1=(rand(2, 3), randint_nonzero(2, 3)), dtype_mixup_1=(rand(2, 3), randint_nonzero(2, 3)),
dtype_mixup_2=(randint_nonzero(2, 3), rand(2, 3)), dtype_mixup_2=(randint_nonzero(2, 3), rand(2, 3)),
# Fix problem with integers and uintegers and add them.
# Them remove there specific addition to CeilIntDivTester tests.
# integer=(randint(2, 3), randint_nonzero(2, 3)),
# uinteger=(randint(2, 3).astype("uint8"),
# randint_nonzero(2, 3).astype("uint8")),
#empty2=(numpy.asarray([0]), numpy.asarray([])), #empty2=(numpy.asarray([0]), numpy.asarray([])),
) )
_good_broadcast_div_mod_normal_float_inplace = dict(
_good_broadcast_div_mod_normal_float_inplace = copymod(
_good_broadcast_div_mod_normal_float_no_complex,
empty1=(numpy.asarray([]), numpy.asarray([1])), empty1=(numpy.asarray([]), numpy.asarray([1])),
complex1=(randcomplex(2, 3), randcomplex(2, 3)), complex1=(randcomplex(2, 3), randcomplex(2, 3)),
complex2=(randcomplex(2, 3), rand(2, 3)), complex2=(randcomplex(2, 3), rand(2, 3)),
# Inplace on the first element. Must have the same type. # Inplace on the first element. Must have the same type.
#complex3=(rand(2, 3) ,randcomplex(2, 3)), #complex3=(rand(2, 3) ,randcomplex(2, 3)),
**_good_broadcast_div_mod_normal_float_no_complex )
)
_good_broadcast_div_mod_normal_float = dict(empty2 = (numpy.asarray([0]), numpy.asarray([])),
**_good_broadcast_div_mod_normal_float_inplace
)
def no_complex(d):
"""Remove pairs from dictionary d when the value contains complex data."""
return dict((k, v) for k, v in d.iteritems()
if all(str(x.dtype) not in tensor.complex_dtypes for x in v))
_good_broadcast_div_mod_normal_float = copymod(
_good_broadcast_div_mod_normal_float_inplace,
empty2=(numpy.asarray([0]), numpy.asarray([]))
)
# 'No-complex' versions, with empty2
_good_broadcast_div_mod_normal_float_no_complex2 = no_complex(
_good_broadcast_div_mod_normal_float)
_good_broadcast_div_mod_normal_float_inplace_no_complex = no_complex(
_good_broadcast_div_mod_normal_float_inplace)
_grad_broadcast_div_mod_normal = dict(same_shapes = (rand(2, 3), rand(2, 3)), _grad_broadcast_div_mod_normal = dict(same_shapes = (rand(2, 3), rand(2, 3)),
scalar = (rand(2, 3), rand(1, 1)), scalar = (rand(2, 3), rand(1, 1)),
...@@ -637,18 +634,22 @@ CeilIntDivTester = makeBroadcastTester( ...@@ -637,18 +634,22 @@ CeilIntDivTester = makeBroadcastTester(
# grad_rtol=div_grad_rtol, # grad_rtol=div_grad_rtol,
) )
ModTester = makeBroadcastTester(
op=tensor.mod,
expected=lambda x, y: numpy.asarray(
x % y, dtype=theano.scalar.basic.upcast(x.dtype, y.dtype)),
good=copymod(_good_broadcast_div_mod_normal_float,
['complex1', 'complex2']),
)
ModTester = makeBroadcastTester(op = tensor.mod,
expected = lambda x, y: numpy.asarray(x % y, dtype=theano.scalar.basic.upcast(x.dtype, y.dtype)), ModInplaceTester = makeBroadcastTester(
good = _good_broadcast_div_mod_normal_float_no_complex2, op=inplace.mod_inplace,
# integers = (randint(2, 3), randint_nonzero(2, 3)), expected=lambda x, y: numpy.asarray(
# dtype_mixup_1 = (rand(2, 3), randint_nonzero(2, 3)), x % y, dtype=theano.scalar.basic.upcast(x.dtype, y.dtype)),
# dtype_mixup_2 = (randint_nonzero(2, 3), rand(2, 3))), good=copymod(_good_broadcast_div_mod_normal_float_inplace,
) ["complex1", "complex2"]),
ModInplaceTester = makeBroadcastTester(op = inplace.mod_inplace, inplace=True)
expected = lambda x, y: numpy.asarray(x % y, dtype=theano.scalar.basic.upcast(x.dtype, y.dtype)),
good = _good_broadcast_div_mod_normal_float_inplace_no_complex,
inplace = True)
_good_broadcast_pow_normal_float = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 3))), _good_broadcast_pow_normal_float = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 3))),
scalar = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 1))), scalar = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 1))),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论