提交 43d91d0a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix JAX dispatch for multi-output Composite

上级 21d723bd
...@@ -63,11 +63,18 @@ def jax_funcify_Clip(op, **kwargs): ...@@ -63,11 +63,18 @@ def jax_funcify_Clip(op, **kwargs):
@jax_funcify.register(Composite) @jax_funcify.register(Composite)
def jax_funcify_Composite(op, vectorize=True, **kwargs): def jax_funcify_Composite(op, node, vectorize=True, **kwargs):
jax_impl = jax_funcify(op.fgraph) jax_impl = jax_funcify(op.fgraph)
def composite(*args): if len(node.outputs) == 1:
return jax_impl(*args)[0]
def composite(*args):
return jax_impl(*args)[0]
else:
def composite(*args):
return jax_impl(*args)
return jnp.vectorize(composite) return jnp.vectorize(composite)
......
...@@ -63,7 +63,7 @@ def test_identity(): ...@@ -63,7 +63,7 @@ def test_identity():
), ),
], ],
) )
def test_jax_Composite(x, y, x_val, y_val): def test_jax_Composite_singe_output(x, y, x_val, y_val):
x_s = aes.float64("x") x_s = aes.float64("x")
y_s = aes.float64("y") y_s = aes.float64("y")
...@@ -80,6 +80,16 @@ def test_jax_Composite(x, y, x_val, y_val): ...@@ -80,6 +80,16 @@ def test_jax_Composite(x, y, x_val, y_val):
_ = compare_jax_and_py(out_fg, test_input_vals) _ = compare_jax_and_py(out_fg, test_input_vals)
def test_jax_Composite_multi_output():
x = vector("x")
x_s = aes.float64("xs")
outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x)
fgraph = FunctionGraph([x], outs)
compare_jax_and_py(fgraph, [np.arange(10, dtype=config.floatX)])
def test_erf(): def test_erf():
x = scalar("x") x = scalar("x")
out = erf(x) out = erf(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论