提交 6850329e authored 作者: Marc-Alexandre Cote's avatar Marc-Alexandre Cote

Added the cumsum function similar to numpy's one.

上级 e8f6cb73
...@@ -8,6 +8,73 @@ tensor = basic ...@@ -8,6 +8,73 @@ tensor = basic
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
class CumsumOp(theano.Op):
# See function cumsum for docstring
def __init__(self, axis=None):
self.axis = axis
def __eq__(self, other):
return (type(self) == type(other) and
self.axis == other.axis)
def __hash__(self):
return hash(type(self)) ^ hash(self.axis)
def make_node(self, x):
x = basic.as_tensor_variable(x)
out_type = x.type()
if self.axis is None:
out_type = theano.tensor.vector(dtype=x.dtype) # Flatten
return theano.Apply(self, [x], [out_type])
def perform(self, node, inputs, output_storage):
x = inputs[0]
z = output_storage[0]
z[0] = np.cumsum(x, axis=self.axis)
def grad(self, inputs, output_gradients):
[gi] = output_gradients
gi = theano.printing.Print("Grad")(gi)
if self.axis is None:
return [cumsum(gi[::-1])[::-1].reshape(inputs[0].shape)]
# from ipdb import set_trace as dbg
# dbg()
# We need to reverse the gradients along ``self.axis``,
# compute cumsum, then reverse again
reverse_slicing = [slice(None,None,None)] * gi.ndim
reverse_slicing[self.axis] = slice(None,None,-1)
reverse_slicing = tuple(reverse_slicing)
return [cumsum(gi[reverse_slicing], self.axis)[reverse_slicing]]
def infer_shape(self, node, shapes):
if self.axis is None:
return [(np.prod(shapes[0]),)] # Flatten
return shapes
def __str__(self):
return self.__class__.__name__
def cumsum(x, axis=None):
"""Return the cumulative sum of the elements along a given axis.
Wraping of numpy.cumsum.
:param x: Input tensor variable.
:param axis: The axis along which the cumulative sum is computed.
The default (None) is to compute the cumsum over the flattened array.
.. versionadded:: 0.6.1
"""
return CumsumOp(axis=axis)(x)
class DiffOp(theano.Op): class DiffOp(theano.Op):
# See function diff for docstring # See function diff for docstring
def __init__(self, n=1, axis=-1): def __init__(self, n=1, axis=-1):
......
...@@ -3,7 +3,7 @@ import numpy ...@@ -3,7 +3,7 @@ import numpy
import theano import theano
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.extra_ops import (BinCountOp, bincount, DiffOp, diff, from theano.tensor.extra_ops import (CumsumOp, cumsum, BinCountOp, bincount, DiffOp, diff,
squeeze, RepeatOp, repeat, Bartlett, bartlett, squeeze, RepeatOp, repeat, Bartlett, bartlett,
FillDiagonal, fill_diagonal) FillDiagonal, fill_diagonal)
from theano import tensor as T from theano import tensor as T
...@@ -13,6 +13,48 @@ from theano import config, tensor, function ...@@ -13,6 +13,48 @@ from theano import config, tensor, function
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]] numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
numpy_16 = bool(numpy_ver >= [1, 6]) numpy_16 = bool(numpy_ver >= [1, 6])
class TestCumsumOp(utt.InferShapeTester):
def setUp(self):
super(TestCumsumOp, self).setUp()
self.op_class = CumsumOp
self.op = CumsumOp()
def test_cumsumOp(self):
x = T.tensor3('x')
a = np.random.random((30, 50, 20)).astype(config.floatX)
f = theano.function([x], cumsum(x))
assert np.allclose(np.cumsum(a), f(a)) # Test axis=None
for axis in range(len(a.shape)):
f = theano.function([x], cumsum(x, axis=axis))
assert np.allclose(np.cumsum(a, axis=axis), f(a))
def test_infer_shape(self):
x = T.tensor3('x')
a = np.random.random((30, 50, 20)).astype(config.floatX)
# Test axis=None
self._compile_and_check([x],
[self.op(x)],
[a],
self.op_class)
for axis in range(len(a.shape)):
self._compile_and_check([x],
[cumsum(x, axis=axis)],
[a],
self.op_class)
def test_grad(self):
a = np.random.random((3, 5, 2)).astype(config.floatX)
utt.verify_grad(self.op, [a]) # Test axis=None
for axis in range(len(a.shape)):
utt.verify_grad(CumsumOp(axis=axis), [a])
class TestBinCountOp(utt.InferShapeTester): class TestBinCountOp(utt.InferShapeTester):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论