提交 8b24ff5d authored 作者: notoraptor's avatar notoraptor

Wrap Op params for theano.gpuarray.rng_mrg.GPUA_mrg_uniform:

- inplace: boolean scalar - ndim: integer scalr - otypenum: integer scalar - otype_is_float32: boolean scalar - otypecode: integer scalar - context: gpu_context_type
上级 0d925cf3
...@@ -7,16 +7,15 @@ http://www.iro.umontreal.ca/~simardr/ssj/indexe.html ...@@ -7,16 +7,15 @@ http://www.iro.umontreal.ca/~simardr/ssj/indexe.html
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import numpy as np
from theano import Apply, tensor from theano import Apply, tensor
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.sandbox.rng_mrg import mrg_uniform_base, mrg_uniform from theano.sandbox.rng_mrg import mrg_uniform_base, mrg_uniform
from theano.tensor import as_tensor_variable, get_vector_length from theano.tensor import as_tensor_variable, get_vector_length
from theano.scalar import int32 as int_t
from .basic_ops import (GpuKernelBase, Kernel, infer_context_name, from .basic_ops import (GpuKernelBase, Kernel, infer_context_name,
host_from_gpu, as_gpuarray_variable) host_from_gpu, as_gpuarray_variable)
from .type import GpuArrayType from .type import GpuArrayType, gpu_context_type
from .fp16_help import write_w from .fp16_help import write_w
from .opt import register_opt, register_opt2 from .opt import register_opt, register_opt2
...@@ -24,6 +23,9 @@ from .opt import register_opt, register_opt2 ...@@ -24,6 +23,9 @@ from .opt import register_opt, register_opt2
class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
# GpuArray version # GpuArray version
_f16_ok = True _f16_ok = True
params_type = mrg_uniform_base.params_type.extended(otypecode=int_t, context=gpu_context_type)
otypecode = property(lambda self: self.output_type.typecode)
def make_node(self, rstate, size): def make_node(self, rstate, size):
# error checking slightly redundant here, since # error checking slightly redundant here, since
...@@ -39,6 +41,9 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -39,6 +41,9 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
[rstate, size], [rstate, size],
[rstate.type(), output_type]) [rstate.type(), output_type])
def get_params(self, node):
return self.params_type.get_params(self, context=node.inputs[0].type.context)
@classmethod @classmethod
def new(cls, rstate, ndim, dtype, size): def new(cls, rstate, ndim, dtype, size):
v_size = as_tensor_variable(size) v_size = as_tensor_variable(size)
...@@ -168,40 +173,34 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -168,40 +173,34 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
] ]
def c_code(self, node, nodename, inp, out, sub): def c_code(self, node, nodename, inp, out, sub):
rstate, size = inp
o_rstate, o_sample = out
inplace = int(self.inplace)
ndim = self.output_type.ndim
o_type_num = np.asarray(0, dtype=self.output_type.dtype).dtype.num
fail = sub['fail']
ctx = sub['params']
kname = self.gpu_kernels(node, nodename)[0].objvar
otypecode = str(self.output_type.typecode)
return """ return """
npy_int64 M1 = 2147483647; //2^31 - 1 npy_int64 M1 = 2147483647; //2^31 - 1
// The +1 is to avoid odims[0] which fails on windows
size_t odims[%(ndim)s+1];
size_t n_elements = 1; size_t n_elements = 1;
unsigned int n_streams; unsigned int n_streams;
int must_alloc_sample = ((NULL == %(o_sample)s) int must_alloc_sample = ((NULL == %(o_sample)s)
|| !pygpu_GpuArray_Check((PyObject*)%(o_sample)s) || !pygpu_GpuArray_Check((PyObject*)%(o_sample)s)
|| !(%(o_sample)s->ga.flags & GA_C_CONTIGUOUS) || !(%(o_sample)s->ga.flags & GA_C_CONTIGUOUS)
|| (PyGpuArray_NDIM(%(o_sample)s) != %(ndim)s)); || (PyGpuArray_NDIM(%(o_sample)s) != %(params)s->ndim));
size_t* odims = (size_t*)malloc(%(params)s->ndim * sizeof(size_t));
if (odims == NULL) {
PyErr_NoMemory();
%(just_fail)s
}
if (PyArray_NDIM(%(size)s) != 1) if (PyArray_NDIM(%(size)s) != 1)
{ {
PyErr_SetString(PyExc_ValueError, "size must be vector"); PyErr_SetString(PyExc_ValueError, "size must be vector");
%(fail)s %(fail)s
} }
if (PyArray_DIMS(%(size)s)[0] != %(ndim)s) if (PyArray_DIMS(%(size)s)[0] != %(params)s->ndim)
{ {
PyErr_Format(PyExc_ValueError, "size must have length %%i (not %%li)", PyErr_Format(PyExc_ValueError, "size must have length %%i (not %%li)",
%(ndim)s, PyArray_DIMS(%(size)s)[0]); %(params)s->ndim, PyArray_DIMS(%(size)s)[0]);
%(fail)s %(fail)s
} }
for (int i = 0; i < %(ndim)s; ++i) for (int i = 0; i < %(params)s->ndim; ++i)
{ {
odims[i] = *(dtype_%(size)s *)PyArray_GETPTR1(%(size)s, i); odims[i] = *(dtype_%(size)s *)PyArray_GETPTR1(%(size)s, i);
n_elements *= odims[i]; n_elements *= odims[i];
...@@ -219,8 +218,8 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -219,8 +218,8 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
if (must_alloc_sample) if (must_alloc_sample)
{ {
Py_XDECREF(%(o_sample)s); Py_XDECREF(%(o_sample)s);
%(o_sample)s = pygpu_empty(%(ndim)s, odims, %(otypecode)s, GA_C_ORDER, %(o_sample)s = pygpu_empty(%(params)s->ndim, odims, %(params)s->otypecode, GA_C_ORDER,
%(ctx)s, Py_None); %(params)s->context, Py_None);
if(!%(o_sample)s) if(!%(o_sample)s)
{ {
%(fail)s; %(fail)s;
...@@ -233,7 +232,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -233,7 +232,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
} }
Py_XDECREF(%(o_rstate)s); Py_XDECREF(%(o_rstate)s);
if (%(inplace)s) if (%(params)s->inplace)
{ {
Py_INCREF(%(rstate)s); Py_INCREF(%(rstate)s);
%(o_rstate)s = %(rstate)s; %(o_rstate)s = %(rstate)s;
...@@ -285,10 +284,22 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -285,10 +284,22 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
%(fail)s %(fail)s
} }
} }
""" % locals()
free(odims);
""" % dict(rstate=inp[0], size=inp[1],
o_rstate=out[0], o_sample=out[1],
kname=self.gpu_kernels(node, nodename)[0].objvar,
params=sub['params'],
just_fail=sub['fail'],
fail="""
{
free(odims);
%(fail)s
}
""" % dict(fail=sub['fail']))
def c_code_cache_version(self): def c_code_cache_version(self):
return (14,) return (15,)
@register_opt2([mrg_uniform], 'fast_compile') @register_opt2([mrg_uniform], 'fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论