提交 8a41c6fb authored 作者: Thomas George's avatar Thomas George

GpuAllocEmpty now uses params

上级 073b822d
...@@ -8,7 +8,8 @@ import theano ...@@ -8,7 +8,8 @@ import theano
from theano import Op, Apply, Type, Variable from theano import Op, Apply, Type, Variable
from theano import tensor, config from theano import tensor, config
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.scalar import bool as bool_t from theano.scalar import (bool as bool_t,
int32 as int32_t)
from theano.tensor.basic import ( from theano.tensor.basic import (
Alloc, AllocEmpty, alloc_validate_shape, Join, Split) Alloc, AllocEmpty, alloc_validate_shape, Join, Split)
...@@ -961,16 +962,19 @@ class GpuAllocEmpty(HideC, AllocEmpty): ...@@ -961,16 +962,19 @@ class GpuAllocEmpty(HideC, AllocEmpty):
Allocate uninitialized memory on the GPU. Allocate uninitialized memory on the GPU.
""" """
__props__ = ('dtype', 'context_name') __props__ = ('dtype', 'context_name', 'typecode')
_f16_ok = True _f16_ok = True
params_type = gpu_context_type params_type = ParamsType(context=gpu_context_type,
typecode=int32_t)
def __init__(self, dtype, context_name): def __init__(self, dtype, context_name):
self.dtype = dtype self.dtype = dtype
self.context_name = context_name self.context_name = context_name
self.typecode = gpuarray.dtype_to_typecode(self.dtype)
def get_params(self, node): def get_params(self, node):
return get_context(self.context_name) return self.params_type.get_params(context=get_context(self.context_name),
typecode=self.typecode)
def make_node(self, *shape): def make_node(self, *shape):
sh, bcast = alloc_validate_shape(shape) sh, bcast = alloc_validate_shape(shape)
...@@ -1015,17 +1019,16 @@ shape[%(i)s] = ((dtype_%(shp_i)s *)PyArray_DATA(%(shp_i)s))[0]; ...@@ -1015,17 +1019,16 @@ shape[%(i)s] = ((dtype_%(shp_i)s *)PyArray_DATA(%(shp_i)s))[0];
""" % dict(i=i, shp_i=shp_i)) """ % dict(i=i, shp_i=shp_i))
code.append(""" code.append("""
if (theano_prep_output(&%(zz)s, %(ndim)s, shape, %(type)s, GA_C_ORDER, if (theano_prep_output(&%(zz)s, %(ndim)s, shape, %(params)s->typecode, GA_C_ORDER,
%(ctx)s)) { %(params)s->context)) {
%(fail)s %(fail)s
} }
""" % dict(zz=zz, ndim=ndim, type=gpuarray.dtype_to_typecode(self.dtype), """ % dict(zz=zz, ndim=ndim, fail=fail, params=sub['params']))
fail=fail, ctx=sub['params']))
return ''.join(code) return ''.join(code)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def do_constant_folding(self, node): def do_constant_folding(self, node):
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论