提交 f9f2080e authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement `tri` symbolically

上级 46f89676
......@@ -17,16 +17,15 @@ from pytensor.tensor.basic import (
ScalarFromTensor,
Split,
TensorFromScalar,
Tri,
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
to be constants. The graph that you defined thus cannot be JIT-compiled
by JAX. An example of a graph that can be compiled to JAX:
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange` to be constants.
The graph that you defined thus cannot be JIT-compiled by JAX.
An example of a graph that can be compiled to JAX:
>>> import pytensor.tensor as pt
>>> pt.arange(1, 10, 2)
"""
......@@ -185,19 +184,3 @@ def jax_funcify_ScalarFromTensor(op, **kwargs):
return jnp.array(x).flatten()[0]
return scalar_from_tensor
@jax_funcify.register(Tri)
def jax_funcify_Tri(op, node, **kwargs):
# node.inputs is N, M, k
const_args = [getattr(x, "data", None) for x in node.inputs]
def tri(*args):
# args is N, M, k
args = [
x if const_x is None else const_x
for x, const_x in zip(args, const_args, strict=True)
]
return jnp.tri(*args, dtype=op.dtype)
return tri
......@@ -13,7 +13,6 @@ from pytensor.tensor.basic import (
ScalarFromTensor,
Split,
TensorFromScalar,
Tri,
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
......@@ -219,23 +218,6 @@ def mlx_funcify_ScalarFromTensor(op, **kwargs):
return scalar_from_tensor
@mlx_funcify.register(Tri)
def mlx_funcify_Tri(op, node, **kwargs):
# node.inputs -> N, M, k
const_args = [getattr(inp, "data", None) for inp in node.inputs]
dtype = convert_dtype_to_mlx(op.dtype)
def tri(*args):
# Replace args with compile-time constants when available
args = [
arg if const_a is None else const_a
for arg, const_a in zip(args, const_args, strict=True)
]
return mx.tri(*args, dtype=dtype)
return tri
@mlx_funcify.register(AllocEmpty)
def mlx_funcify_AllocEmpty(op, node, **kwargs):
dtype = convert_dtype_to_mlx(op.dtype)
......
......@@ -1088,39 +1088,6 @@ def nonzero_values(a):
return _a.flatten()[flatnonzero(_a)]
class Tri(Op):
__props__ = ("dtype",)
def __init__(self, dtype=None):
if dtype is None:
dtype = config.floatX
else:
dtype = np.dtype(dtype).name
self.dtype = dtype
def make_node(self, N, M, k):
N = as_tensor_variable(N)
M = as_tensor_variable(M)
k = as_tensor_variable(k)
return Apply(
self,
[N, M, k],
[TensorType(dtype=self.dtype, shape=(None, None))()],
)
def perform(self, node, inp, out_):
N, M, k = inp
(out,) = out_
out[0] = np.tri(N, M, k, dtype=self.dtype)
def infer_shape(self, fgraph, node, in_shapes):
out_shape = [node.inputs[0], node.inputs[1]]
return [out_shape]
def grad(self, inp, grads):
return [grad_undefined(self, i, inp[i]) for i in range(3)]
def tri(N, M=None, k=0, dtype=None):
"""
An array with ones at and below the given diagonal and zeros elsewhere.
......@@ -1148,10 +1115,12 @@ def tri(N, M=None, k=0, dtype=None):
"""
if dtype is None:
dtype = config.floatX
if M is None:
M = N
op = Tri(dtype)
return op(N, M, k)
# Implementation adapted from https://github.com/numpy/numpy/blob/2f7fe64b8b6d7591dd208942f1cc74473d5db4cb/numpy/lib/_twodim_base_impl.py#L421-L433
m = arange(N)[:, None] >= arange(-k, M - k)[None, :]
return m.astype(dtype)
def tril(m, k=0):
......
import re
import numpy as np
import pytest
......@@ -210,10 +212,6 @@ def test_tri():
compare_jax_and_py([], [out], [])
@pytest.mark.skipif(
jax.__version__ == "0.4.31",
reason="https://github.com/google/jax/issues/22751",
)
def test_tri_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
......@@ -228,7 +226,10 @@ def test_tri_nonconcrete():
out = ptb.tri(m, n, k)
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but
# the error handler raises an Attribute error first, so that's what this test needs to pass
with pytest.raises((AttributeError, TypeError)):
with pytest.raises(
NotImplementedError,
match=re.escape(
"JAX requires the arguments of `jax.numpy.arange` to be constants"
),
):
compare_jax_and_py([m, n, k], [out], [m_test_value, n_test_value, k_test_value])
......@@ -35,7 +35,6 @@ from pytensor.tensor.basic import (
ScalarFromTensor,
Split,
TensorFromScalar,
Tri,
alloc,
alloc_diag,
arange,
......@@ -972,22 +971,17 @@ class TestEye:
class TestTriangle:
def test_tri(self):
def check(dtype, N, M_=None, k=0):
# PyTensor does not accept None as a tensor.
# So we must use a real value.
M = M_
# Currently DebugMode does not support None as inputs even if this is
# allowed.
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
def check(dtype, N, M=None, k=0):
if M is None:
M = N
N_symb = iscalar()
M_symb = iscalar()
k_symb = iscalar()
N_symb = iscalar("N")
M_symb = iscalar("M")
k_symb = iscalar("k")
f = function(
[N_symb, M_symb, k_symb], tri(N_symb, M_symb, k_symb, dtype=dtype)
)
result = f(N, M, k)
assert np.allclose(result, np.tri(N, M_, k, dtype=dtype))
assert np.allclose(result, np.tri(N, M, k, dtype=dtype))
assert result.dtype == np.dtype(dtype)
for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]:
......@@ -3889,22 +3883,6 @@ class TestInferShape(utt.InferShapeTester):
[aiscal, biscal, ciscal], [Eye()(aiscal, biscal, ciscal)], [3, 5, 0], Eye
)
def test_Tri(self):
aiscal = iscalar()
biscal = iscalar()
ciscal = iscalar()
self._compile_and_check(
[aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [4, 4, 0], Tri
)
self._compile_and_check(
[aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [4, 5, 0], Tri
)
self._compile_and_check(
[aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [3, 5, 0], Tri
)
def test_ExtractDiag(self):
atens3 = tensor3()
atens3_val = random(4, 5, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论