Unverified 提交 72d83bce authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6587 from abergeron/cudnn_71

Follow API changes in cudnn 7.1 for RNN params
......@@ -2223,7 +2223,9 @@ class _RNNSplitParams(DnnBase):
void *w;
void *o;
ptrdiff_t off;
#if CUDNN_VERSION < 7100
size_t bshp;
#endif
cudnnStatus_t err;
cudnnDataType_t dt;
cudnnTensorFormat_t tf;
......@@ -2316,13 +2318,21 @@ class _RNNSplitParams(DnnBase):
%(fail)s;
}
// We assume that the typecode matches
assert(dims[2] == 1);
assert(dims[1] == 1);
#if CUDNN_VERSION < 7100
assert(dims[2] == 1 && "bias");
assert(dims[1] == 1 && "bias");
%(b)s = pygpu_view(%(w)s, Py_None);
%(b)s->ga.offset += off;
%(b)s->ga.dimensions[0] = dims[0];
GpuArray_fix_flags(&%(b)s->ga);
bshp = dims[0];
#else
assert(dims[0] == 1 && "bias");
assert(dims[2] == 1 && "bias");
%(b)s = pygpu_view(%(w)s, Py_None);
%(b)s->ga.offset += off;
%(b)s->ga.dimensions[0] = dims[1];
#endif
GpuArray_fix_flags(&%(b)s->ga);
err = cudnnGetRNNLinLayerMatrixParams(%(handle)s, %(desc)s, %(layer)s, xdesc, wdesc, w, %(id)s, odesc, &o);
if (err != CUDNN_STATUS_SUCCESS) {
......@@ -2345,14 +2355,23 @@ class _RNNSplitParams(DnnBase):
%(fail)s;
}
assert(dims[1] == 1);
assert(dims[2] == 1);
#if CUDNN_VERSION < 7100
assert(dims[1] == 1 && "matrix");
assert(dims[2] == 1 && "matrix");
// We assume that the typecode matches
%(m)s = pygpu_reshape(%(w)s, 2, nshp, GA_F_ORDER, 1, -1);
%(m)s->ga.offset += off;
assert(dims[0] %% bshp == 0);
%(m)s->ga.dimensions[0] = dims[0] / bshp;
%(m)s->ga.dimensions[1] = bshp;
#else
assert(dims[0] == 1 && "matrix");
// We assume that the typecode matches
%(m)s = pygpu_reshape(%(w)s, 2, nshp, GA_F_ORDER, 1, -1);
%(m)s->ga.offset += off;
%(m)s->ga.dimensions[1] = dims[1];
%(m)s->ga.dimensions[0] = dims[2];
#endif
%(m)s->ga.strides[1] = %(m)s->ga.dimensions[0] * gpuarray_get_elsize(%(m)s->ga.typecode);
GpuArray_fix_flags(&%(m)s->ga);
""" % kw2
......@@ -2368,7 +2387,7 @@ class _RNNSplitParams(DnnBase):
return code
def c_code_cache_version(self):
return (4, version())
return (5, version())
def _split_rnn_params(w, desc, layer, input_size, dtype, rnn_mode):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论