提交 ed4e0679 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: test and fix grad

上级 e68999e8
...@@ -3449,12 +3449,12 @@ class BatchedDot(Op): ...@@ -3449,12 +3449,12 @@ class BatchedDot(Op):
ygrad = x.dimshuffle(0, 1, 'x') * gz.dimshuffle(0, 'x', 1) ygrad = x.dimshuffle(0, 1, 'x') * gz.dimshuffle(0, 'x', 1)
# x is matrix, y is vector, grad is vector # x is matrix, y is vector, grad is vector
elif xdim == 2 and ydim == 1: elif xdim == 3 and ydim == 2:
xgrad = gz.dimshuffle(0, 1, 'x') * y.dimshuffle(0, 'x', 1) xgrad = gz.dimshuffle(0, 1, 'x') * y.dimshuffle(0, 'x', 1)
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
# x is matrix, y is matrix, grad is matrix # x is matrix, y is matrix, grad is matrix
elif xdim == ydim == 2: elif xdim == ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
......
...@@ -1931,13 +1931,22 @@ DotTester = makeTester(name='DotTester', ...@@ -1931,13 +1931,22 @@ DotTester = makeTester(name='DotTester',
bad_build=dict(), bad_build=dict(),
bad_runtime=dict(bad1=(rand(5, 7), rand(5, 7)), bad_runtime=dict(bad1=(rand(5, 7), rand(5, 7)),
bad2=(rand(5, 7), rand(8, 3)))) bad2=(rand(5, 7), rand(8, 3))))
BatchedDotTester = makeTester( BatchedDotTester = makeTester(
name='BatchedDotTester', name='BatchedDotTester',
op=batched_dot, op=batched_dot,
expected=lambda xs, ys: numpy.asarray(list( expected=lambda xs, ys: numpy.asarray(list(
numpy.dot(x, y) for x, y in zip(xs, ys))), numpy.dot(x, y) for x, y in zip(xs, ys))),
checks={}, checks={},
grad=dict(correct1=(rand(3, 5, 7), rand(3, 7, 5)),
correct2=(rand(3, 5, 7), rand(3, 7, 9)),
correct3=(rand(3, 5, 7), rand(3, 7)),
correct4=(rand(3, 5), rand(3, 5, 7)),
correct5=(rand(3), rand(3, 5, 7)),
correct6=(rand(3, 5), rand(3)),
mixed1=(rand(3, 5).astype('float32'),
rand(3, 5, 7)),
mixed2=(rand(3, 5).astype('float64'),
rand(3, 5, 7))),
good=dict(correct1=(rand(3, 5, 7), rand(3, 7, 5)), good=dict(correct1=(rand(3, 5, 7), rand(3, 7, 5)),
correct2=(rand(3, 5, 7), rand(3, 7, 9)), correct2=(rand(3, 5, 7), rand(3, 7, 9)),
correct3=(rand(3, 5, 7), rand(3, 7)), correct3=(rand(3, 5, 7), rand(3, 7)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论