提交 44aa0fb5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Remove unnecessary ndim property and use order property in UnravelIndex

上级 6bf0d06d
......@@ -1065,13 +1065,10 @@ class TestUnravelIndex(utt.InferShapeTester):
indices_symb = theano.shared(indices)
# reference result
ref = np.unravel_index(indices, shape)
ref = np.unravel_index(indices, shape, order=order)
def fn(i, d, nd=None):
if nd is None:
return function([], unravel_index(i, d, order=order))
else:
return function([], unravel_index(i, d, order=order, ndim=nd))
def fn(i, d):
return function([], unravel_index(i, d, order=order))
# shape given as a tuple
f_array_tuple = fn(indices, shape)
......@@ -1086,7 +1083,7 @@ class TestUnravelIndex(utt.InferShapeTester):
# shape given as a theano variable
shape_symb = theano.shared(shape_array)
f_array_symb = fn(indices, shape_symb, len(shape))
f_array_symb = fn(indices, shape_symb)
np.testing.assert_equal(ref, f_array_symb())
# shape given as a Shape op (unravel_index will use get_vector_length
......@@ -1098,7 +1095,7 @@ class TestUnravelIndex(utt.InferShapeTester):
# shape testing
self._compile_and_check(
[],
unravel_index(indices, shape_symb, order=order, ndim=len(shape)),
unravel_index(indices, shape_symb, order=order),
[],
UnravelIndex,
)
......@@ -1118,8 +1115,6 @@ class TestUnravelIndex(utt.InferShapeTester):
unravel_index(theano.tensor.fvector(), (3, 4))
with pytest.raises(TypeError):
unravel_index((3, 4), (3.4, 3.2))
with pytest.raises(ValueError):
unravel_index((3, 4), (3, 3), ndim=5.4)
# dims must be a 1D sequence
with pytest.raises(TypeError):
......
......@@ -1292,13 +1292,10 @@ class Unique(theano.Op):
class UnravelIndex(gof.Op):
__props__ = ("ndim", "order")
__props__ = ("order",)
def __init__(self, ndim, order="C"):
def __init__(self, order="C"):
assert order in ("C", "F")
if not isinstance(ndim, int) or ndim < 1:
raise ValueError("ndim must be an integer greater than 0")
self.ndim = int(ndim)
self.order = order
def make_node(self, indices, dims):
......@@ -1321,7 +1318,7 @@ class UnravelIndex(gof.Op):
[indices, dims],
[
basic.TensorType(dtype="int64", broadcastable=(False,) * indices.ndim)()
for i in range(self.ndim)
for i in range(basic.get_vector_length(dims))
],
)
......@@ -1330,7 +1327,7 @@ class UnravelIndex(gof.Op):
def perform(self, node, inp, out):
indices, dims = inp
res = np.unravel_index(indices, dims)
res = np.unravel_index(indices, dims, order=self.order)
assert len(res) == len(out)
for i in range(len(out)):
ret = theano._asarray(res[i], node.outputs[0].dtype)
......@@ -1341,15 +1338,11 @@ class UnravelIndex(gof.Op):
out[i][0] = ret
def unravel_index(indices, dims, order="C", ndim=None):
def unravel_index(indices, dims, order="C"):
"""
Converts a flat index or array of flat indices into a tuple
of coordinate arrays.
This method is similar to the NumPy version, except for the
additional ``ndim`` parameter. This parameter is required if
the length of ``dims`` cannot be determined automatically.
Parameters
----------
indices : Theano or NumPy array
......@@ -1360,10 +1353,6 @@ def unravel_index(indices, dims, order="C", ndim=None):
order : {'C', 'F'}, optional
Determines whether the indices should be viewed as indexing in
row-major (C-style) or column-major (Fortran-style) order.
ndim : int, optional
Specifies the number of dimensions, i.e., the length of
``dims``. This is required if the dimensions cannot be determined
automatically from ``dims`` itself.
Returns
-------
......@@ -1376,20 +1365,8 @@ def unravel_index(indices, dims, order="C", ndim=None):
ravel_multi_index
"""
if ndim is None:
try:
ndim = basic.get_vector_length(dims)
except ValueError:
raise ValueError(
"The length of the provided dimension list (%s) cannot "
"be automatically determined, so Theano is not able "
"to know what the number of dimensions of the unraveled "
"index will be. You can provide the 'ndim' keyword "
"argument to 'unravel_index' to avoid this problem." % str(dims)
)
res = UnravelIndex(ndim=ndim, order=order)(indices, dims)
if ndim == 1:
res = UnravelIndex(order=order)(indices, dims)
if not isinstance(res, (list, tuple)):
return (res,)
else:
return tuple(res)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论