Unverified 提交 55ef0e6b authored 作者: George Ho's avatar George Ho 提交者: GitHub

Use jax.lax.cond for IfElse Op (#187)

上级 201e3845
...@@ -572,10 +572,12 @@ def test_jax_ifelse(): ...@@ -572,10 +572,12 @@ def test_jax_ifelse():
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
x = theano.ifelse.ifelse(np.array(False), true_vals, false_vals) a = tt.dscalar("a")
x_fg = theano.gof.FunctionGraph([], [x]) a.tag.test_value = np.array(0.2, dtype=theano.config.floatX)
x = theano.ifelse.ifelse(a < 0.5, true_vals, false_vals)
x_fg = theano.gof.FunctionGraph([a], [x]) # I.e. False
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
def test_jax_CAReduce(): def test_jax_CAReduce():
......
...@@ -571,11 +571,9 @@ def jax_funcify_IfElse(op): ...@@ -571,11 +571,9 @@ def jax_funcify_IfElse(op):
n_outs = op.n_outs n_outs = op.n_outs
def ifelse(cond, *args, n_outs=n_outs): def ifelse(cond, *args, n_outs=n_outs):
if cond: res = jax.lax.cond(
res = args[:n_outs] cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None
else: )
res = args[n_outs:]
return res if n_outs > 1 else res[0] return res if n_outs > 1 else res[0]
return ifelse return ifelse
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论