提交 88638cb5 authored 作者: Frederic's avatar Frederic

Add direct GPU min/max reduce for pattern 1, 11, 111, 1111, by by reshape for 11111*.

Also add test that other GPU min/max that are done by reshape.
上级 1780d27f
...@@ -1052,7 +1052,6 @@ class GpuCAReduce(GpuOp): ...@@ -1052,7 +1052,6 @@ class GpuCAReduce(GpuOp):
is for the case where we are reducing on all axes and x is is for the case where we are reducing on all axes and x is
C contiguous. C contiguous.
""" """
self._op_guard()
print >> sio, """ print >> sio, """
{ {
if(CudaNdarray_SIZE(%(x)s)==0){ if(CudaNdarray_SIZE(%(x)s)==0){
...@@ -1092,7 +1091,6 @@ class GpuCAReduce(GpuOp): ...@@ -1092,7 +1091,6 @@ class GpuCAReduce(GpuOp):
""" % locals() """ % locals()
def c_code_reduce_1(self, sio, node, name, x, z, fail): def c_code_reduce_1(self, sio, node, name, x, z, fail):
self._op_guard()
makecall = self._makecall(node, name, x, z, fail) makecall = self._makecall(node, name, x, z, fail)
print >> sio, """ print >> sio, """
{ {
...@@ -1106,7 +1104,6 @@ class GpuCAReduce(GpuOp): ...@@ -1106,7 +1104,6 @@ class GpuCAReduce(GpuOp):
""" % locals() """ % locals()
def c_code_reduce_11(self, sio, node, name, x, z, fail): def c_code_reduce_11(self, sio, node, name, x, z, fail):
self._op_guard()
makecall = self._makecall(node, name, x, z, fail) makecall = self._makecall(node, name, x, z, fail)
print >> sio, """ print >> sio, """
{ {
...@@ -1450,7 +1447,6 @@ class GpuCAReduce(GpuOp): ...@@ -1450,7 +1447,6 @@ class GpuCAReduce(GpuOp):
""" % locals() """ % locals()
def c_code_reduce_111(self, sio, node, name, x, z, fail): def c_code_reduce_111(self, sio, node, name, x, z, fail):
self._op_guard()
makecall = self._makecall(node, name, x, z, fail) makecall = self._makecall(node, name, x, z, fail)
print >> sio, """ print >> sio, """
{ {
...@@ -1514,7 +1510,6 @@ class GpuCAReduce(GpuOp): ...@@ -1514,7 +1510,6 @@ class GpuCAReduce(GpuOp):
""" % locals() """ % locals()
def c_code_reduce_1111(self, sio, node, name, x, z, fail): def c_code_reduce_1111(self, sio, node, name, x, z, fail):
self._op_guard()
makecall = self._makecall(node, name, x, z, fail) makecall = self._makecall(node, name, x, z, fail)
print >> sio, """ print >> sio, """
{ {
...@@ -1595,10 +1590,20 @@ class GpuCAReduce(GpuOp): ...@@ -1595,10 +1590,20 @@ class GpuCAReduce(GpuOp):
sio = StringIO() sio = StringIO()
nd_in = len(self.reduce_mask) nd_in = len(self.reduce_mask)
if all(i == 1 for i in self.reduce_mask): if all(i == 1 for i in self.reduce_mask):
self._op_guard() if not isinstance(self.scalar_op, (scal.Add,
scal.Maximum,
scal.Minimum)):
raise NotImplementedError()
#this kernel is ok for up to a few thousand elements, but #this kernel is ok for up to a few thousand elements, but
# it only runs on ONE multiprocessor # it only runs on ONE multiprocessor
reducebuf = self._k_reduce_buf('Z[0]', node, nodename, sub = {}) reducebuf = self._k_reduce_buf('Z[0]', node, nodename, sub = {})
reduce_fct = self._assign_reduce(node, nodename, "myresult",
"A[i0]",
{})
if isinstance(self.scalar_op, scal.Add):
reduce_init = "0.f;"
else:
reduce_init = "A[0]"
print >> sio, """ print >> sio, """
static __global__ void kernel_reduce_ccontig_%(nodename)s( static __global__ void kernel_reduce_ccontig_%(nodename)s(
const unsigned int d0, const unsigned int d0,
...@@ -1608,7 +1613,7 @@ class GpuCAReduce(GpuOp): ...@@ -1608,7 +1613,7 @@ class GpuCAReduce(GpuOp):
const int threadCount = blockDim.x; const int threadCount = blockDim.x;
const int threadNum = threadIdx.x; const int threadNum = threadIdx.x;
extern __shared__ float buf[]; extern __shared__ float buf[];
float myresult = 0.0f; float myresult = %(reduce_init)s;
if (warpSize != 32) if (warpSize != 32)
{ {
...@@ -1617,16 +1622,26 @@ class GpuCAReduce(GpuOp): ...@@ -1617,16 +1622,26 @@ class GpuCAReduce(GpuOp):
for (int i0 = threadIdx.x; i0 < d0; i0 += blockDim.x) for (int i0 = threadIdx.x; i0 < d0; i0 += blockDim.x)
{ {
myresult += A[i0]; %(reduce_fct)s
} }
%(reducebuf)s %(reducebuf)s
} }
""" % locals() """ % locals()
if self.reduce_mask == (1,): if self.reduce_mask == (1,):
self._op_guard() if not isinstance(self.scalar_op, (scal.Add,
scal.Maximum,
scal.Minimum)):
raise NotImplementedError()
#this kernel is ok for up to a few thousand elements, but #this kernel is ok for up to a few thousand elements, but
# it only runs on ONE multiprocessor # it only runs on ONE multiprocessor
reducebuf = self._k_reduce_buf('Z[0]', node, nodename, sub = {}) reducebuf = self._k_reduce_buf('Z[0]', node, nodename, sub = {})
reduce_fct = self._assign_reduce(node, nodename, "myresult",
"A[i0 * sA0]",
{})
if isinstance(self.scalar_op, scal.Add):
reduce_init = "0.f;"
else:
reduce_init = "A[0]"
print >> sio, """ print >> sio, """
static __global__ void kernel_reduce_1_%(nodename)s( static __global__ void kernel_reduce_1_%(nodename)s(
const unsigned int d0, const unsigned int d0,
...@@ -1636,7 +1651,7 @@ class GpuCAReduce(GpuOp): ...@@ -1636,7 +1651,7 @@ class GpuCAReduce(GpuOp):
const int threadCount = blockDim.x; const int threadCount = blockDim.x;
const int threadNum = threadIdx.x; const int threadNum = threadIdx.x;
extern __shared__ float buf[]; extern __shared__ float buf[];
float myresult = 0.0f; float myresult = %(reduce_init)s;
if (warpSize != 32) if (warpSize != 32)
{ {
...@@ -1645,17 +1660,26 @@ class GpuCAReduce(GpuOp): ...@@ -1645,17 +1660,26 @@ class GpuCAReduce(GpuOp):
for (int i0 = threadIdx.x; i0 < d0; i0 += blockDim.x) for (int i0 = threadIdx.x; i0 < d0; i0 += blockDim.x)
{ {
float Ai = A[i0 * sA0]; %(reduce_fct)s
myresult += Ai;
} }
%(reducebuf)s %(reducebuf)s
} }
""" % locals() """ % locals()
if self.reduce_mask == (1, 1): if self.reduce_mask == (1, 1):
self._op_guard() if not isinstance(self.scalar_op, (scal.Add,
scal.Maximum,
scal.Minimum)):
raise NotImplementedError()
#this kernel is ok for up to a few thousand elements, but #this kernel is ok for up to a few thousand elements, but
# it only runs on ONE multiprocessor # it only runs on ONE multiprocessor
reducebuf = self._k_reduce_buf('Z[0]', node, nodename, sub = {}) reducebuf = self._k_reduce_buf('Z[0]', node, nodename, sub = {})
reduce_fct = self._assign_reduce(node, nodename, "myresult",
"A[i0 * sA0 + i1 * sA1]",
{})
if isinstance(self.scalar_op, scal.Add):
reduce_init = "0.f;"
else:
reduce_init = "A[0]"
print >> sio, """ print >> sio, """
static __global__ void kernel_reduce_11_%(nodename)s( static __global__ void kernel_reduce_11_%(nodename)s(
const int d0, const int d0,
...@@ -1666,7 +1690,7 @@ class GpuCAReduce(GpuOp): ...@@ -1666,7 +1690,7 @@ class GpuCAReduce(GpuOp):
const int threadCount = blockDim.x * blockDim.y; const int threadCount = blockDim.x * blockDim.y;
const int threadNum = threadIdx.y*blockDim.x + threadIdx.x; const int threadNum = threadIdx.y*blockDim.x + threadIdx.x;
extern __shared__ float buf[]; extern __shared__ float buf[];
float myresult = 0.0f; float myresult = %(reduce_init)s;
if (warpSize != 32) if (warpSize != 32)
{ {
...@@ -1677,8 +1701,7 @@ class GpuCAReduce(GpuOp): ...@@ -1677,8 +1701,7 @@ class GpuCAReduce(GpuOp):
{ {
for (int i1 = threadIdx.x; i1 < d1; i1 += blockDim.x) for (int i1 = threadIdx.x; i1 < d1; i1 += blockDim.x)
{ {
float Ai = A[i0 * sA0 + i1 * sA1]; %(reduce_fct)s;
myresult += Ai;
} }
} }
%(reducebuf)s %(reducebuf)s
...@@ -2003,23 +2026,33 @@ class GpuCAReduce(GpuOp): ...@@ -2003,23 +2026,33 @@ class GpuCAReduce(GpuOp):
} }
""" % locals() """ % locals()
if self.reduce_mask == (1, 1, 1): if self.reduce_mask == (1, 1, 1):
self._op_guard() if not isinstance(self.scalar_op, (scal.Add,
scal.Maximum,
scal.Minimum)):
raise NotImplementedError()
reducebuf = self._k_reduce_buf('Z[0]', node, reducebuf = self._k_reduce_buf('Z[0]', node,
nodename, sub={}) nodename, sub={})
decl = self._k_decl(node, nodename) decl = self._k_decl(node, nodename)
init = self._k_init(node, nodename) init = self._k_init(node, nodename)
reduce_fct = self._assign_reduce(node, nodename, "myresult",
"A[i0 * sA0 + i1 * sA1 + i2 * sA2]",
{})
if isinstance(self.scalar_op, scal.Add):
reduce_init = "0.f;"
else:
reduce_init = "A[0]"
print >> sio, """ print >> sio, """
%(decl)s %(decl)s
{ {
%(init)s %(init)s
myresult = 0; myresult = %(reduce_init)s;
for (int i0 = threadIdx.z; i0 < d0; i0 += blockDim.z) for (int i0 = threadIdx.z; i0 < d0; i0 += blockDim.z)
{ {
for (int i1 = threadIdx.y; i1 < d1; i1 += blockDim.y) for (int i1 = threadIdx.y; i1 < d1; i1 += blockDim.y)
{ {
for (int i2 = threadIdx.x; i2 < d2; i2 += blockDim.x) for (int i2 = threadIdx.x; i2 < d2; i2 += blockDim.x)
{ {
myresult += A[i0 * sA0 + i1 * sA1 + i2 * sA2]; %(reduce_fct)s;
} }
} }
} }
...@@ -2125,16 +2158,26 @@ class GpuCAReduce(GpuOp): ...@@ -2125,16 +2158,26 @@ class GpuCAReduce(GpuOp):
} }
""" % locals() """ % locals()
if self.reduce_mask == (1, 1, 1, 1): if self.reduce_mask == (1, 1, 1, 1):
self._op_guard() if not isinstance(self.scalar_op, (scal.Add,
scal.Maximum,
scal.Minimum)):
raise NotImplementedError()
reducebuf = self._k_reduce_buf('Z[0]', node, nodename, reducebuf = self._k_reduce_buf('Z[0]', node, nodename,
sub = {}) sub = {})
decl = self._k_decl(node, nodename) decl = self._k_decl(node, nodename)
init = self._k_init(node, nodename) init = self._k_init(node, nodename)
reduce_fct = self._assign_reduce(node, nodename, "myresult",
"A[i0 * sA0 + i1 * sA1 + i2 * sA2 + i3 * sA3]",
{})
if isinstance(self.scalar_op, scal.Add):
reduce_init = "0.f;"
else:
reduce_init = "A[0]"
print >> sio, """ print >> sio, """
%(decl)s %(decl)s
{ {
%(init)s %(init)s
myresult = 0; myresult = %(reduce_init)s;
for (int i0 = 0; i0 < d0; i0++) for (int i0 = 0; i0 < d0; i0++)
for (int i1 = threadIdx.z; i1 < d1; i1 += blockDim.z) for (int i1 = threadIdx.z; i1 < d1; i1 += blockDim.z)
{ {
...@@ -2142,7 +2185,7 @@ class GpuCAReduce(GpuOp): ...@@ -2142,7 +2185,7 @@ class GpuCAReduce(GpuOp):
{ {
for (int i3 = threadIdx.x; i3 < d3; i3 += blockDim.x) for (int i3 = threadIdx.x; i3 < d3; i3 += blockDim.x)
{ {
myresult += A[i0 * sA0 + i1 * sA1 + i2 * sA2 + i3 * sA3]; %(reduce_fct)s;
} }
} }
} }
......
...@@ -127,9 +127,17 @@ def test_careduce(): ...@@ -127,9 +127,17 @@ def test_careduce():
#GpuCAReduce{maximum/minimum} support only those patterns #GpuCAReduce{maximum/minimum} support only those patterns
if scalar_op in [theano.scalar.maximum, if scalar_op in [theano.scalar.maximum,
theano.scalar.minimum] and pat not in [ theano.scalar.minimum] and pat not in [
(0, 1), (0, 1, 1), (0, 1, 1), (1, 0), (1,), (1, 1), (0, 1), (1, 0),
(0, 1, 0), (0, 1, 1), (1, 1, 1),
(1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0),
(0, 0, 1, 0), (0, 0, 0, 1)]: (0, 0, 1, 0), (0, 0, 0, 1),
(1, 1, 1, 1), (1, 1, 1, 1, 1),
(0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0),
(0, 0, 1, 1), # by reshape
# (0, 1, 0, 1), #not supported for max/min
(0, 1, 1, 1), # by reshape
#(1, 0, 1, 1) #not supported for max/min
]:
continue continue
a = tensor.TensorType('float32', (False,) * len(shape))() a = tensor.TensorType('float32', (False,) * len(shape))()
...@@ -141,7 +149,8 @@ def test_careduce(): ...@@ -141,7 +149,8 @@ def test_careduce():
f = theano.function([a], b, mode=mode_with_gpu) f = theano.function([a], b, mode=mode_with_gpu)
f2 = theano.function([a], b, mode=mode_without_gpu) f2 = theano.function([a], b, mode=mode_without_gpu)
assert tcn.GpuCAReduce in [x.op.__class__ assert tcn.GpuCAReduce in [x.op.__class__
for x in f.maker.fgraph.toposort()] for x in f.maker.fgraph.toposort()], (
scalar_op, pat)
assert op.__class__ in [x.op.__class__ assert op.__class__ in [x.op.__class__
for x in f2.maker.fgraph.toposort()] for x in f2.maker.fgraph.toposort()]
f_caused_value_error = False f_caused_value_error = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论