提交 9841c0db authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Use gemm_batched from python code in the gradient.ù

上级 437b1a5f
...@@ -371,12 +371,29 @@ class SparseBlockOuterSS(GpuOp): ...@@ -371,12 +371,29 @@ class SparseBlockOuterSS(GpuOp):
if not self.inplace: if not self.inplace:
o = o.copy() o = o.copy()
dd = (x.shape[0] * y.shape[0],)
xHostB = numpy.empty(dd, dtype='intp')
yHostB = numpy.empty(dd, dtype='intp')
outHostB = numpy.empty(dd, dtype='intp')
k = 0
for j in range(y.shape[0]): for j in range(y.shape[0]):
out_id = yIdx[j] out_id = yIdx[j]
for i in range(x.shape[0]): for i in range(x.shape[0]):
inp_id = xIdx[i] inp_id = xIdx[i]
ger(numpy.float32(1.0), y[j], outHostB[k] = o[inp_id, out_id].gpudata
x[i], o[inp_id, out_id]) xHostB[k] = x[i].gpudata
yHostB[k] = y[j].gpudata
k += 1
xB = pycuda.gpuarray.to_gpu(xHostB)
yB = pycuda.gpuarray.to_gpu(yHostB)
outB = pycuda.gpuarray.to_gpu(outHostB)
gemm_batched('n', 't', y.shape[1], x.shape[1], 1,
yB, y.strides[0], xB, x.strides[0],
outB, o.strides[2],
beta=numpy.asarray(1.0, dtype='float32'))
out[0] = o out[0] = o
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论