提交 3774bc48 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

general version of _assign_reduce

上级 a956d035
...@@ -844,9 +844,7 @@ class GpuCAReduce(GpuOp): ...@@ -844,9 +844,7 @@ class GpuCAReduce(GpuOp):
returns C code to reduce left and right, assigning the returns C code to reduce left and right, assigning the
result to left.""" result to left."""
self._op_guard() return self.scalar_op.cuda_assign_reduce(left, right)
return left + " += " + right + ";"
def _k_reduce_buf(self, z_pos): def _k_reduce_buf(self, z_pos):
""" """
......
...@@ -816,6 +816,13 @@ class ScalarOp(Op): ...@@ -816,6 +816,13 @@ class ScalarOp(Op):
else: else:
return self.__class__.__name__ return self.__class__.__name__
def cuda_assign_reduce(self, left, right):
""" Returns CUDA code assigning the reduction of left and right
using this scalar operation to left."""
raise NotImplementedError(
str(self)+" does not implement cuda_assign_reduce")
def c_code_cache_version(self): def c_code_cache_version(self):
return (4,) return (4,)
...@@ -1159,6 +1166,9 @@ class Maximum(BinaryScalarOp): ...@@ -1159,6 +1166,9 @@ class Maximum(BinaryScalarOp):
gx = eq(output, x) * gz gx = eq(output, x) * gz
gy = eq(output, y) * gz gy = eq(output, y) * gz
return (gx, gy) return (gx, gy)
def cuda_assign_reduce(self, left, right):
return left + ' = max(' + left + ', '+ right + ');'
maximum = Maximum(upcast_out, name='maximum') maximum = Maximum(upcast_out, name='maximum')
...@@ -1187,7 +1197,6 @@ class Minimum(BinaryScalarOp): ...@@ -1187,7 +1197,6 @@ class Minimum(BinaryScalarOp):
gx = eq(output, x) * gz gx = eq(output, x) * gz
gy = eq(output, y) * gz gy = eq(output, y) * gz
return (gx, gy) return (gx, gy)
minimum = Minimum(upcast_out, name='minimum') minimum = Minimum(upcast_out, name='minimum')
...@@ -1222,6 +1231,10 @@ class Add(ScalarOp): ...@@ -1222,6 +1231,10 @@ class Add(ScalarOp):
for i in inputs: for i in inputs:
retval += [gz] retval += [gz]
return retval return retval
def cuda_assign_reduce(self, left, right):
return left + ' += ' + right + ';'
add = Add(upcast_out, name='add') add = Add(upcast_out, name='add')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论