提交 44aa0fb5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Remove unnecessary ndim property and use order property in UnravelIndex

上级 6bf0d06d
...@@ -1065,13 +1065,10 @@ class TestUnravelIndex(utt.InferShapeTester): ...@@ -1065,13 +1065,10 @@ class TestUnravelIndex(utt.InferShapeTester):
indices_symb = theano.shared(indices) indices_symb = theano.shared(indices)
# reference result # reference result
ref = np.unravel_index(indices, shape) ref = np.unravel_index(indices, shape, order=order)
def fn(i, d, nd=None): def fn(i, d):
if nd is None: return function([], unravel_index(i, d, order=order))
return function([], unravel_index(i, d, order=order))
else:
return function([], unravel_index(i, d, order=order, ndim=nd))
# shape given as a tuple # shape given as a tuple
f_array_tuple = fn(indices, shape) f_array_tuple = fn(indices, shape)
...@@ -1086,7 +1083,7 @@ class TestUnravelIndex(utt.InferShapeTester): ...@@ -1086,7 +1083,7 @@ class TestUnravelIndex(utt.InferShapeTester):
# shape given as a theano variable # shape given as a theano variable
shape_symb = theano.shared(shape_array) shape_symb = theano.shared(shape_array)
f_array_symb = fn(indices, shape_symb, len(shape)) f_array_symb = fn(indices, shape_symb)
np.testing.assert_equal(ref, f_array_symb()) np.testing.assert_equal(ref, f_array_symb())
# shape given as a Shape op (unravel_index will use get_vector_length # shape given as a Shape op (unravel_index will use get_vector_length
...@@ -1098,7 +1095,7 @@ class TestUnravelIndex(utt.InferShapeTester): ...@@ -1098,7 +1095,7 @@ class TestUnravelIndex(utt.InferShapeTester):
# shape testing # shape testing
self._compile_and_check( self._compile_and_check(
[], [],
unravel_index(indices, shape_symb, order=order, ndim=len(shape)), unravel_index(indices, shape_symb, order=order),
[], [],
UnravelIndex, UnravelIndex,
) )
...@@ -1118,8 +1115,6 @@ class TestUnravelIndex(utt.InferShapeTester): ...@@ -1118,8 +1115,6 @@ class TestUnravelIndex(utt.InferShapeTester):
unravel_index(theano.tensor.fvector(), (3, 4)) unravel_index(theano.tensor.fvector(), (3, 4))
with pytest.raises(TypeError): with pytest.raises(TypeError):
unravel_index((3, 4), (3.4, 3.2)) unravel_index((3, 4), (3.4, 3.2))
with pytest.raises(ValueError):
unravel_index((3, 4), (3, 3), ndim=5.4)
# dims must be a 1D sequence # dims must be a 1D sequence
with pytest.raises(TypeError): with pytest.raises(TypeError):
......
...@@ -1292,13 +1292,10 @@ class Unique(theano.Op): ...@@ -1292,13 +1292,10 @@ class Unique(theano.Op):
class UnravelIndex(gof.Op): class UnravelIndex(gof.Op):
__props__ = ("ndim", "order") __props__ = ("order",)
def __init__(self, ndim, order="C"): def __init__(self, order="C"):
assert order in ("C", "F") assert order in ("C", "F")
if not isinstance(ndim, int) or ndim < 1:
raise ValueError("ndim must be an integer greater than 0")
self.ndim = int(ndim)
self.order = order self.order = order
def make_node(self, indices, dims): def make_node(self, indices, dims):
...@@ -1321,7 +1318,7 @@ class UnravelIndex(gof.Op): ...@@ -1321,7 +1318,7 @@ class UnravelIndex(gof.Op):
[indices, dims], [indices, dims],
[ [
basic.TensorType(dtype="int64", broadcastable=(False,) * indices.ndim)() basic.TensorType(dtype="int64", broadcastable=(False,) * indices.ndim)()
for i in range(self.ndim) for i in range(basic.get_vector_length(dims))
], ],
) )
...@@ -1330,7 +1327,7 @@ class UnravelIndex(gof.Op): ...@@ -1330,7 +1327,7 @@ class UnravelIndex(gof.Op):
def perform(self, node, inp, out): def perform(self, node, inp, out):
indices, dims = inp indices, dims = inp
res = np.unravel_index(indices, dims) res = np.unravel_index(indices, dims, order=self.order)
assert len(res) == len(out) assert len(res) == len(out)
for i in range(len(out)): for i in range(len(out)):
ret = theano._asarray(res[i], node.outputs[0].dtype) ret = theano._asarray(res[i], node.outputs[0].dtype)
...@@ -1341,15 +1338,11 @@ class UnravelIndex(gof.Op): ...@@ -1341,15 +1338,11 @@ class UnravelIndex(gof.Op):
out[i][0] = ret out[i][0] = ret
def unravel_index(indices, dims, order="C", ndim=None): def unravel_index(indices, dims, order="C"):
""" """
Converts a flat index or array of flat indices into a tuple Converts a flat index or array of flat indices into a tuple
of coordinate arrays. of coordinate arrays.
This method is similar to the NumPy version, except for the
additional ``ndim`` parameter. This parameter is required if
the length of ``dims`` cannot be determined automatically.
Parameters Parameters
---------- ----------
indices : Theano or NumPy array indices : Theano or NumPy array
...@@ -1360,10 +1353,6 @@ def unravel_index(indices, dims, order="C", ndim=None): ...@@ -1360,10 +1353,6 @@ def unravel_index(indices, dims, order="C", ndim=None):
order : {'C', 'F'}, optional order : {'C', 'F'}, optional
Determines whether the indices should be viewed as indexing in Determines whether the indices should be viewed as indexing in
row-major (C-style) or column-major (Fortran-style) order. row-major (C-style) or column-major (Fortran-style) order.
ndim : int, optional
Specifies the number of dimensions, i.e., the length of
``dims``. This is required if the dimensions cannot be determined
automatically from ``dims`` itself.
Returns Returns
------- -------
...@@ -1376,20 +1365,8 @@ def unravel_index(indices, dims, order="C", ndim=None): ...@@ -1376,20 +1365,8 @@ def unravel_index(indices, dims, order="C", ndim=None):
ravel_multi_index ravel_multi_index
""" """
if ndim is None: res = UnravelIndex(order=order)(indices, dims)
try: if not isinstance(res, (list, tuple)):
ndim = basic.get_vector_length(dims)
except ValueError:
raise ValueError(
"The length of the provided dimension list (%s) cannot "
"be automatically determined, so Theano is not able "
"to know what the number of dimensions of the unraveled "
"index will be. You can provide the 'ndim' keyword "
"argument to 'unravel_index' to avoid this problem." % str(dims)
)
res = UnravelIndex(ndim=ndim, order=order)(indices, dims)
if ndim == 1:
return (res,) return (res,)
else: else:
return tuple(res) return tuple(res)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论