提交 d8db448d authored 作者: Frederic's avatar Frederic

Reuse the new copymod function to clean up a little random tests value used.

上级 2b598f80
...@@ -562,31 +562,28 @@ _good_broadcast_div_mod_normal_float_no_complex = dict( ...@@ -562,31 +562,28 @@ _good_broadcast_div_mod_normal_float_no_complex = dict(
column=(rand(2, 3), rand(2, 1)), column=(rand(2, 3), rand(2, 1)),
dtype_mixup_1=(rand(2, 3), randint_nonzero(2, 3)), dtype_mixup_1=(rand(2, 3), randint_nonzero(2, 3)),
dtype_mixup_2=(randint_nonzero(2, 3), rand(2, 3)), dtype_mixup_2=(randint_nonzero(2, 3), rand(2, 3)),
# Fix problem with integers and uintegers and add them.
# Them remove there specific addition to CeilIntDivTester tests.
# integer=(randint(2, 3), randint_nonzero(2, 3)),
# uinteger=(randint(2, 3).astype("uint8"),
# randint_nonzero(2, 3).astype("uint8")),
#empty2=(numpy.asarray([0]), numpy.asarray([])), #empty2=(numpy.asarray([0]), numpy.asarray([])),
) )
_good_broadcast_div_mod_normal_float_inplace = dict(
_good_broadcast_div_mod_normal_float_inplace = copymod(
_good_broadcast_div_mod_normal_float_no_complex,
empty1=(numpy.asarray([]), numpy.asarray([1])), empty1=(numpy.asarray([]), numpy.asarray([1])),
complex1=(randcomplex(2, 3), randcomplex(2, 3)), complex1=(randcomplex(2, 3), randcomplex(2, 3)),
complex2=(randcomplex(2, 3), rand(2, 3)), complex2=(randcomplex(2, 3), rand(2, 3)),
# Inplace on the first element. Must have the same type. # Inplace on the first element. Must have the same type.
#complex3=(rand(2, 3) ,randcomplex(2, 3)), #complex3=(rand(2, 3) ,randcomplex(2, 3)),
**_good_broadcast_div_mod_normal_float_no_complex
)
_good_broadcast_div_mod_normal_float = dict(empty2 = (numpy.asarray([0]), numpy.asarray([])),
**_good_broadcast_div_mod_normal_float_inplace
) )
def no_complex(d): _good_broadcast_div_mod_normal_float = copymod(
"""Remove pairs from dictionary d when the value contains complex data.""" _good_broadcast_div_mod_normal_float_inplace,
return dict((k, v) for k, v in d.iteritems() empty2=(numpy.asarray([0]), numpy.asarray([]))
if all(str(x.dtype) not in tensor.complex_dtypes for x in v)) )
# 'No-complex' versions, with empty2
_good_broadcast_div_mod_normal_float_no_complex2 = no_complex(
_good_broadcast_div_mod_normal_float)
_good_broadcast_div_mod_normal_float_inplace_no_complex = no_complex(
_good_broadcast_div_mod_normal_float_inplace)
_grad_broadcast_div_mod_normal = dict(same_shapes = (rand(2, 3), rand(2, 3)), _grad_broadcast_div_mod_normal = dict(same_shapes = (rand(2, 3), rand(2, 3)),
scalar = (rand(2, 3), rand(1, 1)), scalar = (rand(2, 3), rand(1, 1)),
...@@ -637,18 +634,22 @@ CeilIntDivTester = makeBroadcastTester( ...@@ -637,18 +634,22 @@ CeilIntDivTester = makeBroadcastTester(
# grad_rtol=div_grad_rtol, # grad_rtol=div_grad_rtol,
) )
ModTester = makeBroadcastTester(
ModTester = makeBroadcastTester(op = tensor.mod, op=tensor.mod,
expected = lambda x, y: numpy.asarray(x % y, dtype=theano.scalar.basic.upcast(x.dtype, y.dtype)), expected=lambda x, y: numpy.asarray(
good = _good_broadcast_div_mod_normal_float_no_complex2, x % y, dtype=theano.scalar.basic.upcast(x.dtype, y.dtype)),
# integers = (randint(2, 3), randint_nonzero(2, 3)), good=copymod(_good_broadcast_div_mod_normal_float,
# dtype_mixup_1 = (rand(2, 3), randint_nonzero(2, 3)), ['complex1', 'complex2']),
# dtype_mixup_2 = (randint_nonzero(2, 3), rand(2, 3))),
) )
ModInplaceTester = makeBroadcastTester(op = inplace.mod_inplace,
expected = lambda x, y: numpy.asarray(x % y, dtype=theano.scalar.basic.upcast(x.dtype, y.dtype)),
good = _good_broadcast_div_mod_normal_float_inplace_no_complex, ModInplaceTester = makeBroadcastTester(
inplace = True) op=inplace.mod_inplace,
expected=lambda x, y: numpy.asarray(
x % y, dtype=theano.scalar.basic.upcast(x.dtype, y.dtype)),
good=copymod(_good_broadcast_div_mod_normal_float_inplace,
["complex1", "complex2"]),
inplace=True)
_good_broadcast_pow_normal_float = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 3))), _good_broadcast_pow_normal_float = dict(same_shapes = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (2, 3))),
scalar = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 1))), scalar = (rand_ranged(1, 5, (2, 3)), rand_ranged(-3, 3, (1, 1))),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论