提交 ab154880 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix flags for split matrices.

上级 2ec9998a
...@@ -2144,12 +2144,11 @@ class _RNNSplitParams(DnnBase): ...@@ -2144,12 +2144,11 @@ class _RNNSplitParams(DnnBase):
assert(dims[1] == 1); assert(dims[1] == 1);
assert(dims[2] == 1); assert(dims[2] == 1);
// We assume that the typecode matches // We assume that the typecode matches
%(m)s = pygpu_reshape(%(w)s, 2, nshp, GA_C_ORDER, 1, -1); %(m)s = pygpu_reshape(%(w)s, 2, nshp, GA_F_ORDER, 1, -1);
%(m)s->ga.offset = off; %(m)s->ga.offset = off;
assert(dims[0] %% bshp == 0); assert(dims[0] %% bshp == 0);
%(m)s->ga.dimensions[0] = dims[0] / bshp; %(m)s->ga.dimensions[0] = dims[0] / bshp;
%(m)s->ga.dimensions[1] = bshp; %(m)s->ga.dimensions[1] = bshp;
%(m)s->ga.strides[0] = %(m)s->ga.strides[1];
%(m)s->ga.strides[1] = %(m)s->ga.dimensions[0] * gpuarray_get_elsize(%(m)s->ga.typecode); %(m)s->ga.strides[1] = %(m)s->ga.dimensions[0] * gpuarray_get_elsize(%(m)s->ga.typecode);
""" % kw2 """ % kw2
...@@ -2164,7 +2163,7 @@ class _RNNSplitParams(DnnBase): ...@@ -2164,7 +2163,7 @@ class _RNNSplitParams(DnnBase):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def _split_rnn_params(w, desc, layer, input_size, dtype, rnn_mode): def _split_rnn_params(w, desc, layer, input_size, dtype, rnn_mode):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论