提交 08857dc5 authored 作者: abergeron's avatar abergeron

Merge pull request #4066 from cooijmanstim/big_batched_dot

GpuBatchedDot: streams implementation (WIP)
...@@ -48,6 +48,9 @@ class TestBatchedDot(unittest_tools.InferShapeTester): ...@@ -48,6 +48,9 @@ class TestBatchedDot(unittest_tools.InferShapeTester):
mode = mode_with_gpu mode = mode_with_gpu
def test_batched_dot_correctness(self): def test_batched_dot_correctness(self):
# test both implementations
for threshold in [0, 100]:
batched_dot = GpuBatchedDot(stream_threshold=threshold)
def cmp(a_shp, b_shp): def cmp(a_shp, b_shp):
...@@ -109,8 +112,9 @@ class TestBatchedDot(unittest_tools.InferShapeTester): ...@@ -109,8 +112,9 @@ class TestBatchedDot(unittest_tools.InferShapeTester):
self.assertRaises(RuntimeError, fail, (5,4,3), (5,2,2)) self.assertRaises(RuntimeError, fail, (5,4,3), (5,2,2))
def test_batched_dot_gradient(self): def test_batched_dot_gradient(self):
for threshold in [0, 100]:
unittest_tools.verify_grad( unittest_tools.verify_grad(
batched_dot, GpuBatchedDot(stream_threshold=threshold),
[numpy.random.randn(5,7,2).astype(numpy.float32), [numpy.random.randn(5,7,2).astype(numpy.float32),
numpy.random.randn(5,2,6).astype(numpy.float32)], numpy.random.randn(5,2,6).astype(numpy.float32)],
mode=mode_with_gpu) mode=mode_with_gpu)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论