提交 482e6cc2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Speedup TestRavelMultiIndex

上级 f485be15
...@@ -1020,11 +1020,9 @@ class TestUnravelIndex(utt.InferShapeTester): ...@@ -1020,11 +1020,9 @@ class TestUnravelIndex(utt.InferShapeTester):
class TestRavelMultiIndex(utt.InferShapeTester): class TestRavelMultiIndex(utt.InferShapeTester):
def test_ravel_multi_index(self): @staticmethod
def check(shape, index_ndim, mode, order): def get_test_multiindex(shape, index_ndim, mode, order):
multi_index = np.unravel_index( multi_index = np.unravel_index(np.arange(np.prod(shape)), shape, order=order)
np.arange(np.prod(shape)), shape, order=order
)
# create some invalid indices to test the mode # create some invalid indices to test the mode
if mode in ("wrap", "clip"): if mode in ("wrap", "clip"):
multi_index = (multi_index[0] - 1, *multi_index[1:]) multi_index = (multi_index[0] - 1, *multi_index[1:])
...@@ -1033,31 +1031,32 @@ class TestRavelMultiIndex(utt.InferShapeTester): ...@@ -1033,31 +1031,32 @@ class TestRavelMultiIndex(utt.InferShapeTester):
multi_index = tuple(i[-1] for i in multi_index) multi_index = tuple(i[-1] for i in multi_index)
elif index_ndim == 2: elif index_ndim == 2:
multi_index = tuple(i[:, np.newaxis] for i in multi_index) multi_index = tuple(i[:, np.newaxis] for i in multi_index)
multi_index_symb = [pytensor.shared(i) for i in multi_index]
# reference result return multi_index
ref = np.ravel_multi_index(multi_index, shape, mode, order)
def fn(mi, s): def test_eval(self):
return function([], ravel_multi_index(mi, s, mode, order)) def check_eval(shape, index_ndim, mode, order):
multi_index = self.get_test_multiindex(shape, index_ndim, mode, order)
multi_index_symb = [pytensor.shared(i) for i in multi_index]
shape_symb = pytensor.shared(np.array(shape))
out_symb = ravel_multi_index(multi_index_symb, shape_symb, mode, order)
# shape given as a tuple res = out_symb.eval()
f_array_tuple = fn(multi_index, shape) ref = np.ravel_multi_index(multi_index, shape, mode, order)
f_symb_tuple = fn(multi_index_symb, shape) np.testing.assert_equal(ref, res)
np.testing.assert_equal(ref, f_array_tuple())
np.testing.assert_equal(ref, f_symb_tuple())
# shape given as an array for mode in ("raise", "wrap", "clip"):
shape_array = np.array(shape) for order in ("C", "F"):
f_array_array = fn(multi_index, shape_array) for index_ndim in (0, 1, 2):
np.testing.assert_equal(ref, f_array_array()) check_eval((3,), index_ndim, mode, order)
check_eval((3, 4), index_ndim, mode, order)
check_eval((3, 4, 5), index_ndim, mode, order)
# shape given as an PyTensor variable def test_shape(self):
shape_symb = pytensor.shared(shape_array) def check_shape(shape, index_ndim, mode, order):
f_array_symb = fn(multi_index, shape_symb) multi_index = self.get_test_multiindex(shape, index_ndim, mode, order)
np.testing.assert_equal(ref, f_array_symb())
# shape testing shape_symb = pytensor.shared(np.array(shape))
self._compile_and_check( self._compile_and_check(
[], [],
[ravel_multi_index(multi_index, shape_symb, mode, order)], [ravel_multi_index(multi_index, shape_symb, mode, order)],
...@@ -1065,13 +1064,28 @@ class TestRavelMultiIndex(utt.InferShapeTester): ...@@ -1065,13 +1064,28 @@ class TestRavelMultiIndex(utt.InferShapeTester):
RavelMultiIndex, RavelMultiIndex,
) )
for mode in ("raise", "wrap", "clip"):
for order in ("C", "F"):
for index_ndim in (0, 1, 2): for index_ndim in (0, 1, 2):
check((3,), index_ndim, mode, order) check_shape((3,), index_ndim, "raise", "C")
check((3, 4), index_ndim, mode, order) check_shape((3, 4), index_ndim, "raise", "C")
check((3, 4, 5), index_ndim, mode, order) check_shape((3, 4, 5), index_ndim, "raise", "C")
def test_constant_inputs(self):
shape = (3,)
multi_index = np.unravel_index(np.arange(np.prod(shape)), shape)
ref = np.ravel_multi_index(multi_index, shape)
# shape given as a tuple
f_tuple_shape = ravel_multi_index(multi_index, shape).eval(mode="FAST_COMPILE")
np.testing.assert_equal(ref, f_tuple_shape)
# shape given as an array
shape_array = np.array(shape)
f_array_shape = ravel_multi_index(multi_index, shape_array).eval(
mode="FAST_COMPILE"
)
np.testing.assert_equal(ref, f_array_shape)
def test_errors(self):
# must provide integers # must provide integers
with pytest.raises(TypeError): with pytest.raises(TypeError):
ravel_multi_index((fvector(), ivector()), (3, 4)) ravel_multi_index((fvector(), ivector()), (3, 4))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论