提交 43d91d0a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix JAX dispatch for multi-output Composite

上级 21d723bd
......@@ -63,11 +63,18 @@ def jax_funcify_Clip(op, **kwargs):
@jax_funcify.register(Composite)
def jax_funcify_Composite(op, vectorize=True, **kwargs):
def jax_funcify_Composite(op, node, vectorize=True, **kwargs):
jax_impl = jax_funcify(op.fgraph)
def composite(*args):
return jax_impl(*args)[0]
if len(node.outputs) == 1:
def composite(*args):
return jax_impl(*args)[0]
else:
def composite(*args):
return jax_impl(*args)
return jnp.vectorize(composite)
......
......@@ -63,7 +63,7 @@ def test_identity():
),
],
)
def test_jax_Composite(x, y, x_val, y_val):
def test_jax_Composite_singe_output(x, y, x_val, y_val):
x_s = aes.float64("x")
y_s = aes.float64("y")
......@@ -80,6 +80,16 @@ def test_jax_Composite(x, y, x_val, y_val):
_ = compare_jax_and_py(out_fg, test_input_vals)
def test_jax_Composite_multi_output():
x = vector("x")
x_s = aes.float64("xs")
outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x)
fgraph = FunctionGraph([x], outs)
compare_jax_and_py(fgraph, [np.arange(10, dtype=config.floatX)])
def test_erf():
x = scalar("x")
out = erf(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论