提交 ae0d79ba authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not use deprecated variadic in jax dispatch of multiply

上级 948fbc7d
......@@ -1898,7 +1898,7 @@ class Mul(ScalarOp):
commutative = True
associative = True
nfunc_spec = ("multiply", 2, 1)
nfunc_variadic = "product"
nfunc_variadic = "prod"
def impl(self, *inputs):
return np.prod(inputs)
......
......@@ -13,7 +13,7 @@ from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import prod
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.type import matrix, tensor, vector, vectors
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_elemwise import TestElemwise
......@@ -129,3 +129,11 @@ def test_logsumexp_benchmark(size, axis, benchmark):
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
def test_multiple_input_multiply():
x, y, z = vectors("xyz")
out = at.mul(x, y, z)
fg = FunctionGraph(outputs=[out], clone=False)
compare_jax_and_py(fg, [[1.5], [2.5], [3.5]])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论