提交 6b197857 authored 作者: Thomas George's avatar Thomas George

AllocEmpty now uses param

上级 8a41c6fb
......@@ -970,7 +970,10 @@ class GpuAllocEmpty(HideC, AllocEmpty):
def __init__(self, dtype, context_name):
self.dtype = dtype
self.context_name = context_name
self.typecode = gpuarray.dtype_to_typecode(self.dtype)
@property
def typecode(self):
return gpuarray.dtype_to_typecode(self.dtype)
def get_params(self, node):
return self.params_type.get_params(context=get_context(self.context_name),
......
......@@ -17,6 +17,7 @@ from theano import gof
from theano.gof import Apply, Constant, Op, Variable, ParamsType
from theano.gof.type import Generic
from theano.scalar import int32 as int32_t
from theano.tensor import elemwise
from theano.tensor.var import (AsTensorError, TensorVariable,
TensorConstant, TensorConstantSignature,
......@@ -6632,13 +6633,18 @@ class Choose(Op):
class AllocEmpty(gof.Op):
"""Implement Alloc on the cpu, but without initializing memory."""
__props__ = ("dtype",)
__props__ = ("dtype", "typecode")
params_type = ParamsType(typecode=int32_t)
# specify the type of the data
def __init__(self, dtype):
assert isinstance(dtype, str), dtype
self.dtype = dtype.lower()
@property
def typecode(self):
return np.dtype(self.dtype).num
def make_node(self, *shape):
shape, bcast = alloc_validate_shape(shape)
otype = TensorType(dtype=self.dtype, broadcastable=bcast)
......@@ -6668,11 +6674,11 @@ class AllocEmpty(gof.Op):
out[0] = np.empty(sh, dtype=self.dtype)
def c_code(self, node, name, inputs, out_, sub):
dtype = "NPY_" + self.dtype.upper()
out, = out_
fail = sub['fail']
shps = inputs
nd = len(shps)
params = sub['params']
str = "npy_intp dims[%(nd)s];\n" % locals()
for idx, sh in enumerate(shps):
str += "dims[%(idx)s] =" \
......@@ -6691,7 +6697,7 @@ class AllocEmpty(gof.Op):
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s,
dims,
%(dtype)s,
%(params)s->typecode,
0);
if (!%(out)s)
{
......@@ -6706,7 +6712,7 @@ class AllocEmpty(gof.Op):
return [node.inputs]
def c_code_cache_version(self):
return (3,)
return (4,)
def do_constant_folding(self, node):
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论