提交 d2236cff authored 作者: Tim Cooijmans's avatar Tim Cooijmans

GpuBatchedDot: pass on stream_threshold to related instances

上级 4acaa2cf
...@@ -325,8 +325,8 @@ class GpuBatchedDot(GpuOp): ...@@ -325,8 +325,8 @@ class GpuBatchedDot(GpuOp):
x, y = inp x, y = inp
gz, = grads gz, = grads
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) xgrad = GpuBatchedDot(stream_threshold=self.stream_threshold)(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) ygrad = GpuBatchedDot(stream_threshold=self.stream_threshold)(x.dimshuffle(0, 2, 1), gz)
rval = xgrad, ygrad rval = xgrad, ygrad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论