提交 ffb2d4fa authored 作者: notoraptor's avatar notoraptor

Wrap op param for test op theano.gpuarray.tests.test_cgpukernelbase.GpuEye.

上级 808f855b
......@@ -4,10 +4,12 @@ from six.moves import xrange
import theano
from theano import tensor, config, Apply, Op
from theano.scalar import int32 as int_t
from theano.gof import ParamsType
from theano.gradient import grad_undefined
from ..basic_ops import CGpuKernelBase
from ..type import GpuArrayType, get_context
from ..type import GpuArrayType, get_context, gpu_context_type
# This is an implementation to test that CGpuKernelBase works and also
......@@ -18,6 +20,7 @@ class GpuEye(CGpuKernelBase, Op):
"""
__props__ = ('dtype', 'context_name')
params_type = ParamsType(typecode=int_t, context=gpu_context_type)
def __init__(self, dtype=None, context_name=None):
if dtype is None:
......@@ -28,7 +31,9 @@ class GpuEye(CGpuKernelBase, Op):
'APPLY_SPECIFIC(tstgpueye)')
def get_params(self, node):
return get_context(self.context_name)
from pygpu.gpuarray import dtype_to_typecode
return self.params_type.get_params(typecode=dtype_to_typecode(self.dtype),
context=get_context(self.context_name))
def c_headers(self):
return ['<gpuarray/types.h>', '<gpuarray/kernel.h>']
......@@ -52,11 +57,6 @@ class GpuEye(CGpuKernelBase, Op):
return [grad_undefined(self, i, inp[i])
for i in xrange(2)]
def get_op_params(self):
from pygpu.gpuarray import dtype_to_typecode
return [('TYPECODE', str(dtype_to_typecode(self.dtype)))]
def test_cgpukernelbase():
# Import inside the function to prevent the back-end from being
......@@ -69,4 +69,5 @@ def test_cgpukernelbase():
r = f()
assert r.dtype == 'int32'
assert (np.asarray(r) == np.eye(4, 5, dtype='int32')).all()
......@@ -18,7 +18,7 @@ KERNEL void eye(GLOBAL_MEM DTYPE_OUTPUT_0 *a, ga_size a_off, ga_size n, ga_size
#section support_code_struct
int APPLY_SPECIFIC(tstgpueye)(PyArrayObject *n, PyArrayObject *m,
PyGpuArrayObject **z, PyGpuContextObject *ctx) {
PyGpuArrayObject **z, PARAMS_TYPE* params) {
size_t dims[2] = {0, 0};
size_t ls, gs;
void *args[3];
......@@ -29,9 +29,9 @@ int APPLY_SPECIFIC(tstgpueye)(PyArrayObject *n, PyArrayObject *m,
Py_XDECREF(*z);
*z = pygpu_zeros(2, dims,
TYPECODE,
params->typecode,
GA_C_ORDER,
ctx, Py_None);
params->context, Py_None);
if (*z == NULL)
return -1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论