提交 482e6cc2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Speedup TestRavelMultiIndex

上级 f485be15
......@@ -1020,44 +1020,43 @@ class TestUnravelIndex(utt.InferShapeTester):
class TestRavelMultiIndex(utt.InferShapeTester):
def test_ravel_multi_index(self):
def check(shape, index_ndim, mode, order):
multi_index = np.unravel_index(
np.arange(np.prod(shape)), shape, order=order
)
# create some invalid indices to test the mode
if mode in ("wrap", "clip"):
multi_index = (multi_index[0] - 1, *multi_index[1:])
# test with scalars and higher-dimensional indices
if index_ndim == 0:
multi_index = tuple(i[-1] for i in multi_index)
elif index_ndim == 2:
multi_index = tuple(i[:, np.newaxis] for i in multi_index)
@staticmethod
def get_test_multiindex(shape, index_ndim, mode, order):
multi_index = np.unravel_index(np.arange(np.prod(shape)), shape, order=order)
# create some invalid indices to test the mode
if mode in ("wrap", "clip"):
multi_index = (multi_index[0] - 1, *multi_index[1:])
# test with scalars and higher-dimensional indices
if index_ndim == 0:
multi_index = tuple(i[-1] for i in multi_index)
elif index_ndim == 2:
multi_index = tuple(i[:, np.newaxis] for i in multi_index)
return multi_index
def test_eval(self):
def check_eval(shape, index_ndim, mode, order):
multi_index = self.get_test_multiindex(shape, index_ndim, mode, order)
multi_index_symb = [pytensor.shared(i) for i in multi_index]
shape_symb = pytensor.shared(np.array(shape))
out_symb = ravel_multi_index(multi_index_symb, shape_symb, mode, order)
# reference result
res = out_symb.eval()
ref = np.ravel_multi_index(multi_index, shape, mode, order)
np.testing.assert_equal(ref, res)
def fn(mi, s):
return function([], ravel_multi_index(mi, s, mode, order))
# shape given as a tuple
f_array_tuple = fn(multi_index, shape)
f_symb_tuple = fn(multi_index_symb, shape)
np.testing.assert_equal(ref, f_array_tuple())
np.testing.assert_equal(ref, f_symb_tuple())
# shape given as an array
shape_array = np.array(shape)
f_array_array = fn(multi_index, shape_array)
np.testing.assert_equal(ref, f_array_array())
for mode in ("raise", "wrap", "clip"):
for order in ("C", "F"):
for index_ndim in (0, 1, 2):
check_eval((3,), index_ndim, mode, order)
check_eval((3, 4), index_ndim, mode, order)
check_eval((3, 4, 5), index_ndim, mode, order)
# shape given as an PyTensor variable
shape_symb = pytensor.shared(shape_array)
f_array_symb = fn(multi_index, shape_symb)
np.testing.assert_equal(ref, f_array_symb())
def test_shape(self):
def check_shape(shape, index_ndim, mode, order):
multi_index = self.get_test_multiindex(shape, index_ndim, mode, order)
# shape testing
shape_symb = pytensor.shared(np.array(shape))
self._compile_and_check(
[],
[ravel_multi_index(multi_index, shape_symb, mode, order)],
......@@ -1065,13 +1064,28 @@ class TestRavelMultiIndex(utt.InferShapeTester):
RavelMultiIndex,
)
for mode in ("raise", "wrap", "clip"):
for order in ("C", "F"):
for index_ndim in (0, 1, 2):
check((3,), index_ndim, mode, order)
check((3, 4), index_ndim, mode, order)
check((3, 4, 5), index_ndim, mode, order)
for index_ndim in (0, 1, 2):
check_shape((3,), index_ndim, "raise", "C")
check_shape((3, 4), index_ndim, "raise", "C")
check_shape((3, 4, 5), index_ndim, "raise", "C")
def test_constant_inputs(self):
shape = (3,)
multi_index = np.unravel_index(np.arange(np.prod(shape)), shape)
ref = np.ravel_multi_index(multi_index, shape)
# shape given as a tuple
f_tuple_shape = ravel_multi_index(multi_index, shape).eval(mode="FAST_COMPILE")
np.testing.assert_equal(ref, f_tuple_shape)
# shape given as an array
shape_array = np.array(shape)
f_array_shape = ravel_multi_index(multi_index, shape_array).eval(
mode="FAST_COMPILE"
)
np.testing.assert_equal(ref, f_array_shape)
def test_errors(self):
# must provide integers
with pytest.raises(TypeError):
ravel_multi_index((fvector(), ivector()), (3, 4))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论