提交 d6da86d7 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

generalized the rest of the helper c code arguments

上级 ca5fbc8c
...@@ -2175,6 +2175,12 @@ class GpuReshape(tensor.Reshape, GpuOp): ...@@ -2175,6 +2175,12 @@ class GpuReshape(tensor.Reshape, GpuOp):
out[0] = x.reshape(tuple(shp)) out[0] = x.reshape(tuple(shp))
# C Code shared by GpuSubtensor and GpuIncSubtensor
_define_set_data = """
#define CudaNdarray_set_device_data2(obj, ptr, base) \
CudaNdarray_set_device_data(obj, (float *)ptr, base)
"""
class GpuSubtensor(GpuOp, tensor.Subtensor): class GpuSubtensor(GpuOp, tensor.Subtensor):
""" """
Implement subtensor on the gpu. Implement subtensor on the gpu.
...@@ -2240,10 +2246,10 @@ class GpuSubtensor(GpuOp, tensor.Subtensor): ...@@ -2240,10 +2246,10 @@ class GpuSubtensor(GpuOp, tensor.Subtensor):
%(fail)s; %(fail)s;
} }
cnda_mark_dev_structure_dirty(xview); cnda_mark_dev_structure_dirty(xview);
#define CudaNdarray_set_device_data2(obj, ptr, base) \ """ % locals()
CudaNdarray_set_device_data(obj, (float *)ptr, base)
""" % locals() get_xview = _define_set_data + \
get_xview = self.helper_c_code(node, name, inputs, outputs, sub, self.helper_c_code(node, name, inputs, outputs, sub,
self.idx_list, self.idx_list,
c_prefix='CudaNdarray', c_prefix='CudaNdarray',
set_data='CudaNdarray_set_device_data2', set_data='CudaNdarray_set_device_data2',
...@@ -2251,6 +2257,7 @@ class GpuSubtensor(GpuOp, tensor.Subtensor): ...@@ -2251,6 +2257,7 @@ class GpuSubtensor(GpuOp, tensor.Subtensor):
set_stride='CudaNdarray_set_stride', set_stride='CudaNdarray_set_stride',
update_flags="", strides_mul=4) update_flags="", strides_mul=4)
finish_view = """ finish_view = """
//Set the base only now //Set the base only now
...@@ -2453,9 +2460,16 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): ...@@ -2453,9 +2460,16 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
return """CudaNdarray* xview = (CudaNdarray*) return """CudaNdarray* xview = (CudaNdarray*)
CudaNdarray_New(%(view_ndim)s)""" % locals() CudaNdarray_New(%(view_ndim)s)""" % locals()
def get_update_flags(self): def get_helper_c_code_args(self):
""" Return the update_flags string to pass to helper_c_code.""" """ Return a dictionary of arguments to use with helper_c_code"""
return "" return { 'update_flags' : "",
'c_prefix' : 'CudaNdarray',
'set_data' :'CudaNdarray_set_device_data2',
'set_dim' : 'CudaNdarray_set_dim',
'set_stride' : 'CudaNdarray_set_stride',
'update_flags' : "",
'strides_mul': 4
}
def copy_into(self, view, source): def copy_into(self, view, source):
""" """
...@@ -2467,6 +2481,9 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): ...@@ -2467,6 +2481,9 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
""" """
return """CudaNdarray_CopyFromCudaNdarray(%(view)s, %(source)s)""" % locals() return """CudaNdarray_CopyFromCudaNdarray(%(view)s, %(source)s)""" % locals()
def define_set_data(self):
return _define_set_data
def c_code_cache_version(self): def c_code_cache_version(self):
# TODO: cooperate with parent class' C code # TODO: cooperate with parent class' C code
return () return ()
......
...@@ -3964,29 +3964,55 @@ class Subtensor(Op): ...@@ -3964,29 +3964,55 @@ class Subtensor(Op):
return "%s{%s}" % (self.__class__.__name__, ", ".join(indices)) return "%s{%s}" % (self.__class__.__name__, ", ".join(indices))
@staticmethod @staticmethod
def default_update_flags(): def default_helper_c_code_args():
return ("PyArray_UpdateFlags(xview," """
Returns a dictionary of default arguments to
helper_c_code
"""
return {
"c_prefix" : "PyArray",
"update_flags": ("PyArray_UpdateFlags(xview,"
" NPY_ARRAY_C_CONTIGUOUS|" " NPY_ARRAY_C_CONTIGUOUS|"
"NPY_ARRAY_F_CONTIGUOUS);") "NPY_ARRAY_F_CONTIGUOUS);"),
"set_data" : "PyArray_set_data",
"set_dim" : "PyArray_set_dim",
"set_stride" : "PyArray_set_stride",
"strides_mul" : "strides_mul" }
@staticmethod @staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list, def helper_c_code(node, name, inputs, outputs, sub, idx_list,
c_prefix="PyArray", c_prefix=None,
update_flags=None, update_flags=None,
set_data='PyArray_set_data', set_data=None,
set_dim='PyArray_set_dim', set_dim=None,
set_stride='PyArray_set_stride', set_stride=None,
strides_mul=1, strides_mul=None,
): ):
"""The parameters c_prefix, update_flags, set_data, set_dim, """
The parameters c_prefix, update_flags, set_data, set_dim,
set_stride and strides_mul are there to allow reusing this set_stride and strides_mul are there to allow reusing this
function on PyArray and CudaNdarray object. function on PyArray and CudaNdarray object.
""" """
default_args = Subtensor.default_helper_c_code_args()
if update_flags is None: if update_flags is None:
update_flags = Subtensor.default_update_flags() update_flags = default_args['update_flags']
if set_data is None:
set_data = default_args['set_data']
if set_dim is None:
set_dim = default_args['set_dim']
if set_stride is None:
set_stride = default_args['set_stride']
if strides_mul is None:
strides_mul = default_args['strides_mul']
# #
# two arrays are created in C code: # two arrays are created in C code:
...@@ -4062,6 +4088,7 @@ class Subtensor(Op): ...@@ -4062,6 +4088,7 @@ class Subtensor(Op):
z, = outputs z, = outputs
rval = """ rval = """
fprintf(stderr, "Enter helper_c_code\\n");
#define PyArray_set_dim(obj, idx, d) PyArray_DIMS(obj)[idx]=d #define PyArray_set_dim(obj, idx, d) PyArray_DIMS(obj)[idx]=d
#define PyArray_set_stride(obj, idx, d) PyArray_STRIDES(obj)[idx]=d #define PyArray_set_stride(obj, idx, d) PyArray_STRIDES(obj)[idx]=d
#define PyArray_set_data(obj, ptr, base) PyArray_BYTES(obj)=ptr #define PyArray_set_data(obj, ptr, base) PyArray_BYTES(obj)=ptr
...@@ -4602,14 +4629,15 @@ class IncSubtensor(Op): ...@@ -4602,14 +4629,15 @@ class IncSubtensor(Op):
} }
""" % locals() """ % locals()
# make xview actually a view of %(z)s # make xview actually a view of %(z)s
get_xview = Subtensor.helper_c_code( get_xview = self.define_set_data() + \
Subtensor.helper_c_code(
node=node, node=node,
name=name, name=name,
inputs=outputs[:1] + inputs[2:], inputs=outputs[:1] + inputs[2:],
outputs=outputs, outputs=outputs,
sub=sub, sub=sub,
idx_list=self.idx_list, idx_list=self.idx_list,
update_flags=self.get_update_flags() **self.get_helper_c_code_args()
) )
copy_into = self.copy_into("xview", y) copy_into = self.copy_into("xview", y)
...@@ -4707,9 +4735,9 @@ class IncSubtensor(Op): ...@@ -4707,9 +4735,9 @@ class IncSubtensor(Op):
%(x)s->flags, %(x)s->flags,
NULL)""" % locals() NULL)""" % locals()
def get_update_flags(self): def get_helper_c_code_args(self):
""" Return the update_flags string to pass to helper c_code.""" """ Return a dictionary of arguments to pass to helper_c_code."""
return Subtensor.default_update_flags() return Subtensor.default_helper_c_code_args()
def copy_into(self, view, source): def copy_into(self, view, source):
""" """
...@@ -4721,6 +4749,11 @@ class IncSubtensor(Op): ...@@ -4721,6 +4749,11 @@ class IncSubtensor(Op):
""" """
return """PyArray_CopyInto(%(view)s, %(source)s)""" % locals() return """PyArray_CopyInto(%(view)s, %(source)s)""" % locals()
def define_set_data(self):
""" Returns C code used to define any macros used in the
set data argument to the helper C code. """
return ""
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论