提交 d6da86d7 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

generalized the rest of the helper c code arguments

上级 ca5fbc8c
......@@ -2175,6 +2175,12 @@ class GpuReshape(tensor.Reshape, GpuOp):
out[0] = x.reshape(tuple(shp))
# C Code shared by GpuSubtensor and GpuIncSubtensor
_define_set_data = """
#define CudaNdarray_set_device_data2(obj, ptr, base) \
CudaNdarray_set_device_data(obj, (float *)ptr, base)
"""
class GpuSubtensor(GpuOp, tensor.Subtensor):
"""
Implement subtensor on the gpu.
......@@ -2240,10 +2246,10 @@ class GpuSubtensor(GpuOp, tensor.Subtensor):
%(fail)s;
}
cnda_mark_dev_structure_dirty(xview);
#define CudaNdarray_set_device_data2(obj, ptr, base) \
CudaNdarray_set_device_data(obj, (float *)ptr, base)
""" % locals()
get_xview = self.helper_c_code(node, name, inputs, outputs, sub,
""" % locals()
get_xview = _define_set_data + \
self.helper_c_code(node, name, inputs, outputs, sub,
self.idx_list,
c_prefix='CudaNdarray',
set_data='CudaNdarray_set_device_data2',
......@@ -2251,6 +2257,7 @@ class GpuSubtensor(GpuOp, tensor.Subtensor):
set_stride='CudaNdarray_set_stride',
update_flags="", strides_mul=4)
finish_view = """
//Set the base only now
......@@ -2453,9 +2460,16 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
return """CudaNdarray* xview = (CudaNdarray*)
CudaNdarray_New(%(view_ndim)s)""" % locals()
def get_update_flags(self):
""" Return the update_flags string to pass to helper_c_code."""
return ""
def get_helper_c_code_args(self):
""" Return a dictionary of arguments to use with helper_c_code"""
return { 'update_flags' : "",
'c_prefix' : 'CudaNdarray',
'set_data' :'CudaNdarray_set_device_data2',
'set_dim' : 'CudaNdarray_set_dim',
'set_stride' : 'CudaNdarray_set_stride',
'update_flags' : "",
'strides_mul': 4
}
def copy_into(self, view, source):
"""
......@@ -2467,6 +2481,9 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
"""
return """CudaNdarray_CopyFromCudaNdarray(%(view)s, %(source)s)""" % locals()
def define_set_data(self):
return _define_set_data
def c_code_cache_version(self):
# TODO: cooperate with parent class' C code
return ()
......
......@@ -3964,29 +3964,55 @@ class Subtensor(Op):
return "%s{%s}" % (self.__class__.__name__, ", ".join(indices))
@staticmethod
def default_update_flags():
return ("PyArray_UpdateFlags(xview,"
def default_helper_c_code_args():
"""
Returns a dictionary of default arguments to
helper_c_code
"""
return {
"c_prefix" : "PyArray",
"update_flags": ("PyArray_UpdateFlags(xview,"
" NPY_ARRAY_C_CONTIGUOUS|"
"NPY_ARRAY_F_CONTIGUOUS);")
"NPY_ARRAY_F_CONTIGUOUS);"),
"set_data" : "PyArray_set_data",
"set_dim" : "PyArray_set_dim",
"set_stride" : "PyArray_set_stride",
"strides_mul" : "strides_mul" }
@staticmethod
def helper_c_code(node, name, inputs, outputs, sub, idx_list,
c_prefix="PyArray",
c_prefix=None,
update_flags=None,
set_data='PyArray_set_data',
set_dim='PyArray_set_dim',
set_stride='PyArray_set_stride',
strides_mul=1,
set_data=None,
set_dim=None,
set_stride=None,
strides_mul=None,
):
"""The parameters c_prefix, update_flags, set_data, set_dim,
"""
The parameters c_prefix, update_flags, set_data, set_dim,
set_stride and strides_mul are there to allow reusing this
function on PyArray and CudaNdarray object.
"""
default_args = Subtensor.default_helper_c_code_args()
if update_flags is None:
update_flags = Subtensor.default_update_flags()
update_flags = default_args['update_flags']
if set_data is None:
set_data = default_args['set_data']
if set_dim is None:
set_dim = default_args['set_dim']
if set_stride is None:
set_stride = default_args['set_stride']
if strides_mul is None:
strides_mul = default_args['strides_mul']
#
# two arrays are created in C code:
......@@ -4062,6 +4088,7 @@ class Subtensor(Op):
z, = outputs
rval = """
fprintf(stderr, "Enter helper_c_code\\n");
#define PyArray_set_dim(obj, idx, d) PyArray_DIMS(obj)[idx]=d
#define PyArray_set_stride(obj, idx, d) PyArray_STRIDES(obj)[idx]=d
#define PyArray_set_data(obj, ptr, base) PyArray_BYTES(obj)=ptr
......@@ -4602,14 +4629,15 @@ class IncSubtensor(Op):
}
""" % locals()
# make xview actually a view of %(z)s
get_xview = Subtensor.helper_c_code(
get_xview = self.define_set_data() + \
Subtensor.helper_c_code(
node=node,
name=name,
inputs=outputs[:1] + inputs[2:],
outputs=outputs,
sub=sub,
idx_list=self.idx_list,
update_flags=self.get_update_flags()
**self.get_helper_c_code_args()
)
copy_into = self.copy_into("xview", y)
......@@ -4707,9 +4735,9 @@ class IncSubtensor(Op):
%(x)s->flags,
NULL)""" % locals()
def get_update_flags(self):
""" Return the update_flags string to pass to helper c_code."""
return Subtensor.default_update_flags()
def get_helper_c_code_args(self):
""" Return a dictionary of arguments to pass to helper_c_code."""
return Subtensor.default_helper_c_code_args()
def copy_into(self, view, source):
"""
......@@ -4721,6 +4749,11 @@ class IncSubtensor(Op):
"""
return """PyArray_CopyInto(%(view)s, %(source)s)""" % locals()
def define_set_data(self):
""" Returns C code used to define any macros used in the
set data argument to the helper C code. """
return ""
def infer_shape(self, node, shapes):
return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论