提交 f85f469d authored 作者: Frederic's avatar Frederic

GpuIncSubtensor continued, not finished.

上级 3143c789
...@@ -168,17 +168,17 @@ class GpuIncSubtensor(HideC, IncSubtensor): ...@@ -168,17 +168,17 @@ class GpuIncSubtensor(HideC, IncSubtensor):
""" """
def make_node(self, x, y, *inputs): def make_node(self, x, y, *inputs):
x = as_cuda_ndarray_variable(x) x = as_gpuarray_variable(x)
y = as_cuda_ndarray_variable(y) y = as_gpuarray_variable(y)
rval = tensor.IncSubtensor.make_node(self, x, y, *inputs) rval = tensor.IncSubtensor.make_node(self, x, y, *inputs)
return Apply(self, [x, y] + rval.inputs[2:], [x.type()]) return gof.Apply(self, [x, y] + rval.inputs[2:], [x.type()])
def do_type_checking(self, node): def do_type_checking(self, node):
""" Should raise NotImplementedError if c_code does not support """ Should raise NotImplementedError if c_code does not support
the types involved in this node. the types involved in this node.
""" """
if not isinstance(node.inputs[0].type, CudaNdarrayType): if not isinstance(node.inputs[0].type, GpuArrayType):
raise NotImplementedError() raise NotImplementedError()
def copy_of_x(self, x): def copy_of_x(self, x):
...@@ -191,13 +191,13 @@ class GpuIncSubtensor(HideC, IncSubtensor): ...@@ -191,13 +191,13 @@ class GpuIncSubtensor(HideC, IncSubtensor):
Base class uses `PyArrayObject *`, subclasses may override for Base class uses `PyArrayObject *`, subclasses may override for
different types of arrays. different types of arrays.
""" """
return """(CudaNdarray*) CudaNdarray_Copy(%(x)s)""" % locals() return """pygpu_copy(%(x)s, GA_ANY_ORDER)""" % locals()
def decl_view(self): def decl_view(self):
return "CudaNdarray* zview = NULL;" return "PyGpuArray* zview = NULL;"
def make_view_array(self, x, view_ndim): def make_view_array(self, x, view_ndim):
""" """//TODO
:param x: a string identifying an array to be viewed :param x: a string identifying an array to be viewed
:param view_ndim: a string specifying the number of dimensions :param view_ndim: a string specifying the number of dimensions
to have in the view to have in the view
...@@ -230,8 +230,8 @@ class GpuIncSubtensor(HideC, IncSubtensor): ...@@ -230,8 +230,8 @@ class GpuIncSubtensor(HideC, IncSubtensor):
def get_helper_c_code_args(self): def get_helper_c_code_args(self):
""" Return a dictionary of arguments to use with helper_c_code""" """ Return a dictionary of arguments to use with helper_c_code"""
return {'c_prefix': 'CudaNdarray', return {'c_prefix': 'PyGpuArray',
'strides_mul': 4 'strides_mul': 1
} }
def copy_into(self, view, source): def copy_into(self, view, source):
...@@ -242,10 +242,10 @@ class GpuIncSubtensor(HideC, IncSubtensor): ...@@ -242,10 +242,10 @@ class GpuIncSubtensor(HideC, IncSubtensor):
returns a C code expression to copy source into view, and returns a C code expression to copy source into view, and
return 0 on success return 0 on success
""" """
return """CudaNdarray_CopyFromCudaNdarray(%(view)s, %(source)s)""" % locals() return """GpuArray_move(%(view)s, %(source)s)""" % locals()
def set_view_base(self, x, fail): def set_view_base(self, x, fail):
return """ return """//TODO
//Set the base only now //Set the base only now
if(CudaNdarray_set_device_data(zview, CudaNdarray_DEV_DATA(zview), if(CudaNdarray_set_device_data(zview, CudaNdarray_DEV_DATA(zview),
...@@ -258,7 +258,7 @@ class GpuIncSubtensor(HideC, IncSubtensor): ...@@ -258,7 +258,7 @@ class GpuIncSubtensor(HideC, IncSubtensor):
}""" % locals() }""" % locals()
def add_to_zview(self, x, fail): def add_to_zview(self, x, fail):
#TODO
return """ return """
PyObject * add_result = CudaNdarray_inplace_add((PyObject *) zview, PyObject * add_result = CudaNdarray_inplace_add((PyObject *) zview,
(PyObject *) py_%(x)s); (PyObject *) py_%(x)s);
...@@ -275,6 +275,7 @@ class GpuIncSubtensor(HideC, IncSubtensor): ...@@ -275,6 +275,7 @@ class GpuIncSubtensor(HideC, IncSubtensor):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return ()
parent_version = super(GpuIncSubtensor, self).c_code_cache_version() parent_version = super(GpuIncSubtensor, self).c_code_cache_version()
if parent_version: if parent_version:
return parent_version + (0,) return parent_version + (0,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论