提交 80130a95 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added op_guard calls to some methods I overlooked

上级 4d4bc1f5
...@@ -923,6 +923,7 @@ class GpuCAReduce(GpuOp): ...@@ -923,6 +923,7 @@ class GpuCAReduce(GpuOp):
#Threads must be organized as: threadNum%nb_reduce correspond to the same sum #Threads must be organized as: threadNum%nb_reduce correspond to the same sum
#nb_reduce<=warpSize #nb_reduce<=warpSize
def _k_reduce_buf_multiple(self, z_pos, nb_reduce): def _k_reduce_buf_multiple(self, z_pos, nb_reduce):
self._op_guard()
return """ return """
__syncthreads(); // some kernel do multiple reduction. __syncthreads(); // some kernel do multiple reduction.
buf[threadNum] = mysum; buf[threadNum] = mysum;
...@@ -947,6 +948,7 @@ class GpuCAReduce(GpuOp): ...@@ -947,6 +948,7 @@ class GpuCAReduce(GpuOp):
is for the case where we are reducing on all axes and x is is for the case where we are reducing on all axes and x is
C contiguous. C contiguous.
""" """
self._op_guard()
print >> sio, """ print >> sio, """
{ {
if(CudaNdarray_SIZE(%(x)s)==0){ if(CudaNdarray_SIZE(%(x)s)==0){
...@@ -986,6 +988,7 @@ class GpuCAReduce(GpuOp): ...@@ -986,6 +988,7 @@ class GpuCAReduce(GpuOp):
""" % locals() """ % locals()
def c_code_reduce_1(self, sio, node, name, x, z, fail): def c_code_reduce_1(self, sio, node, name, x, z, fail):
self._op_guard()
makecall = self._makecall(node, name, x, z, fail) makecall = self._makecall(node, name, x, z, fail)
print >> sio, """ print >> sio, """
{ {
...@@ -999,6 +1002,7 @@ class GpuCAReduce(GpuOp): ...@@ -999,6 +1002,7 @@ class GpuCAReduce(GpuOp):
""" % locals() """ % locals()
def c_code_reduce_11(self, sio, node, name, x, z, fail): def c_code_reduce_11(self, sio, node, name, x, z, fail):
self._op_guard()
makecall = self._makecall(node, name, x, z, fail) makecall = self._makecall(node, name, x, z, fail)
print >> sio, """ print >> sio, """
{ {
...@@ -1021,6 +1025,7 @@ class GpuCAReduce(GpuOp): ...@@ -1021,6 +1025,7 @@ class GpuCAReduce(GpuOp):
:param N: the number of 1 in the pattern N=1 -> 01, N=2 -> 011 N=3 ->0111 :param N: the number of 1 in the pattern N=1 -> 01, N=2 -> 011 N=3 ->0111
Work for N=1,2,3 Work for N=1,2,3
""" """
self._op_guard()
assert N in [1, 2, 3] assert N in [1, 2, 3]
makecall = self._makecall(node, name, x, z, fail) makecall = self._makecall(node, name, x, z, fail)
N_pattern = ''.join(['1'] * N) N_pattern = ''.join(['1'] * N)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论