Unverified 提交 b9792d8a authored 作者: jessegrabowski's avatar jessegrabowski 提交者: GitHub

Add JAX support for `pt.tri` (#302)

上级 6b43b433
...@@ -18,6 +18,7 @@ from pytensor.tensor.basic import ( ...@@ -18,6 +18,7 @@ from pytensor.tensor.basic import (
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
Tri,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
) )
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
...@@ -26,7 +27,6 @@ from pytensor.tensor.exceptions import NotScalarConstantError ...@@ -26,7 +27,6 @@ from pytensor.tensor.exceptions import NotScalarConstantError
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange` ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
to be constants. The graph that you defined thus cannot be JIT-compiled to be constants. The graph that you defined thus cannot be JIT-compiled
by JAX. An example of a graph that can be compiled to JAX: by JAX. An example of a graph that can be compiled to JAX:
>>> import pytensor.tensor basic >>> import pytensor.tensor basic
>>> at.arange(1, 10, 2) >>> at.arange(1, 10, 2)
""" """
...@@ -193,3 +193,18 @@ def jax_funcify_ScalarFromTensor(op, **kwargs): ...@@ -193,3 +193,18 @@ def jax_funcify_ScalarFromTensor(op, **kwargs):
return jnp.array(x).flatten()[0] return jnp.array(x).flatten()[0]
return scalar_from_tensor return scalar_from_tensor
@jax_funcify.register(Tri)
def jax_funcify_Tri(op, node, **kwargs):
# node.inputs is N, M, k
const_args = [getattr(x, "data", None) for x in node.inputs]
def tri(*args):
# args is N, M, k
args = [
x if const_x is None else const_x for x, const_x in zip(args, const_args)
]
return jnp.tri(*args, dtype=op.dtype)
return tri
...@@ -191,3 +191,30 @@ def test_jax_eye(): ...@@ -191,3 +191,30 @@ def test_jax_eye():
out_fg = FunctionGraph([], [out]) out_fg = FunctionGraph([], [out])
compare_jax_and_py(out_fg, []) compare_jax_and_py(out_fg, [])
def test_tri():
out = at.tri(10, 10, 0)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
def test_tri_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
m, n, k = (
scalar("a", dtype="int64"),
scalar("n", dtype="int64"),
scalar("k", dtype="int64"),
)
m.tag.test_value = 10
n.tag.test_value = 10
k.tag.test_value = 0
out = at.tri(m, n, k)
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but
# the error handler raises an Attribute error first, so that's what this test needs to pass
with pytest.raises(AttributeError):
fgraph = FunctionGraph([m, n, k], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论