提交 f63afd20 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2965 from briancheung/batched_dot_grad

Added small gradient calculation for batched_dot
...@@ -199,6 +199,20 @@ class BatchedDotOp(GpuOp): ...@@ -199,6 +199,20 @@ class BatchedDotOp(GpuOp):
} while (0) } while (0)
""" """
def grad(self, inp, grads):
x, y = inp
gz, = grads
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
rval = xgrad, ygrad
for elem in rval:
assert elem.dtype.find('float') != -1
return rval
batched_dot = BatchedDotOp() batched_dot = BatchedDotOp()
class GpuDot22(GpuOp): class GpuDot22(GpuOp):
......
...@@ -43,6 +43,7 @@ mode_without_gpu.check_py_code = False ...@@ -43,6 +43,7 @@ mode_without_gpu.check_py_code = False
def my_rand(*shape): def my_rand(*shape):
return theano._asarray(numpy.random.rand(*shape), dtype='float32') return theano._asarray(numpy.random.rand(*shape), dtype='float32')
class TestBatchedDot(TestCase): class TestBatchedDot(TestCase):
def test_batched_dot_correctness(self): def test_batched_dot_correctness(self):
...@@ -106,6 +107,14 @@ class TestBatchedDot(TestCase): ...@@ -106,6 +107,14 @@ class TestBatchedDot(TestCase):
# Shape mismatch # Shape mismatch
self.assertRaises(RuntimeError, fail, (5,4,3), (5,2,2)) self.assertRaises(RuntimeError, fail, (5,4,3), (5,2,2))
def test_batched_dot_gradient(self):
theano.tests.unittest_tools.verify_grad(
batched_dot,
[numpy.random.randn(5,7,2).astype(numpy.float32),
numpy.random.randn(5,2,6).astype(numpy.float32)],
mode=mode_with_gpu)
def test_dot22(): def test_dot22():
def cmp(a_shp, b_shp): def cmp(a_shp, b_shp):
a0 = my_rand(*a_shp) a0 = my_rand(*a_shp)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论