提交 eb5565b0 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

unravel_index infers the length of Shape.

上级 bdceabfd
...@@ -2630,8 +2630,7 @@ def unravel_index(indices, dims, order='C', ndim=None): ...@@ -2630,8 +2630,7 @@ def unravel_index(indices, dims, order='C', ndim=None):
This method is similar to the NumPy version, except for the This method is similar to the NumPy version, except for the
additional ``ndim`` parameter. This parameter is required if additional ``ndim`` parameter. This parameter is required if
the length of ``dims`` cannot be determined automatically, for the length of ``dims`` cannot be determined automatically.
example because ``dims`` is a Theano vector.
For example: For example:
...@@ -2652,7 +2651,7 @@ def unravel_index(indices, dims, order='C', ndim=None): ...@@ -2652,7 +2651,7 @@ def unravel_index(indices, dims, order='C', ndim=None):
ndim : int, optional ndim : int, optional
Specifies the number of dimensions, i.e., the length of Specifies the number of dimensions, i.e., the length of
``dims``. This is required if the dimensions cannot be determined ``dims``. This is required if the dimensions cannot be determined
from ``dims`` itself, for example, if ``dims`` is a Theano vector. automatically from ``dims`` itself.
Returns Returns
------- -------
...@@ -2666,15 +2665,15 @@ def unravel_index(indices, dims, order='C', ndim=None): ...@@ -2666,15 +2665,15 @@ def unravel_index(indices, dims, order='C', ndim=None):
""" """
if ndim is None: if ndim is None:
if isinstance(dims, (tuple, list)): try:
ndim = len(dims) ndim = get_vector_length(dims)
elif isinstance(dims, np.ndarray) and dims.ndim == 1: except ValueError:
ndim = dims.shape[0] raise ValueError(
else: "The length of the provided dimension list (%s) cannot "
raise TypeError('unravel_index was called with a dimension ' "be automatically determined, so Theano is not able "
'list with an unspecified length (dim = %s). ' "to know what the number of dimensions of the unraveled "
'Use the ndim parameter of unravel_index to ' "index will be. You can provide the 'ndim' keyword "
'set the number of dimensions. ' % str(dims)) "argument to 'unravel_index' to avoid this problem." % str(dims))
res = UnravelIndex(order=order)(indices, dims, ndim) res = UnravelIndex(order=order)(indices, dims, ndim)
if ndim == 1: if ndim == 1:
......
...@@ -2785,6 +2785,14 @@ class test_unravel_index(utt.InferShapeTester): ...@@ -2785,6 +2785,14 @@ class test_unravel_index(utt.InferShapeTester):
np.testing.assert_equal(ref, f_array_symb()) np.testing.assert_equal(ref, f_array_symb())
np.testing.assert_equal(ref, f_symb_symb()) np.testing.assert_equal(ref, f_symb_symb())
# shape given as a Shape op (unravel_index will use get_vector_length
# to infer the number of dimensions)
indexed_array = theano.shared(np.random.uniform(size=shape_array))
f_array_shape = fn(indices, indexed_array.shape)
f_symb_shape = fn(indices_symb, indexed_array.shape)
np.testing.assert_equal(ref, f_array_shape())
np.testing.assert_equal(ref, f_symb_shape())
# shape testing # shape testing
self._compile_and_check([], self._compile_and_check([],
unravel_index(indices, shape_symb, order=order, ndim=len(shape)), unravel_index(indices, shape_symb, order=order, ndim=len(shape)),
...@@ -2797,7 +2805,7 @@ class test_unravel_index(utt.InferShapeTester): ...@@ -2797,7 +2805,7 @@ class test_unravel_index(utt.InferShapeTester):
check((3, 4, 5), index_ndim, order) check((3, 4, 5), index_ndim, order)
# must specify ndim if length of dims is not fixed # must specify ndim if length of dims is not fixed
self.assertRaises(TypeError, unravel_index, ivector(), ivector()) self.assertRaises(ValueError, unravel_index, ivector(), ivector())
# must provide integers # must provide integers
self.assertRaises(TypeError, unravel_index, fvector(), (3, 4)) self.assertRaises(TypeError, unravel_index, fvector(), (3, 4))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论