提交 607f2a3e authored 作者: Ian Goodfellow's avatar Ian Goodfellow

IncSubtensor: generalized selection of update flags

上级 b3c674e1
...@@ -2431,12 +2431,32 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp): ...@@ -2431,12 +2431,32 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
raise NotImplementedError() raise NotImplementedError()
def copy_of_x(self, x): def copy_of_x(self, x):
"""
x: a string giving the name of a C variable pointing to an array
Returns C code expression to make a copy of x.
Base class uses PyArrayObject *, subclasses may override for
different types of arrays.
"""
return """(CudaNdarray*) CudaNdarray_Copy(%(x)s)""" % locals() return """(CudaNdarray*) CudaNdarray_Copy(%(x)s)""" % locals()
def make_view_buffer(self, x, view_ndim): def make_view_buffer(self, x, view_ndim):
"""
x: a string identifying an array to be viewed
view_ndim: a string specifying the number of dimensions
to have in the view
This doesn't need to actually set up the view with the
right indexing; we'll do that manually later.
"""
return """CudaNdarray* xview = (CudaNdarray*) return """CudaNdarray* xview = (CudaNdarray*)
CudaNdarray_New(%(view_ndim)s)""" % locals() CudaNdarray_New(%(view_ndim)s)""" % locals()
def get_update_flags(self):
""" Return the update_flags string to pass to helper_c_code."""
return ""
def c_code_cache_version(self): def c_code_cache_version(self):
# TODO: cooperate with parent class' C code # TODO: cooperate with parent class' C code
return () return ()
......
...@@ -4598,9 +4598,15 @@ class IncSubtensor(Op): ...@@ -4598,9 +4598,15 @@ class IncSubtensor(Op):
} }
""" % locals() """ % locals()
# make xview actually a view of %(z)s # make xview actually a view of %(z)s
get_xview = Subtensor.helper_c_code(node, name, get_xview = Subtensor.helper_c_code(
outputs[:1] + inputs[2:], node=node,
outputs, sub, self.idx_list) name=name,
inputs=outputs[:1] + inputs[2:],
outputs=outputs,
sub=sub,
idx_list=self.idx_list,
update_flags=self.get_update_flags()
)
make_modification = """ make_modification = """
if (%(op_is_set)s) if (%(op_is_set)s)
...@@ -4694,6 +4700,10 @@ class IncSubtensor(Op): ...@@ -4694,6 +4700,10 @@ class IncSubtensor(Op):
%(x)s->flags, %(x)s->flags,
NULL)""" % locals() NULL)""" % locals()
def get_update_flags(self):
""" Return the update_flags string to pass to helper c_code."""
return Subtensor.default_update_flags()
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论