提交 4c62e54f authored 作者: Frederic's avatar Frederic

Define dtype_var in gpuarray c code.

上级 c58d0090
......@@ -138,7 +138,13 @@ class GpuArrayType(Type):
return numpy.dtype(self.dtype).itemsize
def c_declare(self, name, sub):
return "PyGpuArrayObject *%s;" % (name,)
dtype = theano.tensor.TensorType(
dtype=self.dtype,
broadcastable=self.broadcastable).dtype_specs()[1]
return """
PyGpuArrayObject *%(name)s;
typedef %(dtype)s dtype_%(name)s;
""" % locals()
def c_init(self, name, sub):
return "%s = NULL;" % (name,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论