提交 8b24ff5d authored 作者: notoraptor's avatar notoraptor

Wrap Op params for theano.gpuarray.rng_mrg.GPUA_mrg_uniform:

- inplace: boolean scalar - ndim: integer scalr - otypenum: integer scalar - otype_is_float32: boolean scalar - otypecode: integer scalar - context: gpu_context_type
上级 0d925cf3
......@@ -7,16 +7,15 @@ http://www.iro.umontreal.ca/~simardr/ssj/indexe.html
"""
from __future__ import absolute_import, print_function, division
import numpy as np
from theano import Apply, tensor
from theano.gof import local_optimizer
from theano.sandbox.rng_mrg import mrg_uniform_base, mrg_uniform
from theano.tensor import as_tensor_variable, get_vector_length
from theano.scalar import int32 as int_t
from .basic_ops import (GpuKernelBase, Kernel, infer_context_name,
host_from_gpu, as_gpuarray_variable)
from .type import GpuArrayType
from .type import GpuArrayType, gpu_context_type
from .fp16_help import write_w
from .opt import register_opt, register_opt2
......@@ -24,6 +23,9 @@ from .opt import register_opt, register_opt2
class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
# GpuArray version
_f16_ok = True
params_type = mrg_uniform_base.params_type.extended(otypecode=int_t, context=gpu_context_type)
otypecode = property(lambda self: self.output_type.typecode)
def make_node(self, rstate, size):
# error checking slightly redundant here, since
......@@ -39,6 +41,9 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
[rstate, size],
[rstate.type(), output_type])
def get_params(self, node):
return self.params_type.get_params(self, context=node.inputs[0].type.context)
@classmethod
def new(cls, rstate, ndim, dtype, size):
v_size = as_tensor_variable(size)
......@@ -168,40 +173,34 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
]
def c_code(self, node, nodename, inp, out, sub):
rstate, size = inp
o_rstate, o_sample = out
inplace = int(self.inplace)
ndim = self.output_type.ndim
o_type_num = np.asarray(0, dtype=self.output_type.dtype).dtype.num
fail = sub['fail']
ctx = sub['params']
kname = self.gpu_kernels(node, nodename)[0].objvar
otypecode = str(self.output_type.typecode)
return """
npy_int64 M1 = 2147483647; //2^31 - 1
// The +1 is to avoid odims[0] which fails on windows
size_t odims[%(ndim)s+1];
size_t n_elements = 1;
unsigned int n_streams;
int must_alloc_sample = ((NULL == %(o_sample)s)
|| !pygpu_GpuArray_Check((PyObject*)%(o_sample)s)
|| !(%(o_sample)s->ga.flags & GA_C_CONTIGUOUS)
|| (PyGpuArray_NDIM(%(o_sample)s) != %(ndim)s));
|| (PyGpuArray_NDIM(%(o_sample)s) != %(params)s->ndim));
size_t* odims = (size_t*)malloc(%(params)s->ndim * sizeof(size_t));
if (odims == NULL) {
PyErr_NoMemory();
%(just_fail)s
}
if (PyArray_NDIM(%(size)s) != 1)
{
PyErr_SetString(PyExc_ValueError, "size must be vector");
%(fail)s
}
if (PyArray_DIMS(%(size)s)[0] != %(ndim)s)
if (PyArray_DIMS(%(size)s)[0] != %(params)s->ndim)
{
PyErr_Format(PyExc_ValueError, "size must have length %%i (not %%li)",
%(ndim)s, PyArray_DIMS(%(size)s)[0]);
%(params)s->ndim, PyArray_DIMS(%(size)s)[0]);
%(fail)s
}
for (int i = 0; i < %(ndim)s; ++i)
for (int i = 0; i < %(params)s->ndim; ++i)
{
odims[i] = *(dtype_%(size)s *)PyArray_GETPTR1(%(size)s, i);
n_elements *= odims[i];
......@@ -219,8 +218,8 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
if (must_alloc_sample)
{
Py_XDECREF(%(o_sample)s);
%(o_sample)s = pygpu_empty(%(ndim)s, odims, %(otypecode)s, GA_C_ORDER,
%(ctx)s, Py_None);
%(o_sample)s = pygpu_empty(%(params)s->ndim, odims, %(params)s->otypecode, GA_C_ORDER,
%(params)s->context, Py_None);
if(!%(o_sample)s)
{
%(fail)s;
......@@ -233,7 +232,7 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
}
Py_XDECREF(%(o_rstate)s);
if (%(inplace)s)
if (%(params)s->inplace)
{
Py_INCREF(%(rstate)s);
%(o_rstate)s = %(rstate)s;
......@@ -285,10 +284,22 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
%(fail)s
}
}
""" % locals()
free(odims);
""" % dict(rstate=inp[0], size=inp[1],
o_rstate=out[0], o_sample=out[1],
kname=self.gpu_kernels(node, nodename)[0].objvar,
params=sub['params'],
just_fail=sub['fail'],
fail="""
{
free(odims);
%(fail)s
}
""" % dict(fail=sub['fail']))
def c_code_cache_version(self):
return (14,)
return (15,)
@register_opt2([mrg_uniform], 'fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论