提交 1f3b41c8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Use Elemwise NumPy information to convert CAReduce nodes to JAX

上级 80d8a0ef
......@@ -518,9 +518,15 @@ def test_tensor_basics():
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = theano.gof.FunctionGraph([y, x, A, alpha, beta], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.maximum(y, x)
fgraph = theano.gof.FunctionGraph([y, x], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = tt.max(y)
fgraph = theano.gof.FunctionGraph([y], [out])
_ = compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
......@@ -528,28 +528,40 @@ def jax_funcify_FunctionGraph(fgraph):
@jax_funcify.register(CAReduce)
def jax_funcify_CAReduce(op):
axis = op.axis
op_nfunc_spec = getattr(op, "nfunc_spec", None)
scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None)
scalar_op_name = getattr(op.scalar_op, "name", None)
scalar_op_identity = getattr(op.scalar_op, "identity", None)
acc_dtype = getattr(op, "acc_dtype", None)
def careduce(x):
axis = op.axis
nonlocal axis, op_nfunc_spec, scalar_nfunc_spec, scalar_op_name, scalar_op_identity, acc_dtype
if axis is None:
axis = list(range(x.ndim))
to_reduce = reversed(sorted(axis))
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
if acc_dtype is None:
acc_dtype = x.dtype.type
if op_nfunc_spec:
jax_op = getattr(jnp, op_nfunc_spec[0])
return jax_op(x, axis=axis).astype(acc_dtype)
# The Theano `Op` didn't tell us which NumPy equivalent to use (or
# there isn't one), so we use this fallback approach
if scalar_nfunc_spec:
scalar_fn_name = scalar_nfunc_spec[0]
elif scalar_op_name:
scalar_fn_name = scalar_op_name
to_reduce = reversed(sorted(axis))
if to_reduce:
if getattr(op.scalar_op, "name", None):
jax_op = getattr(jax.lax, op.scalar_op.name)
elif getattr(op.scalar_op, "nfunc_spec", None):
# In this case, we need to use the `jax.lax` function (if there
# is one), and not the `jnp` version.
jax_op = getattr(jax.lax, op.scalar_op.nfunc_spec[0])
init_value = jnp.array(op.scalar_op.identity, dtype=acc_dtype)
# In this case, we need to use the `jax.lax` function (if there
# is one), and not the `jnp` version.
jax_op = getattr(jax.lax, scalar_fn_name)
init_value = jnp.array(scalar_op_identity, dtype=acc_dtype)
return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype)
else:
return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论