提交 3b4d7a08 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6316 from gvtulder/f-unravel-ravel-index

Add unravel_index and ravel_multi_index
......@@ -43,7 +43,7 @@ from theano.gradient import Rop, Lop, grad, numeric_grad, verify_grad, \
from theano.tensor.sort import sort, argsort
from theano.tensor.extra_ops import (DiffOp, bincount, squeeze,
repeat, bartlett, fill_diagonal, fill_diagonal_offset,
cumsum, cumprod)
cumsum, cumprod, unravel_index, ravel_multi_index)
# SpecifyShape is defined in theano.compile, but should be available in tensor
from theano.compile import SpecifyShape, specify_shape
......@@ -1182,3 +1182,168 @@ class Unique(theano.Op):
ret[1] = shape
return ret
return ret
class UnravelIndex(gof.Op):
__props__ = ('ndim', 'order')
def __init__(self, ndim, order='C'):
assert order in ('C', 'F')
if not isinstance(ndim, int) or ndim < 1:
raise ValueError('ndim must be an integer greater than 0')
self.ndim = int(ndim)
self.order = order
def make_node(self, indices, dims):
indices = basic.as_tensor_variable(indices)
dims = basic.as_tensor_variable(dims)
if indices.dtype not in basic.int_dtypes:
raise TypeError("'%s' object cannot be interpreted as an index" % str(indices.dtype))
if dims.dtype not in basic.int_dtypes:
raise TypeError("'%s' object cannot be interpreted as an index" % str(dims.dtype))
if dims.ndim != 1:
raise TypeError("dims must be a 1D array")
return gof.Apply(
self, [indices, dims],
[basic.TensorType(dtype='int64', broadcastable=(False,) * indices.ndim)()
for i in xrange(self.ndim)])
def infer_shape(self, node, input_shapes):
return [input_shapes[0]] * len(node.outputs)
def perform(self, node, inp, out):
indices, dims = inp
res = np.unravel_index(indices, dims)
assert len(res) == len(out)
for i in xrange(len(out)):
out[i][0] = res[i]
def unravel_index(indices, dims, order='C', ndim=None):
"""
Converts a flat index or array of flat indices into a tuple
of coordinate arrays.
This method is similar to the NumPy version, except for the
additional ``ndim`` parameter. This parameter is required if
the length of ``dims`` cannot be determined automatically.
Parameters
----------
indices : Theano or NumPy array
An integer array whose elements are indices into the flattened
version of an array of dimensions ``dims``.
dims : tuple of ints
The shape of the array to use for unraveling ``indices``.
order : {'C', 'F'}, optional
Determines whether the indices should be viewed as indexing in
row-major (C-style) or column-major (Fortran-style) order.
ndim : int, optional
Specifies the number of dimensions, i.e., the length of
``dims``. This is required if the dimensions cannot be determined
automatically from ``dims`` itself.
Returns
-------
unraveled_coords : tuple of ndarray
Each array in the tuple has the same shape as the ``indices``
array.
See Also
--------
ravel_multi_index
"""
if ndim is None:
try:
ndim = basic.get_vector_length(dims)
except ValueError:
raise ValueError(
"The length of the provided dimension list (%s) cannot "
"be automatically determined, so Theano is not able "
"to know what the number of dimensions of the unraveled "
"index will be. You can provide the 'ndim' keyword "
"argument to 'unravel_index' to avoid this problem." % str(dims))
res = UnravelIndex(ndim=ndim, order=order)(indices, dims)
if ndim == 1:
return (res,)
else:
return tuple(res)
class RavelMultiIndex(gof.Op):
__props__ = ('mode', 'order')
def __init__(self, mode='raise', order='C'):
assert mode in ('raise', 'wrap', 'clip')
assert order in ('C', 'F')
self.mode = mode
self.order = order
def make_node(self, *inp):
multi_index = [basic.as_tensor_variable(i) for i in inp[:-1]]
dims = basic.as_tensor_variable(inp[-1])
for i in multi_index:
if i.dtype not in basic.int_dtypes:
raise TypeError("'%s' object cannot be interpreted as an index" % str(i.dtype))
if dims.dtype not in basic.int_dtypes:
raise TypeError("'%s' object cannot be interpreted as an index" % str(dims.dtype))
if dims.ndim != 1:
raise TypeError("dims must be a 1D array")
return gof.Apply(
self, multi_index + [dims],
[basic.TensorType(dtype='int64', broadcastable=(False,) * multi_index[0].ndim)()])
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
def perform(self, node, inp, out):
multi_index, dims = inp[:-1], inp[-1]
out[0][0] = np.ravel_multi_index(multi_index, dims,
mode=self.mode, order=self.order)
def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
"""
Converts a tuple of index arrays into an array of flat
indices, applying boundary modes to the multi-index.
Parameters
----------
multi_index : tuple of Theano or NumPy arrays
A tuple of integer arrays, one array for each dimension.
dims : tuple of ints
The shape of array into which the indices from ``multi_index`` apply.
mode : {'raise', 'wrap', 'clip'}, optional
Specifies how out-of-bounds indices are handled. Can specify
either one mode or a tuple of modes, one mode per index.
* 'raise' -- raise an error (default)
* 'wrap' -- wrap around
* 'clip' -- clip to the range
In 'clip' mode, a negative index which would normally
wrap will clip to 0 instead.
order : {'C', 'F'}, optional
Determines whether the multi-index should be viewed as
indexing in row-major (C-style) or column-major
(Fortran-style) order.
Returns
-------
raveled_indices : Theano array
An array of indices into the flattened version of an array
of dimensions ``dims``.
See Also
--------
unravel_index
"""
if not isinstance(multi_index, (tuple, list)):
raise TypeError('multi_index must be a tuple or a list.')
args = tuple(multi_index) + (dims,)
return RavelMultiIndex(mode=mode, order=order)(*args)
......@@ -12,7 +12,8 @@ from theano.tensor.extra_ops import (SearchsortedOp, searchsorted,
RepeatOp, repeat, Bartlett, bartlett,
FillDiagonal, fill_diagonal,
FillDiagonalOffset, fill_diagonal_offset,
to_one_hot, Unique)
to_one_hot, Unique, unravel_index, UnravelIndex,
ravel_multi_index, RavelMultiIndex)
from theano import tensor as T
from theano import config, tensor, function
from theano.tests.unittest_tools import attr
......@@ -755,3 +756,125 @@ class test_Unique(utt.InferShapeTester):
[np.asarray(np.array([[2, 1], [3, 2], [2, 3]]),
dtype=config.floatX)],
self.op_class)
class test_unravel_index(utt.InferShapeTester):
def test_unravel_index(self):
def check(shape, index_ndim, order):
indices = np.arange(np.product(shape))
# test with scalars and higher-dimensional indices
if index_ndim == 0:
indices = indices[-1]
elif index_ndim == 2:
indices = indices[:, np.newaxis]
indices_symb = theano.shared(indices)
# reference result
ref = np.unravel_index(indices, shape)
def fn(i, d, nd=None):
if nd is None:
return function([], unravel_index(i, d, order=order))
else:
return function([], unravel_index(i, d, order=order, ndim=nd))
# shape given as a tuple
f_array_tuple = fn(indices, shape)
f_symb_tuple = fn(indices_symb, shape)
np.testing.assert_equal(ref, f_array_tuple())
np.testing.assert_equal(ref, f_symb_tuple())
# shape given as an array
shape_array = np.array(shape)
f_array_array = fn(indices, shape_array)
np.testing.assert_equal(ref, f_array_array())
# shape given as a theano variable
shape_symb = theano.shared(shape_array)
f_array_symb = fn(indices, shape_symb, len(shape))
np.testing.assert_equal(ref, f_array_symb())
# shape given as a Shape op (unravel_index will use get_vector_length
# to infer the number of dimensions)
indexed_array = theano.shared(np.random.uniform(size=shape_array))
f_array_shape = fn(indices, indexed_array.shape)
np.testing.assert_equal(ref, f_array_shape())
# shape testing
self._compile_and_check([],
unravel_index(indices, shape_symb, order=order, ndim=len(shape)),
[], UnravelIndex)
for order in ('C', 'F'):
for index_ndim in (0, 1, 2):
check((3,), index_ndim, order)
check((3, 4), index_ndim, order)
check((3, 4, 5), index_ndim, order)
# must specify ndim if length of dims is not fixed
self.assertRaises(ValueError, unravel_index, theano.tensor.ivector(), theano.tensor.ivector())
# must provide integers
self.assertRaises(TypeError, unravel_index, theano.tensor.fvector(), (3, 4))
self.assertRaises(TypeError, unravel_index, (3, 4), (3.4, 3.2))
self.assertRaises(ValueError, unravel_index, (3, 4), (3, 3), ndim=5.4)
# dims must be a 1D sequence
self.assertRaises(TypeError, unravel_index, (3, 4), 3)
self.assertRaises(TypeError, unravel_index, (3, 4), ((3, 4),))
class test_ravel_multi_index(utt.InferShapeTester):
def test_ravel_multi_index(self):
def check(shape, index_ndim, mode, order):
multi_index = np.unravel_index(np.arange(np.product(shape)), shape, order=order)
# create some invalid indices to test the mode
if mode in ('wrap', 'clip'):
multi_index = (multi_index[0] - 1,) + multi_index[1:]
# test with scalars and higher-dimensional indices
if index_ndim == 0:
multi_index = tuple(i[-1] for i in multi_index)
elif index_ndim == 2:
multi_index = tuple(i[:, np.newaxis] for i in multi_index)
multi_index_symb = [theano.shared(i) for i in multi_index]
# reference result
ref = np.ravel_multi_index(multi_index, shape, mode, order)
def fn(mi, s):
return function([], ravel_multi_index(mi, s, mode, order))
# shape given as a tuple
f_array_tuple = fn(multi_index, shape)
f_symb_tuple = fn(multi_index_symb, shape)
np.testing.assert_equal(ref, f_array_tuple())
np.testing.assert_equal(ref, f_symb_tuple())
# shape given as an array
shape_array = np.array(shape)
f_array_array = fn(multi_index, shape_array)
np.testing.assert_equal(ref, f_array_array())
# shape given as a theano variable
shape_symb = theano.shared(shape_array)
f_array_symb = fn(multi_index, shape_symb)
np.testing.assert_equal(ref, f_array_symb())
# shape testing
self._compile_and_check([],
[ravel_multi_index(multi_index, shape_symb, mode, order)],
[], RavelMultiIndex)
for mode in ('raise', 'wrap', 'clip'):
for order in ('C', 'F'):
for index_ndim in (0, 1, 2):
check((3,), index_ndim, mode, order)
check((3, 4), index_ndim, mode, order)
check((3, 4, 5), index_ndim, mode, order)
# must provide integers
self.assertRaises(TypeError, ravel_multi_index, (theano.tensor.fvector(), theano.tensor.ivector()), (3, 4))
self.assertRaises(TypeError, ravel_multi_index, ((3, 4), theano.tensor.ivector()), (3.4, 3.2))
# dims must be a 1D sequence
self.assertRaises(TypeError, ravel_multi_index, ((3, 4),), ((3, 4),))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论