提交 95afdbc3 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

IncSubtensor: generalized copy_into functionality

上级 607f2a3e
...@@ -2457,6 +2457,16 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): ...@@ -2457,6 +2457,16 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
""" Return the update_flags string to pass to helper_c_code.""" """ Return the update_flags string to pass to helper_c_code."""
return "" return ""
def copy_into(self, view, source):
"""
view: string, C code expression for an array
source: string, C code expression for an array
returns a C code expression to copy source into view, and
return 0 on success
"""
return """CudaNdarray_CopyFromCudaNdarray(%(view)s, %(source)s)""" % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
# TODO: cooperate with parent class' C code # TODO: cooperate with parent class' C code
return () return ()
......
...@@ -4608,10 +4608,12 @@ class IncSubtensor(Op): ...@@ -4608,10 +4608,12 @@ class IncSubtensor(Op):
update_flags=self.get_update_flags() update_flags=self.get_update_flags()
) )
copy_into = self.copy_into("xview", y)
make_modification = """ make_modification = """
if (%(op_is_set)s) if (%(op_is_set)s)
{ {
if (PyArray_CopyInto(xview, %(y)s)) // does broadcasting if (%(copy_into)s) // does broadcasting
{ {
Py_DECREF(xview); Py_DECREF(xview);
%(fail)s; %(fail)s;
...@@ -4704,6 +4706,16 @@ class IncSubtensor(Op): ...@@ -4704,6 +4706,16 @@ class IncSubtensor(Op):
""" Return the update_flags string to pass to helper c_code.""" """ Return the update_flags string to pass to helper c_code."""
return Subtensor.default_update_flags() return Subtensor.default_update_flags()
def copy_into(self, view, source):
"""
view: string, C code expression for an array
source: string, C code expression for an array
returns a C code expression to copy source into view, and
return 0 on success
"""
return """PyArray_CopyInto(%(view)s, %(source)s)""" % locals()
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论