提交 d2236cff authored 作者: Tim Cooijmans's avatar Tim Cooijmans

GpuBatchedDot: pass on stream_threshold to related instances

上级 4acaa2cf
......@@ -325,8 +325,8 @@ class GpuBatchedDot(GpuOp):
x, y = inp
gz, = grads
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
xgrad = GpuBatchedDot(stream_threshold=self.stream_threshold)(gz, y.dimshuffle(0, 2, 1))
ygrad = GpuBatchedDot(stream_threshold=self.stream_threshold)(x.dimshuffle(0, 2, 1), gz)
rval = xgrad, ygrad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论