提交 2202eb41 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: more diverse tests

上级 bee1a7a0
...@@ -1931,12 +1931,17 @@ DotTester = makeTester(name='DotTester', ...@@ -1931,12 +1931,17 @@ DotTester = makeTester(name='DotTester',
bad_build=dict(), bad_build=dict(),
bad_runtime=dict(bad1=(rand(5, 7), rand(5, 7)), bad_runtime=dict(bad1=(rand(5, 7), rand(5, 7)),
bad2=(rand(5, 7), rand(8, 3)))) bad2=(rand(5, 7), rand(8, 3))))
BatchedDotTester = makeTester( BatchedDotTester = makeTester(
name='BatchedDotTester', name='BatchedDotTester',
op=batched_dot, op=batched_dot,
expected=lambda xs, ys: numpy.asarray( expected=(lambda xs, ys:
list(numpy.dot(x, y) for x, y in zip(xs, ys)), numpy.asarray(
dtype=theano.scalar.upcast(xs.dtype, ys.dtype)), list(x * y if x.ndim == 0 or y.ndim == 0
else (numpy.dot(x, y) if y.ndim == 1
else numpy.tensordot(x, y, [[x.ndim - 1], [y.ndim - 2]]))
for x, y in zip(xs, ys)),
dtype=theano.scalar.upcast(xs.dtype, ys.dtype))),
checks={}, checks={},
grad=dict(correct1=(rand(3, 5, 7), rand(3, 7, 5)), grad=dict(correct1=(rand(3, 5, 7), rand(3, 7, 5)),
correct2=(rand(3, 5, 7), rand(3, 7, 9)), correct2=(rand(3, 5, 7), rand(3, 7, 9)),
...@@ -1944,6 +1949,12 @@ BatchedDotTester = makeTester( ...@@ -1944,6 +1949,12 @@ BatchedDotTester = makeTester(
correct4=(rand(3, 5), rand(3, 5, 7)), correct4=(rand(3, 5), rand(3, 5, 7)),
correct5=(rand(3), rand(3, 5, 7)), correct5=(rand(3), rand(3, 5, 7)),
correct6=(rand(3, 5), rand(3)), correct6=(rand(3, 5), rand(3)),
correct7=(rand(3, 5), rand(3, 5)),
correct8=(rand(3), rand(3)),
correct9=(rand(3, 5, 7, 11), rand(3)),
correct10=(rand(3, 7, 11, 5), rand(3, 5)),
correct11=(rand(3, 7, 11, 5), rand(3, 5, 13)),
correct12=(rand(3, 7, 11, 5), rand(3, 13, 5, 17)),
mixed1=(rand(3, 5).astype('float32'), mixed1=(rand(3, 5).astype('float32'),
rand(3, 5, 7)), rand(3, 5, 7)),
mixed2=(rand(3, 5).astype('float64'), mixed2=(rand(3, 5).astype('float64'),
...@@ -1954,6 +1965,12 @@ BatchedDotTester = makeTester( ...@@ -1954,6 +1965,12 @@ BatchedDotTester = makeTester(
correct4=(rand(3, 5), rand(3, 5, 7)), correct4=(rand(3, 5), rand(3, 5, 7)),
correct5=(rand(3), rand(3, 5, 7)), correct5=(rand(3), rand(3, 5, 7)),
correct6=(rand(3, 5), rand(3)), correct6=(rand(3, 5), rand(3)),
correct7=(rand(3, 5), rand(3, 5)),
correct8=(rand(3), rand(3)),
correct9=(rand(3, 5, 7, 11), rand(3)),
correct10=(rand(3, 7, 11, 5), rand(3, 5)),
correct11=(rand(3, 7, 11, 5), rand(3, 5, 13)),
correct12=(rand(3, 7, 11, 5), rand(3, 13, 5, 17)),
mixed1=(rand(3, 5).astype('float32'), mixed1=(rand(3, 5).astype('float32'),
rand(3, 5, 7)), rand(3, 5, 7)),
mixed2=(rand(3, 5).astype('float64'), mixed2=(rand(3, 5).astype('float64'),
...@@ -1961,16 +1978,18 @@ BatchedDotTester = makeTester( ...@@ -1961,16 +1978,18 @@ BatchedDotTester = makeTester(
complex1=(randcomplex(3, 5, 7), complex1=(randcomplex(3, 5, 7),
randcomplex(3, 7)), randcomplex(3, 7)),
complex2=(rand(3, 5, 7), randcomplex(3, 7)), complex2=(rand(3, 5, 7), randcomplex(3, 7)),
complex3=(randcomplex(3, 5, 7), rand(3, 7)), complex3=(randcomplex(3, 5, 7), rand(3, 7))),
empty1=(numpy.asarray([], dtype=config.floatX), bad_build=dict(no_batch_axis2=(rand(), rand(3, 5)),
numpy.asarray([], dtype=config.floatX)), no_batch_axis3=(rand(3, 5), rand())),
empty2=(rand(3, 5, 0), rand(3, 0, 2)), bad_runtime=dict(batch_dim_mismatch1=(rand(2, 5, 7), rand(3, 7, 9)),
empty3=(rand(3, 0, 5), rand(3, 5, 0))), batch_dim_mismatch2=(rand(3, 5, 7), rand(2, 7, 9)),
bad_build=dict(), batch_dim_mismatch3=(rand(3), rand(5)),
bad_runtime=dict(bad1=(rand(3, 5, 7), rand(3, 5, 7)), bad_dim1=(rand(3, 5, 7), rand(3, 5, 7)),
bad2=(rand(3, 5, 7), rand(3, 8, 3)), bad_dim2=(rand(3, 5, 7), rand(3, 8, 3)),
bad3=(rand(2, 5, 7), rand(3, 7, 9)), bad_dim3=(rand(3, 5), rand(3, 7)),
bad4=(rand(3, 5, 7), rand(2, 7, 9)))) bad_dim4=(rand(3, 5, 7, 11), rand(3, 5)),
bad_dim5=(rand(3, 5, 7, 11), rand(3, 5, 13)),
bad_dim6=(rand(3, 5, 7, 11), rand(3, 13, 5, 17))))
def _numpy_second(x, y): def _numpy_second(x, y):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论