提交 482e6cc2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Speedup TestRavelMultiIndex

上级 f485be15
...@@ -1020,44 +1020,43 @@ class TestUnravelIndex(utt.InferShapeTester): ...@@ -1020,44 +1020,43 @@ class TestUnravelIndex(utt.InferShapeTester):
class TestRavelMultiIndex(utt.InferShapeTester): class TestRavelMultiIndex(utt.InferShapeTester):
def test_ravel_multi_index(self): @staticmethod
def check(shape, index_ndim, mode, order): def get_test_multiindex(shape, index_ndim, mode, order):
multi_index = np.unravel_index( multi_index = np.unravel_index(np.arange(np.prod(shape)), shape, order=order)
np.arange(np.prod(shape)), shape, order=order # create some invalid indices to test the mode
) if mode in ("wrap", "clip"):
# create some invalid indices to test the mode multi_index = (multi_index[0] - 1, *multi_index[1:])
if mode in ("wrap", "clip"): # test with scalars and higher-dimensional indices
multi_index = (multi_index[0] - 1, *multi_index[1:]) if index_ndim == 0:
# test with scalars and higher-dimensional indices multi_index = tuple(i[-1] for i in multi_index)
if index_ndim == 0: elif index_ndim == 2:
multi_index = tuple(i[-1] for i in multi_index) multi_index = tuple(i[:, np.newaxis] for i in multi_index)
elif index_ndim == 2:
multi_index = tuple(i[:, np.newaxis] for i in multi_index) return multi_index
def test_eval(self):
def check_eval(shape, index_ndim, mode, order):
multi_index = self.get_test_multiindex(shape, index_ndim, mode, order)
multi_index_symb = [pytensor.shared(i) for i in multi_index] multi_index_symb = [pytensor.shared(i) for i in multi_index]
shape_symb = pytensor.shared(np.array(shape))
out_symb = ravel_multi_index(multi_index_symb, shape_symb, mode, order)
# reference result res = out_symb.eval()
ref = np.ravel_multi_index(multi_index, shape, mode, order) ref = np.ravel_multi_index(multi_index, shape, mode, order)
np.testing.assert_equal(ref, res)
def fn(mi, s): for mode in ("raise", "wrap", "clip"):
return function([], ravel_multi_index(mi, s, mode, order)) for order in ("C", "F"):
for index_ndim in (0, 1, 2):
# shape given as a tuple check_eval((3,), index_ndim, mode, order)
f_array_tuple = fn(multi_index, shape) check_eval((3, 4), index_ndim, mode, order)
f_symb_tuple = fn(multi_index_symb, shape) check_eval((3, 4, 5), index_ndim, mode, order)
np.testing.assert_equal(ref, f_array_tuple())
np.testing.assert_equal(ref, f_symb_tuple())
# shape given as an array
shape_array = np.array(shape)
f_array_array = fn(multi_index, shape_array)
np.testing.assert_equal(ref, f_array_array())
# shape given as an PyTensor variable def test_shape(self):
shape_symb = pytensor.shared(shape_array) def check_shape(shape, index_ndim, mode, order):
f_array_symb = fn(multi_index, shape_symb) multi_index = self.get_test_multiindex(shape, index_ndim, mode, order)
np.testing.assert_equal(ref, f_array_symb())
# shape testing shape_symb = pytensor.shared(np.array(shape))
self._compile_and_check( self._compile_and_check(
[], [],
[ravel_multi_index(multi_index, shape_symb, mode, order)], [ravel_multi_index(multi_index, shape_symb, mode, order)],
...@@ -1065,13 +1064,28 @@ class TestRavelMultiIndex(utt.InferShapeTester): ...@@ -1065,13 +1064,28 @@ class TestRavelMultiIndex(utt.InferShapeTester):
RavelMultiIndex, RavelMultiIndex,
) )
for mode in ("raise", "wrap", "clip"): for index_ndim in (0, 1, 2):
for order in ("C", "F"): check_shape((3,), index_ndim, "raise", "C")
for index_ndim in (0, 1, 2): check_shape((3, 4), index_ndim, "raise", "C")
check((3,), index_ndim, mode, order) check_shape((3, 4, 5), index_ndim, "raise", "C")
check((3, 4), index_ndim, mode, order)
check((3, 4, 5), index_ndim, mode, order) def test_constant_inputs(self):
shape = (3,)
multi_index = np.unravel_index(np.arange(np.prod(shape)), shape)
ref = np.ravel_multi_index(multi_index, shape)
# shape given as a tuple
f_tuple_shape = ravel_multi_index(multi_index, shape).eval(mode="FAST_COMPILE")
np.testing.assert_equal(ref, f_tuple_shape)
# shape given as an array
shape_array = np.array(shape)
f_array_shape = ravel_multi_index(multi_index, shape_array).eval(
mode="FAST_COMPILE"
)
np.testing.assert_equal(ref, f_array_shape)
def test_errors(self):
# must provide integers # must provide integers
with pytest.raises(TypeError): with pytest.raises(TypeError):
ravel_multi_index((fvector(), ivector()), (3, 4)) ravel_multi_index((fvector(), ivector()), (3, 4))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论