Unverified 提交 55ef0e6b authored 作者: George Ho's avatar George Ho 提交者: GitHub

Use jax.lax.cond for IfElse Op (#187)

上级 201e3845
......@@ -572,10 +572,12 @@ def test_jax_ifelse():
compare_jax_and_py(x_fg, [])
x = theano.ifelse.ifelse(np.array(False), true_vals, false_vals)
x_fg = theano.gof.FunctionGraph([], [x])
a = tt.dscalar("a")
a.tag.test_value = np.array(0.2, dtype=theano.config.floatX)
x = theano.ifelse.ifelse(a < 0.5, true_vals, false_vals)
x_fg = theano.gof.FunctionGraph([a], [x]) # I.e. False
compare_jax_and_py(x_fg, [])
compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
def test_jax_CAReduce():
......
......@@ -571,11 +571,9 @@ def jax_funcify_IfElse(op):
n_outs = op.n_outs
def ifelse(cond, *args, n_outs=n_outs):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
res = jax.lax.cond(
cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None
)
return res if n_outs > 1 else res[0]
return ifelse
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论