提交 e85ae3bd authored 作者: abergeron's avatar abergeron

Merge pull request #2716 from MarcCote/fix_cumsum_negative_axis

Added support for negative axis in GpuCumsumOp
...@@ -25,8 +25,8 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -25,8 +25,8 @@ class GpuCumsum(CumsumOp, GpuOp):
self.max_grid_size1 = None self.max_grid_size1 = None
self.max_grid_size2 = None self.max_grid_size2 = None
# We must reuse the same method, not reimplement and call it. # We must reuse the same method, not reimplement and call it.
# Otherwise DebugMode will print many warnings. # Otherwise DebugMode will print many warnings.
perform = Op.perform perform = Op.perform
def make_node(self, x): def make_node(self, x):
...@@ -37,7 +37,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -37,7 +37,7 @@ class GpuCumsum(CumsumOp, GpuOp):
if x.ndim > GpuCumsum.SUPPORTED_NDIMS: if x.ndim > GpuCumsum.SUPPORTED_NDIMS:
raise NotImplementedError('Only cumsum on 1D, 2D and 3D array are supported right now!') raise NotImplementedError('Only cumsum on 1D, 2D and 3D array are supported right now!')
if self.axis >= x.ndim: if self.axis >= x.ndim or self.axis < -x.ndim:
raise ValueError('axis(={1}) out of bounds'.format(self.axis)) raise ValueError('axis(={1}) out of bounds'.format(self.axis))
return theano.Apply(self, [x], [x.type()]) return theano.Apply(self, [x], [x.type()])
...@@ -69,7 +69,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -69,7 +69,7 @@ class GpuCumsum(CumsumOp, GpuOp):
return "%s{%s}" % (self.__class__.__name__, self.axis) return "%s{%s}" % (self.__class__.__name__, self.axis)
def c_code_cache_version(self): def c_code_cache_version(self):
return (7,) return (8,)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
return """ return """
...@@ -352,6 +352,8 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -352,6 +352,8 @@ class GpuCumsum(CumsumOp, GpuOp):
def c_code(self, node, nodename, inames, onames, sub): def c_code(self, node, nodename, inames, onames, sub):
x, = inames x, = inames
z, = onames z, = onames
# We assume array has been already flattened if needed.
axis = self.axis if self.axis is not None else 0 axis = self.axis if self.axis is not None else 0
fail = sub['fail'] fail = sub['fail']
...@@ -368,6 +370,12 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -368,6 +370,12 @@ class GpuCumsum(CumsumOp, GpuOp):
const int* shape = CudaNdarray_HOST_DIMS(%(x)s); const int* shape = CudaNdarray_HOST_DIMS(%(x)s);
bool needAllocation = !%(z)s || CudaNdarray_NDIM(%(x)s) != CudaNdarray_NDIM(%(z)s); bool needAllocation = !%(z)s || CudaNdarray_NDIM(%(x)s) != CudaNdarray_NDIM(%(z)s);
int axis = %(axis)s;
if (axis < 0) {
// Convert negative axis to positive axis.
axis += CudaNdarray_NDIM(%(x)s);
}
// If output is already allocated, check if its shape matches the input's one. // If output is already allocated, check if its shape matches the input's one.
if (!needAllocation) { if (!needAllocation) {
for (int i= 0; i < CudaNdarray_NDIM(%(x)s); ++i) { for (int i= 0; i < CudaNdarray_NDIM(%(x)s); ++i) {
...@@ -387,7 +395,7 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -387,7 +395,7 @@ class GpuCumsum(CumsumOp, GpuOp):
} }
{ // Namespace for kernel calls // { // Namespace for kernel calls //
if (cumSum_%(nodename)s(%(x)s, %(z)s, %(axis)s, %(max_threads_dim0)s, %(max_grid_size1)s, %(max_grid_size2)s) == -1){ if (cumSum_%(nodename)s(%(x)s, %(z)s, axis, %(max_threads_dim0)s, %(max_grid_size1)s, %(max_grid_size2)s) == -1){
%(fail)s; %(fail)s;
} }
...@@ -408,11 +416,10 @@ class GpuCumsum(CumsumOp, GpuOp): ...@@ -408,11 +416,10 @@ class GpuCumsum(CumsumOp, GpuOp):
def values_eq_approx_high_tol(a, b): def values_eq_approx_high_tol(a, b):
"""This fct is needed to don't have DebugMode raise useless """This fct is needed to don't have DebugMode raise useless
error due to ronding error. error due to rounding error.
This happen with big input size due to change in the order of This happen with big input size due to change in the order of
operation. operation.
""" """
rtol = None rtol = None
if a.size > 100000: if a.size > 100000:
...@@ -443,6 +450,7 @@ def use_gpu_cumsum(node): ...@@ -443,6 +450,7 @@ def use_gpu_cumsum(node):
# ``gpu_cumsum`` assume array has been flattened if needed. # ``gpu_cumsum`` assume array has been flattened if needed.
if axis is None: if axis is None:
axis = 0 axis = 0
ret = host_from_gpu(GpuCumsum(axis)(x)) ret = host_from_gpu(GpuCumsum(axis)(x))
ret.values_eq_approx = values_eq_approx_high_tol ret.values_eq_approx = values_eq_approx_high_tol
return [ret] return [ret]
...@@ -47,7 +47,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp): ...@@ -47,7 +47,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
def test_Strides1D(self): def test_Strides1D(self):
x = T.fvector('x') x = T.fvector('x')
for axis in [0, None]: for axis in [0, None, -1]:
a = np.random.random((42,)).astype("float32") a = np.random.random((42,)).astype("float32")
cumsum_function = theano.function([x], cumsum(x, axis=axis), cumsum_function = theano.function([x], cumsum(x, axis=axis),
mode=self.mode) mode=self.mode)
...@@ -70,7 +70,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp): ...@@ -70,7 +70,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
def test_Strides2D(self): def test_Strides2D(self):
x = T.fmatrix('x') x = T.fmatrix('x')
for axis in [0, 1, None]: for axis in [0, 1, None, -1, -2]:
a = np.random.random((42, 30)).astype("float32") a = np.random.random((42, 30)).astype("float32")
cumsum_function = theano.function([x], cumsum(x, axis=axis), cumsum_function = theano.function([x], cumsum(x, axis=axis),
mode=self.mode) mode=self.mode)
...@@ -93,7 +93,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp): ...@@ -93,7 +93,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
def test_Strides3D(self): def test_Strides3D(self):
x = T.ftensor3('x') x = T.ftensor3('x')
for axis in [0, 1, 2, None]: for axis in [0, 1, 2, None, -1, -2, -3]:
a = np.random.random((42, 30, 25)).astype("float32") a = np.random.random((42, 30, 25)).astype("float32")
cumsum_function = theano.function([x], cumsum(x, axis=axis), cumsum_function = theano.function([x], cumsum(x, axis=axis),
mode=self.mode) mode=self.mode)
...@@ -139,7 +139,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp): ...@@ -139,7 +139,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
block_max_size = self.max_threads_dim0 * 2 block_max_size = self.max_threads_dim0 * 2
x = T.fmatrix('x') x = T.fmatrix('x')
for shape_axis, axis in zip([0, 1, 0], [0, 1, None]): for shape_axis, axis in zip([0, 1, 0, 1, 0], [0, 1, None, -1, -2]):
f = theano.function([x], cumsum(x, axis=axis), mode=self.mode) f = theano.function([x], cumsum(x, axis=axis), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort() assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)] if isinstance(n.op, GpuCumsum)]
...@@ -178,7 +178,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp): ...@@ -178,7 +178,7 @@ class TestGpuCumsum(theano.tensor.tests.test_extra_ops.TestCumsumOp):
block_max_size = self.max_threads_dim0 * 2 block_max_size = self.max_threads_dim0 * 2
x = T.ftensor3('x') x = T.ftensor3('x')
for shape_axis, axis in zip([0, 1, 2, 0], [0, 1, 2, None]): for shape_axis, axis in zip([0, 1, 2, 0, 2, 1, 0], [0, 1, 2, None, -1, -2, -3]):
f = theano.function([x], cumsum(x, axis=axis), mode=self.mode) f = theano.function([x], cumsum(x, axis=axis), mode=self.mode)
assert [n for n in f.maker.fgraph.toposort() assert [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, GpuCumsum)] if isinstance(n.op, GpuCumsum)]
......
...@@ -29,7 +29,7 @@ class CumsumOp(theano.Op): ...@@ -29,7 +29,7 @@ class CumsumOp(theano.Op):
if self.axis is None: if self.axis is None:
out_type = theano.tensor.vector(dtype=x.dtype) # Flatten out_type = theano.tensor.vector(dtype=x.dtype) # Flatten
elif self.axis >= x.ndim: elif self.axis >= x.ndim or self.axis < -x.ndim:
raise ValueError('axis(={0}) out of bounds'.format(self.axis)) raise ValueError('axis(={0}) out of bounds'.format(self.axis))
return theano.Apply(self, [x], [out_type]) return theano.Apply(self, [x], [out_type])
...@@ -151,7 +151,7 @@ class CumprodOp(theano.Op): ...@@ -151,7 +151,7 @@ class CumprodOp(theano.Op):
if self.axis is None: if self.axis is None:
out_type = theano.tensor.vector(dtype=x.dtype) # Flatten out_type = theano.tensor.vector(dtype=x.dtype) # Flatten
elif self.axis >= x.ndim: elif self.axis >= x.ndim or self.axis < -x.ndim:
raise ValueError('axis(={0}) out of bounds'.format(self.axis)) raise ValueError('axis(={0}) out of bounds'.format(self.axis))
return theano.Apply(self, [x], [out_type]) return theano.Apply(self, [x], [out_type])
......
...@@ -32,12 +32,13 @@ class TestCumsumOp(utt.InferShapeTester): ...@@ -32,12 +32,13 @@ class TestCumsumOp(utt.InferShapeTester):
a = np.random.random((3, 5, 2)).astype(config.floatX) a = np.random.random((3, 5, 2)).astype(config.floatX)
# Test axis out of bounds # Test axis out of bounds
self.assertRaises(ValueError, cumsum, x, axis=4) self.assertRaises(ValueError, cumsum, x, axis=3)
self.assertRaises(ValueError, cumsum, x, axis=-4)
f = theano.function([x], cumsum(x)) f = theano.function([x], cumsum(x))
assert np.allclose(np.cumsum(a), f(a)) # Test axis=None assert np.allclose(np.cumsum(a), f(a)) # Test axis=None
for axis in range(len(a.shape)): for axis in range(-len(a.shape), len(a.shape)):
f = theano.function([x], cumsum(x, axis=axis)) f = theano.function([x], cumsum(x, axis=axis))
assert np.allclose(np.cumsum(a, axis=axis), f(a)) assert np.allclose(np.cumsum(a, axis=axis), f(a))
...@@ -51,7 +52,7 @@ class TestCumsumOp(utt.InferShapeTester): ...@@ -51,7 +52,7 @@ class TestCumsumOp(utt.InferShapeTester):
[a], [a],
self.op_class) self.op_class)
for axis in range(len(a.shape)): for axis in range(-len(a.shape), len(a.shape)):
self._compile_and_check([x], self._compile_and_check([x],
[cumsum(x, axis=axis)], [cumsum(x, axis=axis)],
[a], [a],
...@@ -62,7 +63,7 @@ class TestCumsumOp(utt.InferShapeTester): ...@@ -62,7 +63,7 @@ class TestCumsumOp(utt.InferShapeTester):
utt.verify_grad(self.op, [a]) # Test axis=None utt.verify_grad(self.op, [a]) # Test axis=None
for axis in range(len(a.shape)): for axis in range(-len(a.shape), len(a.shape)):
utt.verify_grad(self.op_class(axis=axis), [a], eps=4e-4) utt.verify_grad(self.op_class(axis=axis), [a], eps=4e-4)
...@@ -77,10 +78,14 @@ class TestCumprodOp(utt.InferShapeTester): ...@@ -77,10 +78,14 @@ class TestCumprodOp(utt.InferShapeTester):
x = T.tensor3('x') x = T.tensor3('x')
a = np.random.random((3, 5, 2)).astype(config.floatX) a = np.random.random((3, 5, 2)).astype(config.floatX)
# Test axis out of bounds
self.assertRaises(ValueError, cumprod, x, axis=3)
self.assertRaises(ValueError, cumprod, x, axis=-4)
f = theano.function([x], cumprod(x)) f = theano.function([x], cumprod(x))
assert np.allclose(np.cumprod(a), f(a)) # Test axis=None assert np.allclose(np.cumprod(a), f(a)) # Test axis=None
for axis in range(len(a.shape)): for axis in range(-len(a.shape), len(a.shape)):
f = theano.function([x], cumprod(x, axis=axis)) f = theano.function([x], cumprod(x, axis=axis))
assert np.allclose(np.cumprod(a, axis=axis), f(a)) assert np.allclose(np.cumprod(a, axis=axis), f(a))
...@@ -94,7 +99,7 @@ class TestCumprodOp(utt.InferShapeTester): ...@@ -94,7 +99,7 @@ class TestCumprodOp(utt.InferShapeTester):
[a], [a],
self.op_class) self.op_class)
for axis in range(len(a.shape)): for axis in range(-len(a.shape), len(a.shape)):
self._compile_and_check([x], self._compile_and_check([x],
[cumprod(x, axis=axis)], [cumprod(x, axis=axis)],
[a], [a],
...@@ -105,7 +110,7 @@ class TestCumprodOp(utt.InferShapeTester): ...@@ -105,7 +110,7 @@ class TestCumprodOp(utt.InferShapeTester):
utt.verify_grad(self.op, [a]) # Test axis=None utt.verify_grad(self.op, [a]) # Test axis=None
for axis in range(len(a.shape)): for axis in range(-len(a.shape), len(a.shape)):
utt.verify_grad(self.op_class(axis=axis), [a]) utt.verify_grad(self.op_class(axis=axis), [a])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论