提交 f9f2080e authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement `tri` symbolically

上级 46f89676
...@@ -17,16 +17,15 @@ from pytensor.tensor.basic import ( ...@@ -17,16 +17,15 @@ from pytensor.tensor.basic import (
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
Tri,
get_scalar_constant_value, get_scalar_constant_value,
) )
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i from pytensor.tensor.shape import Shape_i
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange` ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange` to be constants.
to be constants. The graph that you defined thus cannot be JIT-compiled The graph that you defined thus cannot be JIT-compiled by JAX.
by JAX. An example of a graph that can be compiled to JAX: An example of a graph that can be compiled to JAX:
>>> import pytensor.tensor as pt >>> import pytensor.tensor as pt
>>> pt.arange(1, 10, 2) >>> pt.arange(1, 10, 2)
""" """
...@@ -185,19 +184,3 @@ def jax_funcify_ScalarFromTensor(op, **kwargs): ...@@ -185,19 +184,3 @@ def jax_funcify_ScalarFromTensor(op, **kwargs):
return jnp.array(x).flatten()[0] return jnp.array(x).flatten()[0]
return scalar_from_tensor return scalar_from_tensor
@jax_funcify.register(Tri)
def jax_funcify_Tri(op, node, **kwargs):
# node.inputs is N, M, k
const_args = [getattr(x, "data", None) for x in node.inputs]
def tri(*args):
# args is N, M, k
args = [
x if const_x is None else const_x
for x, const_x in zip(args, const_args, strict=True)
]
return jnp.tri(*args, dtype=op.dtype)
return tri
...@@ -13,7 +13,6 @@ from pytensor.tensor.basic import ( ...@@ -13,7 +13,6 @@ from pytensor.tensor.basic import (
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
Tri,
get_scalar_constant_value, get_scalar_constant_value,
) )
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
...@@ -219,23 +218,6 @@ def mlx_funcify_ScalarFromTensor(op, **kwargs): ...@@ -219,23 +218,6 @@ def mlx_funcify_ScalarFromTensor(op, **kwargs):
return scalar_from_tensor return scalar_from_tensor
@mlx_funcify.register(Tri)
def mlx_funcify_Tri(op, node, **kwargs):
# node.inputs -> N, M, k
const_args = [getattr(inp, "data", None) for inp in node.inputs]
dtype = convert_dtype_to_mlx(op.dtype)
def tri(*args):
# Replace args with compile-time constants when available
args = [
arg if const_a is None else const_a
for arg, const_a in zip(args, const_args, strict=True)
]
return mx.tri(*args, dtype=dtype)
return tri
@mlx_funcify.register(AllocEmpty) @mlx_funcify.register(AllocEmpty)
def mlx_funcify_AllocEmpty(op, node, **kwargs): def mlx_funcify_AllocEmpty(op, node, **kwargs):
dtype = convert_dtype_to_mlx(op.dtype) dtype = convert_dtype_to_mlx(op.dtype)
......
...@@ -1088,39 +1088,6 @@ def nonzero_values(a): ...@@ -1088,39 +1088,6 @@ def nonzero_values(a):
return _a.flatten()[flatnonzero(_a)] return _a.flatten()[flatnonzero(_a)]
class Tri(Op):
__props__ = ("dtype",)
def __init__(self, dtype=None):
if dtype is None:
dtype = config.floatX
else:
dtype = np.dtype(dtype).name
self.dtype = dtype
def make_node(self, N, M, k):
N = as_tensor_variable(N)
M = as_tensor_variable(M)
k = as_tensor_variable(k)
return Apply(
self,
[N, M, k],
[TensorType(dtype=self.dtype, shape=(None, None))()],
)
def perform(self, node, inp, out_):
N, M, k = inp
(out,) = out_
out[0] = np.tri(N, M, k, dtype=self.dtype)
def infer_shape(self, fgraph, node, in_shapes):
out_shape = [node.inputs[0], node.inputs[1]]
return [out_shape]
def grad(self, inp, grads):
return [grad_undefined(self, i, inp[i]) for i in range(3)]
def tri(N, M=None, k=0, dtype=None): def tri(N, M=None, k=0, dtype=None):
""" """
An array with ones at and below the given diagonal and zeros elsewhere. An array with ones at and below the given diagonal and zeros elsewhere.
...@@ -1148,10 +1115,12 @@ def tri(N, M=None, k=0, dtype=None): ...@@ -1148,10 +1115,12 @@ def tri(N, M=None, k=0, dtype=None):
""" """
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
if M is None: if M is None:
M = N M = N
op = Tri(dtype) # Implementation adapted from https://github.com/numpy/numpy/blob/2f7fe64b8b6d7591dd208942f1cc74473d5db4cb/numpy/lib/_twodim_base_impl.py#L421-L433
return op(N, M, k) m = arange(N)[:, None] >= arange(-k, M - k)[None, :]
return m.astype(dtype)
def tril(m, k=0): def tril(m, k=0):
......
import re
import numpy as np import numpy as np
import pytest import pytest
...@@ -210,10 +212,6 @@ def test_tri(): ...@@ -210,10 +212,6 @@ def test_tri():
compare_jax_and_py([], [out], []) compare_jax_and_py([], [out], [])
@pytest.mark.skipif(
jax.__version__ == "0.4.31",
reason="https://github.com/google/jax/issues/22751",
)
def test_tri_nonconcrete(): def test_tri_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values.""" """JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
...@@ -228,7 +226,10 @@ def test_tri_nonconcrete(): ...@@ -228,7 +226,10 @@ def test_tri_nonconcrete():
out = ptb.tri(m, n, k) out = ptb.tri(m, n, k)
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but with pytest.raises(
# the error handler raises an Attribute error first, so that's what this test needs to pass NotImplementedError,
with pytest.raises((AttributeError, TypeError)): match=re.escape(
"JAX requires the arguments of `jax.numpy.arange` to be constants"
),
):
compare_jax_and_py([m, n, k], [out], [m_test_value, n_test_value, k_test_value]) compare_jax_and_py([m, n, k], [out], [m_test_value, n_test_value, k_test_value])
...@@ -35,7 +35,6 @@ from pytensor.tensor.basic import ( ...@@ -35,7 +35,6 @@ from pytensor.tensor.basic import (
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
Tri,
alloc, alloc,
alloc_diag, alloc_diag,
arange, arange,
...@@ -972,22 +971,17 @@ class TestEye: ...@@ -972,22 +971,17 @@ class TestEye:
class TestTriangle: class TestTriangle:
def test_tri(self): def test_tri(self):
def check(dtype, N, M_=None, k=0): def check(dtype, N, M=None, k=0):
# PyTensor does not accept None as a tensor. if M is None:
# So we must use a real value.
M = M_
# Currently DebugMode does not support None as inputs even if this is
# allowed.
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
M = N M = N
N_symb = iscalar() N_symb = iscalar("N")
M_symb = iscalar() M_symb = iscalar("M")
k_symb = iscalar() k_symb = iscalar("k")
f = function( f = function(
[N_symb, M_symb, k_symb], tri(N_symb, M_symb, k_symb, dtype=dtype) [N_symb, M_symb, k_symb], tri(N_symb, M_symb, k_symb, dtype=dtype)
) )
result = f(N, M, k) result = f(N, M, k)
assert np.allclose(result, np.tri(N, M_, k, dtype=dtype)) assert np.allclose(result, np.tri(N, M, k, dtype=dtype))
assert result.dtype == np.dtype(dtype) assert result.dtype == np.dtype(dtype)
for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]: for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]:
...@@ -3889,22 +3883,6 @@ class TestInferShape(utt.InferShapeTester): ...@@ -3889,22 +3883,6 @@ class TestInferShape(utt.InferShapeTester):
[aiscal, biscal, ciscal], [Eye()(aiscal, biscal, ciscal)], [3, 5, 0], Eye [aiscal, biscal, ciscal], [Eye()(aiscal, biscal, ciscal)], [3, 5, 0], Eye
) )
def test_Tri(self):
aiscal = iscalar()
biscal = iscalar()
ciscal = iscalar()
self._compile_and_check(
[aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [4, 4, 0], Tri
)
self._compile_and_check(
[aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [4, 5, 0], Tri
)
self._compile_and_check(
[aiscal, biscal, ciscal], [Tri()(aiscal, biscal, ciscal)], [3, 5, 0], Tri
)
def test_ExtractDiag(self): def test_ExtractDiag(self):
atens3 = tensor3() atens3 = tensor3()
atens3_val = random(4, 5, 3) atens3_val = random(4, 5, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论