提交 565a6d91 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Address comments from review

上级 8a8fe4a1
...@@ -108,8 +108,16 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W, ...@@ -108,8 +108,16 @@ int APPLY_SPECIFIC(blockgemv)(PyGpuArrayObject *o, PyGpuArrayObject *W,
inp_list, offInp, PyGpuArray_STRIDES(h)[2] / gpuarray_get_elsize(h->ga.typecode), inp_list, offInp, PyGpuArray_STRIDES(h)[2] / gpuarray_get_elsize(h->ga.typecode),
1, out_list, offOut, PyGpuArray_STRIDES(out)[2] / gpuarray_get_elsize(out->ga.typecode), 1, out_list, offOut, PyGpuArray_STRIDES(out)[2] / gpuarray_get_elsize(out->ga.typecode),
PyGpuArray_DIMS(out)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(out)[0], 0); PyGpuArray_DIMS(out)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(out)[0], 0);
} else if (out->ga.typecode == GA_HALF) {
err = blas_ops->sgemvBatch(cb_fortran, transA,
PyGpuArray_DIMS(out)[2],
PyGpuArray_DIMS(h)[2], 1,
W_list, offW, lda,
inp_list, offInp, PyGpuArray_STRIDES(h)[2] / gpuarray_get_elsize(h->ga.typecode),
1, out_list, offOut, PyGpuArray_STRIDES(out)[2] / gpuarray_get_elsize(out->ga.typecode),
PyGpuArray_DIMS(out)[1] * PyGpuArray_DIMS(h)[1] * PyGpuArray_DIMS(out)[0], 0);
} else { } else {
err = GA_DEVSUP_ERROR; err = GA_INVALID_ERROR;
} }
free(W_list); free(W_list);
......
...@@ -97,8 +97,15 @@ int APPLY_SPECIFIC(blockger)(PyGpuArrayObject *o, PyGpuArrayObject *x, ...@@ -97,8 +97,15 @@ int APPLY_SPECIFIC(blockger)(PyGpuArrayObject *o, PyGpuArrayObject *x,
y_list, offY, str_y, x_list, offX, str_x, y_list, offY, str_y, x_list, offX, str_x,
o_list, offOut, str_out, o_list, offOut, str_out,
PyGpuArray_DIMS(x)[0] * PyGpuArray_DIMS(x)[1] * PyGpuArray_DIMS(y)[1], 0); PyGpuArray_DIMS(x)[0] * PyGpuArray_DIMS(x)[1] * PyGpuArray_DIMS(y)[1], 0);
} else if (out->ga.typecode == GA_HALF) {
err = blas_ops->hgerBatch(cb_fortran,
PyGpuArray_DIMS(y)[2], PyGpuArray_DIMS(x)[2],
*(float *)PyArray_GETPTR1(alpha, 0),
y_list, offY, str_y, x_list, offX, str_x,
o_list, offOut, str_out,
PyGpuArray_DIMS(x)[0] * PyGpuArray_DIMS(x)[1] * PyGpuArray_DIMS(y)[1], 0);
} else { } else {
err = GA_DEVSUP_ERROR; err = GA_INVALID_ERROR;
} }
free(o_list); free(o_list);
free(offOut); free(offOut);
...@@ -107,7 +114,7 @@ int APPLY_SPECIFIC(blockger)(PyGpuArrayObject *o, PyGpuArrayObject *x, ...@@ -107,7 +114,7 @@ int APPLY_SPECIFIC(blockger)(PyGpuArrayObject *o, PyGpuArrayObject *x,
free(y_list); free(y_list);
free(offY); free(offY);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "sgerBatch failed"); PyErr_SetString(PyExc_RuntimeError, "gerBatch failed");
return -1; return -1;
} }
*_out = out; *_out = out;
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import numpy import numpy
from theano import Apply, tensor from theano import Apply, tensor
from theano.gof import COp from theano.gof import COp
from theano.tensor import discrete_dtypes from theano.tensor import discrete_dtypes, as_tensor_variable
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
...@@ -54,6 +54,8 @@ class GpuSparseBlockGemv(COp): ...@@ -54,6 +54,8 @@ class GpuSparseBlockGemv(COp):
o = as_gpuarray_variable(o, ctx) o = as_gpuarray_variable(o, ctx)
W = as_gpuarray_variable(W, ctx) W = as_gpuarray_variable(W, ctx)
h = as_gpuarray_variable(h, ctx) h = as_gpuarray_variable(h, ctx)
inputIdx = as_tensor_variable(inputIdx)
outputIdx = as_tensor_variable(outputIdx)
assert o.ndim == 3 assert o.ndim == 3
assert W.ndim == 4 assert W.ndim == 4
assert h.ndim == 3 assert h.ndim == 3
...@@ -123,6 +125,8 @@ class GpuSparseBlockOuter(COp): ...@@ -123,6 +125,8 @@ class GpuSparseBlockOuter(COp):
o = as_gpuarray_variable(o, ctx) o = as_gpuarray_variable(o, ctx)
x = as_gpuarray_variable(x, ctx) x = as_gpuarray_variable(x, ctx)
y = as_gpuarray_variable(y, ctx) y = as_gpuarray_variable(y, ctx)
xIdx = as_tensor_variable(xIdx)
yIdx = as_tensor_variable(yIdx)
if alpha is None: if alpha is None:
alpha = one alpha = one
return Apply(self, [o, x, y, xIdx, yIdx, alpha], return Apply(self, [o, x, y, xIdx, yIdx, alpha],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论