提交 e8bd0d7d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in JAX implementation of Second

上级 5e6b356d
......@@ -181,7 +181,8 @@ def jax_funcify_Composite(op, node, vectorize=True, **kwargs):
@jax_funcify.register(Second)
def jax_funcify_Second(op, **kwargs):
def second(x, y):
return jnp.broadcast_to(y, x.shape)
_, y = jnp.broadcast_arrays(x, y)
return y
return second
......
......@@ -25,6 +25,7 @@ from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
from pytensor.link.jax.dispatch import jax_funcify
def test_second():
......@@ -40,6 +41,25 @@ def test_second():
fgraph = FunctionGraph([a1, b], [out])
compare_jax_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0])
a2 = matrix("a2", shape=(1, None), dtype="float64")
b2 = matrix("b2", shape=(None, 1), dtype="int32")
out = at.second(a2, b2)
fgraph = FunctionGraph([a2, b2], [out])
compare_jax_and_py(
fgraph, [np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")]
)
def test_second_constant_scalar():
b = scalar("b", dtype="int")
out = at.second(0.0, b)
fgraph = FunctionGraph([b], [out])
# Test dispatch directly as useless second is removed during compilation
fn = jax_funcify(fgraph)
[res] = fn(1)
assert res == 1
assert res.dtype == out.dtype
def test_identity():
a = scalar("a")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论