提交 8112576b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Stop using deprecated `numpy.product`

上级 22e9233e
...@@ -1895,7 +1895,7 @@ class Mul(ScalarOp): ...@@ -1895,7 +1895,7 @@ class Mul(ScalarOp):
nfunc_variadic = "product" nfunc_variadic = "product"
def impl(self, *inputs): def impl(self, *inputs):
return np.product(inputs) return np.prod(inputs)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(z,) = outputs (z,) = outputs
......
...@@ -55,7 +55,7 @@ def test_extra_ops(): ...@@ -55,7 +55,7 @@ def test_extra_ops():
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
indices = np.arange(np.product((3, 4))) indices = np.arange(np.prod((3, 4)))
out = at_extra_ops.unravel_index(indices, (3, 4), order="C") out = at_extra_ops.unravel_index(indices, (3, 4), order="C")
fgraph = FunctionGraph([], out) fgraph = FunctionGraph([], out)
compare_jax_and_py( compare_jax_and_py(
...@@ -100,7 +100,7 @@ def test_extra_ops_omni(): ...@@ -100,7 +100,7 @@ def test_extra_ops_omni():
fgraph = FunctionGraph([], [out]) fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) multi_index = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4))
out = at_extra_ops.ravel_multi_index(multi_index, (3, 4)) out = at_extra_ops.ravel_multi_index(multi_index, (3, 4))
fgraph = FunctionGraph([], [out]) fgraph = FunctionGraph([], [out])
compare_jax_and_py( compare_jax_and_py(
......
...@@ -925,7 +925,7 @@ class TestUnique(utt.InferShapeTester): ...@@ -925,7 +925,7 @@ class TestUnique(utt.InferShapeTester):
class TestUnravelIndex(utt.InferShapeTester): class TestUnravelIndex(utt.InferShapeTester):
def test_unravel_index(self): def test_unravel_index(self):
def check(shape, index_ndim, order): def check(shape, index_ndim, order):
indices = np.arange(np.product(shape)) indices = np.arange(np.prod(shape))
# test with scalars and higher-dimensional indices # test with scalars and higher-dimensional indices
if index_ndim == 0: if index_ndim == 0:
indices = indices[-1] indices = indices[-1]
...@@ -996,7 +996,7 @@ class TestRavelMultiIndex(utt.InferShapeTester): ...@@ -996,7 +996,7 @@ class TestRavelMultiIndex(utt.InferShapeTester):
def test_ravel_multi_index(self): def test_ravel_multi_index(self):
def check(shape, index_ndim, mode, order): def check(shape, index_ndim, mode, order):
multi_index = np.unravel_index( multi_index = np.unravel_index(
np.arange(np.product(shape)), shape, order=order np.arange(np.prod(shape)), shape, order=order
) )
# create some invalid indices to test the mode # create some invalid indices to test the mode
if mode in ("wrap", "clip"): if mode in ("wrap", "clip"):
......
...@@ -1151,7 +1151,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1151,7 +1151,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
for inplace in (False, True): for inplace in (False, True):
for data_shape in ((10,), (4, 5), (1, 2, 3), (4, 5, 6, 7)): for data_shape in ((10,), (4, 5), (1, 2, 3), (4, 5, 6, 7)):
data_n_dims = len(data_shape) data_n_dims = len(data_shape)
data_size = np.product(data_shape) data_size = np.prod(data_shape)
# Corresponding numeric variable. # Corresponding numeric variable.
data_num_init = np.arange(data_size, dtype=self.dtype) data_num_init = np.arange(data_size, dtype=self.dtype)
data_num_init = data_num_init.reshape(data_shape) data_num_init = data_num_init.reshape(data_shape)
...@@ -1203,7 +1203,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1203,7 +1203,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
# The param dtype is needed when inc_shape is empty. # The param dtype is needed when inc_shape is empty.
# By default, it would return a float and rng.uniform # By default, it would return a float and rng.uniform
# with NumPy 1.10 will raise a Deprecation warning. # with NumPy 1.10 will raise a Deprecation warning.
inc_size = np.product(inc_shape, dtype="int") inc_size = np.prod(inc_shape, dtype="int")
# Corresponding numeric variable. # Corresponding numeric variable.
inc_num = rng.uniform(size=inc_size).astype(self.dtype) inc_num = rng.uniform(size=inc_size).astype(self.dtype)
inc_num = inc_num.reshape(inc_shape) inc_num = inc_num.reshape(inc_shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论