提交 cc8c4992 authored 作者: Adv's avatar Adv 提交者: Ricardo Vieira

Stop using FunctionGraph and tag.test_value in linker tests

上级 51ea1a0b
...@@ -7,12 +7,12 @@ import pytest ...@@ -7,12 +7,12 @@ import pytest
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import JAX, Mode from pytensor.compile.mode import JAX, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value from pytensor.graph.op import Op
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
...@@ -34,25 +34,28 @@ py_mode = Mode(linker="py", optimizer=None) ...@@ -34,25 +34,28 @@ py_mode = Mode(linker="py", optimizer=None)
def compare_jax_and_py( def compare_jax_and_py(
fgraph: FunctionGraph, graph_inputs: Iterable[Variable],
graph_outputs: Variable | Iterable[Variable],
test_inputs: Iterable, test_inputs: Iterable,
*,
assert_fn: Callable | None = None, assert_fn: Callable | None = None,
must_be_device_array: bool = True, must_be_device_array: bool = True,
jax_mode=jax_mode, jax_mode=jax_mode,
py_mode=py_mode, py_mode=py_mode,
): ):
"""Function to compare python graph output and jax compiled output for testing equality """Function to compare python function output and jax compiled output for testing equality
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to The inputs and outputs are then passed to this function which then compiles the given function in both
this function which then compiles the graphs in both jax and python, runs the calculation jax and python, runs the calculation in both and checks if the results are the same
in both and checks if the results are the same
Parameters Parameters
---------- ----------
fgraph: FunctionGraph graph_inputs:
PyTensor function Graph object Symbolic inputs to the graph
outputs:
Symbolic outputs of the graph
test_inputs: iter test_inputs: iter
Numerical inputs for testing the function graph Numerical inputs for testing the function.
assert_fn: func, opt assert_fn: func, opt
Assert function used to check for equality between python and jax. If not Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose provided uses np.testing.assert_allclose
...@@ -68,8 +71,10 @@ def compare_jax_and_py( ...@@ -68,8 +71,10 @@ def compare_jax_and_py(
if assert_fn is None: if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] if any(inp.owner is not None for inp in graph_inputs):
pytensor_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode) raise ValueError("Inputs must be root variables")
pytensor_jax_fn = function(graph_inputs, graph_outputs, mode=jax_mode)
jax_res = pytensor_jax_fn(*test_inputs) jax_res = pytensor_jax_fn(*test_inputs)
if must_be_device_array: if must_be_device_array:
...@@ -78,10 +83,10 @@ def compare_jax_and_py( ...@@ -78,10 +83,10 @@ def compare_jax_and_py(
else: else:
assert isinstance(jax_res, jax.Array) assert isinstance(jax_res, jax.Array)
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs) py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1: if isinstance(graph_outputs, list | tuple):
for j, p in zip(jax_res, py_res, strict=True): for j, p in zip(jax_res, py_res, strict=True):
assert_fn(j, p) assert_fn(j, p)
else: else:
...@@ -187,16 +192,14 @@ def test_jax_ifelse(): ...@@ -187,16 +192,14 @@ def test_jax_ifelse():
false_vals = np.r_[-1, -2, -3] false_vals = np.r_[-1, -2, -3]
x = ifelse(np.array(True), true_vals, false_vals) x = ifelse(np.array(True), true_vals, false_vals)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py([], [x], [])
a = dscalar("a") a = dscalar("a")
a.tag.test_value = np.array(0.2, dtype=config.floatX) a_test = np.array(0.2, dtype=config.floatX)
x = ifelse(a < 0.5, true_vals, false_vals) x = ifelse(a < 0.5, true_vals, false_vals)
x_fg = FunctionGraph([a], [x]) # I.e. False
compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs]) compare_jax_and_py([a], [x], [a_test])
def test_jax_checkandraise(): def test_jax_checkandraise():
...@@ -209,11 +212,6 @@ def test_jax_checkandraise(): ...@@ -209,11 +212,6 @@ def test_jax_checkandraise():
function((p,), res, mode=jax_mode) function((p,), res, mode=jax_mode)
def set_test_value(x, v):
x.tag.test_value = v
return x
def test_OpFromGraph(): def test_OpFromGraph():
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False) ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
...@@ -221,10 +219,9 @@ def test_OpFromGraph(): ...@@ -221,10 +219,9 @@ def test_OpFromGraph():
o1, o2 = ofg_2(y, z) o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2 out = ofg_1(x, o1) + o2
out_fg = FunctionGraph([x, y, z], [out])
xv = np.ones((2, 2), dtype=config.floatX) xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_jax_and_py(out_fg, [xv, yv, zv]) compare_jax_and_py([x, y, z], [out], [xv, yv, zv])
...@@ -4,8 +4,6 @@ import pytest ...@@ -4,8 +4,6 @@ import pytest
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas from pytensor.tensor import blas as pt_blas
...@@ -16,21 +14,20 @@ from tests.link.jax.test_basic import compare_jax_and_py ...@@ -16,21 +14,20 @@ from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_BatchedDot(): def test_jax_BatchedDot():
# tensor3 . tensor3 # tensor3 . tensor3
a = tensor3("a") a = tensor3("a")
a.tag.test_value = ( a_test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
) )
b = tensor3("b") b = tensor3("b")
b.tag.test_value = ( b_test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
) )
out = pt_blas.BatchedDot()(a, b) out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out]) compare_jax_and_py([a, b], [out], [a_test_value, b_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# A dimension mismatch should raise a TypeError for compatibility # A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)] inputs = [a_test_value[:-1], b_test_value]
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts) jax_mode = Mode(JAXLinker(), opts)
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) pytensor_jax_fn = function([a, b], [out], mode=jax_mode)
with pytest.raises(TypeError): with pytest.raises(TypeError):
pytensor_jax_fn(*inputs) pytensor_jax_fn(*inputs)
...@@ -2,7 +2,6 @@ import numpy as np ...@@ -2,7 +2,6 @@ import numpy as np
import pytest import pytest
from pytensor import config from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot, matmul from pytensor.tensor.math import Dot, matmul
...@@ -32,8 +31,7 @@ def test_matmul(matmul_op): ...@@ -32,8 +31,7 @@ def test_matmul(matmul_op):
out = matmul_op(a, b) out = matmul_op(a, b)
assert isinstance(out.owner.op, Blockwise) assert isinstance(out.owner.op, Blockwise)
fg = FunctionGraph([a, b], [out]) fn, _ = compare_jax_and_py([a, b], [out], test_values)
fn, _ = compare_jax_and_py(fg, test_values)
# Check we are not adding any unnecessary stuff # Check we are not adding any unnecessary stuff
jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values)) jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))
......
...@@ -2,7 +2,6 @@ import numpy as np ...@@ -2,7 +2,6 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -22,8 +21,7 @@ def test_jax_einsum(): ...@@ -22,8 +21,7 @@ def test_jax_einsum():
} }
x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
out = pt.einsum(subscripts, x_pt, y_pt, z_pt) out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
fg = FunctionGraph([x_pt, y_pt, z_pt], [out]) compare_jax_and_py([x_pt, y_pt, z_pt], [out], [x, y, z])
compare_jax_and_py(fg, [x, y, z])
def test_ellipsis_einsum(): def test_ellipsis_einsum():
...@@ -34,5 +32,4 @@ def test_ellipsis_einsum(): ...@@ -34,5 +32,4 @@ def test_ellipsis_einsum():
x_pt = pt.tensor("x", shape=x.shape) x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape) y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt) out = pt.einsum(subscripts, x_pt, y_pt)
fg = FunctionGraph([x_pt, y_pt], [out]) compare_jax_and_py([x_pt, y_pt], [out], [x, y])
compare_jax_and_py(fg, [x, y])
...@@ -6,8 +6,6 @@ import pytensor ...@@ -6,8 +6,6 @@ import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile import get_mode from pytensor.compile import get_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import elemwise as pt_elemwise from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import prod from pytensor.tensor.math import prod
...@@ -26,22 +24,22 @@ def test_jax_Dimshuffle(): ...@@ -26,22 +24,22 @@ def test_jax_Dimshuffle():
a_pt = matrix("a") a_pt = matrix("a")
x = a_pt.T x = a_pt.T
x_fg = FunctionGraph([a_pt], [x]) compare_jax_and_py(
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) [a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)
x = a_pt.dimshuffle([0, 1, "x"]) x = a_pt.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_pt], [x]) compare_jax_and_py(
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) [a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)
a_pt = tensor(dtype=config.floatX, shape=(None, 1)) a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = a_pt.dimshuffle((0,)) x = a_pt.dimshuffle((0,))
x_fg = FunctionGraph([a_pt], [x]) compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1)) a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt) x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x]) compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_jax_CAReduce(): def test_jax_CAReduce():
...@@ -49,64 +47,58 @@ def test_jax_CAReduce(): ...@@ -49,64 +47,58 @@ def test_jax_CAReduce():
a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX) a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
x = pt_sum(a_pt, axis=None) x = pt_sum(a_pt, axis=None)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)]) compare_jax_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)])
a_pt = matrix("a") a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = pt_sum(a_pt, axis=0) x = pt_sum(a_pt, axis=0)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = pt_sum(a_pt, axis=1) x = pt_sum(a_pt, axis=1)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
a_pt = matrix("a") a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = prod(a_pt, axis=0) x = prod(a_pt, axis=0)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = pt_all(a_pt) x = pt_all(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis): def test_softmax(axis):
x = matrix("x") x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = softmax(x, axis=axis) out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [x_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis): def test_logsoftmax(axis):
x = matrix("x") x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = log_softmax(x, axis=axis) out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py([x], [out], [x_test_value])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis): def test_softmax_grad(axis):
dy = matrix("dy") dy = matrix("dy")
dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm") sm = matrix("sm")
sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm) out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py([dy, sm], [out], [dy_test_value, sm_test_value])
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)]) @pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
...@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark): ...@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
def test_multiple_input_multiply(): def test_multiple_input_multiply():
x, y, z = vectors("xyz") x, y, z = vectors("xyz")
out = pt.mul(x, y, z) out = pt.mul(x, y, z)
compare_jax_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]])
fg = FunctionGraph(outputs=[out], clone=False)
compare_jax_and_py(fg, [[1.5], [2.5], [3.5]])
...@@ -3,8 +3,6 @@ import pytest ...@@ -3,8 +3,6 @@ import pytest
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as pt_extra_ops from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.sort import argsort from pytensor.tensor.sort import argsort
from pytensor.tensor.type import matrix, tensor from pytensor.tensor.type import matrix, tensor
...@@ -19,57 +17,45 @@ def test_extra_ops(): ...@@ -19,57 +17,45 @@ def test_extra_ops():
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.cumsum(a, axis=0) out = pt_extra_ops.cumsum(a, axis=0)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.cumprod(a, axis=1) out = pt_extra_ops.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.diff(a, n=2, axis=1) out = pt_extra_ops.diff(a, n=2, axis=1)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.repeat(a, (3, 3), axis=1) out = pt_extra_ops.repeat(a, (3, 3), axis=1)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
c = ptb.as_tensor(5) c = ptb.as_tensor(5)
out = pt_extra_ops.fill_diagonal(a, c) out = pt_extra_ops.fill_diagonal(a, c)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
out = pt_extra_ops.fill_diagonal_offset(a, c, c) out = pt_extra_ops.fill_diagonal_offset(a, c, c)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
out = pt_extra_ops.Unique(axis=1)(a) out = pt_extra_ops.Unique(axis=1)(a)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
indices = np.arange(np.prod((3, 4))) indices = np.arange(np.prod((3, 4)))
out = pt_extra_ops.unravel_index(indices, (3, 4), order="C") out = pt_extra_ops.unravel_index(indices, (3, 4), order="C")
fgraph = FunctionGraph([], out) compare_jax_and_py([], out, [], must_be_device_array=False)
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
v = ptb.as_tensor_variable(6.0) v = ptb.as_tensor_variable(6.0)
sorted_idx = argsort(a.ravel()) sorted_idx = argsort(a.ravel())
out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v) out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_bartlett_dynamic_shape(): def test_bartlett_dynamic_shape():
c = tensor(shape=(), dtype=int) c = tensor(shape=(), dtype=int)
out = pt_extra_ops.bartlett(c) out = pt_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out]) compare_jax_and_py([], [out], [np.array(5)])
compare_jax_and_py(fgraph, [np.array(5)])
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
...@@ -79,8 +65,7 @@ def test_ravel_multi_index_dynamic_shape(): ...@@ -79,8 +65,7 @@ def test_ravel_multi_index_dynamic_shape():
x = tensor(shape=(None,), dtype=int) x = tensor(shape=(None,), dtype=int)
y = tensor(shape=(None,), dtype=int) y = tensor(shape=(None,), dtype=int)
out = pt_extra_ops.ravel_multi_index((x, y), (3, 4)) out = pt_extra_ops.ravel_multi_index((x, y), (3, 4))
fgraph = FunctionGraph([], [out]) compare_jax_and_py([], [out], [x_test, y_test])
compare_jax_and_py(fgraph, [x_test, y_test])
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
...@@ -89,5 +74,4 @@ def test_unique_dynamic_shape(): ...@@ -89,5 +74,4 @@ def test_unique_dynamic_shape():
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.Unique()(a) out = pt_extra_ops.Unique()(a)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
...@@ -2,8 +2,6 @@ import numpy as np ...@@ -2,8 +2,6 @@ import numpy as np
import pytest import pytest
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.math import Argmax, Max, maximum from pytensor.tensor.math import Argmax, Max, maximum
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, vector from pytensor.tensor.type import dvector, matrix, scalar, vector
...@@ -20,33 +18,39 @@ def test_jax_max_and_argmax(): ...@@ -20,33 +18,39 @@ def test_jax_max_and_argmax():
mx = Max([0])(x) mx = Max([0])(x)
amx = Argmax([0])(x) amx = Argmax([0])(x)
out = mx * amx out = mx * amx
out_fg = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.r_[1, 2]])
compare_jax_and_py(out_fg, [np.r_[1, 2]])
def test_dot(): def test_dot():
y = vector("y") y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) y_test_value = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x") x = vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) x_test_value = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A") A = matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=config.floatX) A_test_value = np.empty((2, 2), dtype=config.floatX)
alpha = scalar("alpha") alpha = scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=config.floatX) alpha_test_value = np.array(3.0, dtype=config.floatX)
beta = scalar("beta") beta = scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=config.floatX) beta_test_value = np.array(5.0, dtype=config.floatX)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible # This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should # optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone. # leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) compare_jax_and_py(
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) [y, x, A, alpha, beta],
out,
[
y_test_value,
x_test_value,
A_test_value,
alpha_test_value,
beta_test_value,
],
)
out = maximum(y, x) out = maximum(y, x)
fgraph = FunctionGraph([y, x], [out]) compare_jax_and_py([y, x], [out], [y_test_value, x_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = pt_max(y) out = pt_max(y)
fgraph = FunctionGraph([y], [out]) compare_jax_and_py([y], [out], [y_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
...@@ -3,7 +3,6 @@ import pytest ...@@ -3,7 +3,6 @@ import pytest
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor.type import matrix from pytensor.tensor.type import matrix
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -21,41 +20,34 @@ def test_jax_basic_multiout(): ...@@ -21,41 +20,34 @@ def test_jax_basic_multiout():
x = matrix("x") x = matrix("x")
outs = pt_nlinalg.eig(x) outs = pt_nlinalg.eig(x)
out_fg = FunctionGraph([x], outs)
def assert_fn(x, y): def assert_fn(x, y):
np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.eigh(x) outs = pt_nlinalg.eigh(x)
out_fg = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.qr(x, mode="full") outs = pt_nlinalg.qr(x, mode="full")
out_fg = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.qr(x, mode="reduced") outs = pt_nlinalg.qr(x, mode="reduced")
out_fg = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.svd(x) outs = pt_nlinalg.svd(x)
out_fg = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.slogdet(x) outs = pt_nlinalg.slogdet(x)
out_fg = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
def test_pinv(): def test_pinv():
x = matrix("x") x = matrix("x")
x_inv = pt_nlinalg.pinv(x) x_inv = pt_nlinalg.pinv(x)
fgraph = FunctionGraph([x], [x_inv])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_jax_and_py(fgraph, [x_np]) compare_jax_and_py([x], [x_inv], [x_np])
def test_pinv_hermitian(): def test_pinv_hermitian():
...@@ -94,8 +86,7 @@ def test_kron(): ...@@ -94,8 +86,7 @@ def test_kron():
y = matrix("y") y = matrix("y")
z = pt_nlinalg.kron(x, y) z = pt_nlinalg.kron(x, y)
fgraph = FunctionGraph([x, y], [z])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_jax_and_py(fgraph, [x_np, y_np]) compare_jax_and_py([x, y], [z], [x_np, y_np])
...@@ -3,7 +3,6 @@ import pytest ...@@ -3,7 +3,6 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor.pad import PadMode from pytensor.tensor.pad import PadMode
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -53,10 +52,10 @@ def test_jax_pad(mode: PadMode, kwargs): ...@@ -53,10 +52,10 @@ def test_jax_pad(mode: PadMode, kwargs):
x = np.random.normal(size=(3, 3)) x = np.random.normal(size=(3, 3))
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs) res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
res_fg = FunctionGraph([x_pt], [res])
compare_jax_and_py( compare_jax_and_py(
res_fg, [x_pt],
[res],
[x], [x],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL), assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
py_mode="FAST_RUN", py_mode="FAST_RUN",
......
...@@ -5,7 +5,6 @@ import pytensor.scalar.basic as ps ...@@ -5,7 +5,6 @@ import pytensor.scalar.basic as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite
from pytensor.tensor import as_tensor from pytensor.tensor import as_tensor
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
...@@ -51,20 +50,19 @@ def test_second(): ...@@ -51,20 +50,19 @@ def test_second():
b = scalar("b") b = scalar("b")
out = ps.second(a0, b) out = ps.second(a0, b)
fgraph = FunctionGraph([a0, b], [out]) compare_jax_and_py([a0, b], [out], [10.0, 5.0])
compare_jax_and_py(fgraph, [10.0, 5.0])
a1 = vector("a1") a1 = vector("a1")
out = pt.second(a1, b) out = pt.second(a1, b)
fgraph = FunctionGraph([a1, b], [out]) compare_jax_and_py([a1, b], [out], [np.zeros([5], dtype=config.floatX), 5.0])
compare_jax_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0])
a2 = matrix("a2", shape=(1, None), dtype="float64") a2 = matrix("a2", shape=(1, None), dtype="float64")
b2 = matrix("b2", shape=(None, 1), dtype="int32") b2 = matrix("b2", shape=(None, 1), dtype="int32")
out = pt.second(a2, b2) out = pt.second(a2, b2)
fgraph = FunctionGraph([a2, b2], [out])
compare_jax_and_py( compare_jax_and_py(
fgraph, [np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")] [a2, b2],
[out],
[np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")],
) )
...@@ -81,11 +79,10 @@ def test_second_constant_scalar(): ...@@ -81,11 +79,10 @@ def test_second_constant_scalar():
def test_identity(): def test_identity():
a = scalar("a") a = scalar("a")
a.tag.test_value = 10 a_test_value = 10
out = ps.identity(a) out = ps.identity(a)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -109,13 +106,11 @@ def test_jax_Composite_singe_output(x, y, x_val, y_val): ...@@ -109,13 +106,11 @@ def test_jax_Composite_singe_output(x, y, x_val, y_val):
out = comp_op(x, y) out = comp_op(x, y)
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [ test_input_vals = [
x_val.astype(config.floatX), x_val.astype(config.floatX),
y_val.astype(config.floatX), y_val.astype(config.floatX),
] ]
_ = compare_jax_and_py(out_fg, test_input_vals) _ = compare_jax_and_py([x, y], [out], test_input_vals)
def test_jax_Composite_multi_output(): def test_jax_Composite_multi_output():
...@@ -124,32 +119,28 @@ def test_jax_Composite_multi_output(): ...@@ -124,32 +119,28 @@ def test_jax_Composite_multi_output():
x_s = ps.float64("xs") x_s = ps.float64("xs")
outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x) outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x)
fgraph = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [np.arange(10, dtype=config.floatX)])
compare_jax_and_py(fgraph, [np.arange(10, dtype=config.floatX)])
def test_erf(): def test_erf():
x = scalar("x") x = scalar("x")
out = erf(x) out = erf(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0]) compare_jax_and_py([x], [out], [1.0])
def test_erfc(): def test_erfc():
x = scalar("x") x = scalar("x")
out = erfc(x) out = erfc(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0]) compare_jax_and_py([x], [out], [1.0])
def test_erfinv(): def test_erfinv():
x = scalar("x") x = scalar("x")
out = erfinv(x) out = erfinv(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [0.95]) compare_jax_and_py([x], [out], [0.95])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -166,8 +157,7 @@ def test_tfp_ops(op, test_values): ...@@ -166,8 +157,7 @@ def test_tfp_ops(op, test_values):
inputs = [as_tensor(test_value).type() for test_value in test_values] inputs = [as_tensor(test_value).type() for test_value in test_values]
output = op(*inputs) output = op(*inputs)
fg = FunctionGraph(inputs, [output]) compare_jax_and_py(inputs, [output], test_values)
compare_jax_and_py(fg, test_values)
def test_betaincinv(): def test_betaincinv():
...@@ -175,9 +165,10 @@ def test_betaincinv(): ...@@ -175,9 +165,10 @@ def test_betaincinv():
b = vector("b", dtype="float64") b = vector("b", dtype="float64")
x = vector("x", dtype="float64") x = vector("x", dtype="float64")
out = betaincinv(a, b, x) out = betaincinv(a, b, x)
fg = FunctionGraph([a, b, x], [out])
compare_jax_and_py( compare_jax_and_py(
fg, [a, b, x],
[out],
[ [
np.array([5.5, 7.0]), np.array([5.5, 7.0]),
np.array([5.5, 7.0]), np.array([5.5, 7.0]),
...@@ -190,39 +181,40 @@ def test_gammaincinv(): ...@@ -190,39 +181,40 @@ def test_gammaincinv():
k = vector("k", dtype="float64") k = vector("k", dtype="float64")
x = vector("x", dtype="float64") x = vector("x", dtype="float64")
out = gammaincinv(k, x) out = gammaincinv(k, x)
fg = FunctionGraph([k, x], [out])
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) compare_jax_and_py([k, x], [out], [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
def test_gammainccinv(): def test_gammainccinv():
k = vector("k", dtype="float64") k = vector("k", dtype="float64")
x = vector("x", dtype="float64") x = vector("x", dtype="float64")
out = gammainccinv(k, x) out = gammainccinv(k, x)
fg = FunctionGraph([k, x], [out])
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) compare_jax_and_py([k, x], [out], [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
def test_psi(): def test_psi():
x = scalar("x") x = scalar("x")
out = psi(x) out = psi(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [3.0]) compare_jax_and_py([x], [out], [3.0])
def test_tri_gamma(): def test_tri_gamma():
x = vector("x", dtype="float64") x = vector("x", dtype="float64")
out = tri_gamma(x) out = tri_gamma(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [np.array([3.0, 5.0])]) compare_jax_and_py([x], [out], [np.array([3.0, 5.0])])
def test_polygamma(): def test_polygamma():
n = vector("n", dtype="int32") n = vector("n", dtype="int32")
x = vector("x", dtype="float32") x = vector("x", dtype="float32")
out = polygamma(n, x) out = polygamma(n, x)
fg = FunctionGraph([n, x], [out])
compare_jax_and_py( compare_jax_and_py(
fg, [n, x],
[out],
[ [
np.array([0, 1, 2]).astype("int32"), np.array([0, 1, 2]).astype("int32"),
np.array([0.5, 0.9, 2.5]).astype("float32"), np.array([0.5, 0.9, 2.5]).astype("float32"),
...@@ -233,41 +225,34 @@ def test_polygamma(): ...@@ -233,41 +225,34 @@ def test_polygamma():
def test_log1mexp(): def test_log1mexp():
x = vector("x") x = vector("x")
out = log1mexp(x) out = log1mexp(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]]) compare_jax_and_py([x], [out], [[-1.0, -0.75, -0.5, -0.25]])
def test_nnet(): def test_nnet():
x = vector("x") x = vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) x_test_value = np.r_[1.0, 2.0].astype(config.floatX)
out = sigmoid(x) out = sigmoid(x)
fgraph = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [x_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = softplus(x) out = softplus(x)
fgraph = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [x_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_variadic_Scalar(): def test_jax_variadic_Scalar():
mu = vector("mu", dtype=config.floatX) mu = vector("mu", dtype=config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX) mu_test_value = np.r_[0.1, 1.1].astype(config.floatX)
tau = vector("tau", dtype=config.floatX) tau = vector("tau", dtype=config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) tau_test_value = np.r_[1.0, 2.0].astype(config.floatX)
res = -tau * mu res = -tau * mu
fgraph = FunctionGraph([mu, tau], [res]) compare_jax_and_py([mu, tau], [res], [mu_test_value, tau_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
res = -tau * (tau - mu) ** 2 res = -tau * (tau - mu) ** 2
fgraph = FunctionGraph([mu, tau], [res]) compare_jax_and_py([mu, tau], [res], [mu_test_value, tau_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_add_scalars(): def test_add_scalars():
...@@ -275,8 +260,7 @@ def test_add_scalars(): ...@@ -275,8 +260,7 @@ def test_add_scalars():
size = x.shape[0] + x.shape[0] + x.shape[1] size = x.shape[0] + x.shape[0] + x.shape[1]
out = pt.ones(size).astype(config.floatX) out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.ones((2, 3)).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)])
def test_mul_scalars(): def test_mul_scalars():
...@@ -284,8 +268,7 @@ def test_mul_scalars(): ...@@ -284,8 +268,7 @@ def test_mul_scalars():
size = x.shape[0] * x.shape[0] * x.shape[1] size = x.shape[0] * x.shape[0] * x.shape[1]
out = pt.ones(size).astype(config.floatX) out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.ones((2, 3)).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)])
def test_div_scalars(): def test_div_scalars():
...@@ -293,8 +276,7 @@ def test_div_scalars(): ...@@ -293,8 +276,7 @@ def test_div_scalars():
size = x.shape[0] // x.shape[1] size = x.shape[0] // x.shape[1]
out = pt.ones(size).astype(config.floatX) out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.ones((12, 3)).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)])
def test_mod_scalars(): def test_mod_scalars():
...@@ -302,39 +284,43 @@ def test_mod_scalars(): ...@@ -302,39 +284,43 @@ def test_mod_scalars():
size = x.shape[0] % x.shape[1] size = x.shape[0] % x.shape[1]
out = pt.ones(size).astype(config.floatX) out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.ones((12, 3)).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)])
def test_jax_multioutput(): def test_jax_multioutput():
x = vector("x") x = vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) x_test_value = np.r_[1.0, 2.0].astype(config.floatX)
y = vector("y") y = vector("y")
y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) y_test_value = np.r_[3.0, 4.0].astype(config.floatX)
w = cosh(x**2 + y / 3.0) w = cosh(x**2 + y / 3.0)
v = cosh(x / 3.0 + y**2) v = cosh(x / 3.0 + y**2)
fgraph = FunctionGraph([x, y], [w, v]) compare_jax_and_py([x, y], [w, v], [x_test_value, y_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_logp(): def test_jax_logp():
mu = vector("mu") mu = vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX) mu_test_value = np.r_[0.0, 0.0].astype(config.floatX)
tau = vector("tau") tau = vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX) tau_test_value = np.r_[1.0, 1.0].astype(config.floatX)
sigma = vector("sigma") sigma = vector("sigma")
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX) sigma_test_value = (1.0 / tau_test_value).astype(config.floatX)
value = vector("value") value = vector("value")
value.tag.test_value = np.r_[0.1, -10].astype(config.floatX) value_test_value = np.r_[0.1, -10].astype(config.floatX)
logp = (-tau * (value - mu) ** 2 + log(tau / np.pi / 2.0)) / 2.0 logp = (-tau * (value - mu) ** 2 + log(tau / np.pi / 2.0)) / 2.0
conditions = [sigma > 0] conditions = [sigma > 0]
alltrue = pt_all([pt_all(1 * val) for val in conditions]) alltrue = pt_all([pt_all(1 * val) for val in conditions])
normal_logp = pt.switch(alltrue, logp, -np.inf) normal_logp = pt.switch(alltrue, logp, -np.inf)
fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp]) compare_jax_and_py(
[mu, tau, sigma, value],
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) [normal_logp],
[
mu_test_value,
tau_test_value,
sigma_test_value,
value_test_value,
],
)
...@@ -7,7 +7,6 @@ import pytensor.tensor as pt ...@@ -7,7 +7,6 @@ import pytensor.tensor as pt
from pytensor import function, shared from pytensor import function, shared
from pytensor.compile import get_mode from pytensor.compile import get_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.scan import until from pytensor.scan import until
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
...@@ -30,9 +29,8 @@ def test_scan_sit_sot(view): ...@@ -30,9 +29,8 @@ def test_scan_sit_sot(view):
) )
if view: if view:
xs = xs[view] xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.e] test_input_vals = [np.e]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0], [xs], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) @pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
...@@ -45,9 +43,8 @@ def test_scan_mit_sot(view): ...@@ -45,9 +43,8 @@ def test_scan_mit_sot(view):
) )
if view: if view:
xs = xs[view] xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.full((3,), np.e)] test_input_vals = [np.full((3,), np.e)]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0], [xs], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)]) @pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)])
...@@ -72,9 +69,8 @@ def test_scan_multiple_mit_sot(view_x, view_y): ...@@ -72,9 +69,8 @@ def test_scan_multiple_mit_sot(view_x, view_y):
if view_y: if view_y:
ys = ys[view_y] ys = ys[view_y]
fg = FunctionGraph([x0, y0], [xs, ys])
test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)] test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0, y0], [xs, ys], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)]) @pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)])
...@@ -90,12 +86,11 @@ def test_scan_nit_sot(view): ...@@ -90,12 +86,11 @@ def test_scan_nit_sot(view):
) )
if view: if view:
ys = ys[view] ys = ys[view]
fg = FunctionGraph([xs], [ys])
test_input_vals = [rng.normal(size=10)] test_input_vals = [rng.normal(size=10)]
# We need to remove pushout rewrites, or the whole scan would just be # We need to remove pushout rewrites, or the whole scan would just be
# converted to an Elemwise on xs # converted to an Elemwise on xs
jax_fn, _ = compare_jax_and_py( jax_fn, _ = compare_jax_and_py(
fg, test_input_vals, jax_mode=get_mode("JAX").excluding("scan_pushout") [xs], [ys], test_input_vals, jax_mode=get_mode("JAX").excluding("scan_pushout")
) )
scan_nodes = [ scan_nodes = [
node for node in jax_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan) node for node in jax_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
...@@ -112,8 +107,7 @@ def test_scan_mit_mot(): ...@@ -112,8 +107,7 @@ def test_scan_mit_mot():
n_steps=10, n_steps=10,
) )
grads_wrt_xs = pt.grad(ys.sum(), wrt=xs) grads_wrt_xs = pt.grad(ys.sum(), wrt=xs)
fg = FunctionGraph([xs], [grads_wrt_xs]) compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)])
compare_jax_and_py(fg, [np.arange(10)])
def test_scan_update(): def test_scan_update():
...@@ -192,8 +186,7 @@ def test_scan_while(): ...@@ -192,8 +186,7 @@ def test_scan_while():
n_steps=100, n_steps=100,
) )
fg = FunctionGraph([], [xs]) compare_jax_and_py([], [xs], [])
compare_jax_and_py(fg, [])
def test_scan_SEIR(): def test_scan_SEIR():
...@@ -257,11 +250,6 @@ def test_scan_SEIR(): ...@@ -257,11 +250,6 @@ def test_scan_SEIR():
logp_c_all.name = "C_t_logp" logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp" logp_d_all.name = "D_t_logp"
out_fg = FunctionGraph(
[at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
s0, e0, i0 = 100, 50, 25 s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=config.floatX) logp_c0 = np.array(0.0, dtype=config.floatX)
logp_d0 = np.array(0.0, dtype=config.floatX) logp_d0 = np.array(0.0, dtype=config.floatX)
...@@ -283,7 +271,12 @@ def test_scan_SEIR(): ...@@ -283,7 +271,12 @@ def test_scan_SEIR():
gamma_val, gamma_val,
delta_val, delta_val,
] ]
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX") compare_jax_and_py(
[at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
test_input_vals,
jax_mode="JAX",
)
def test_scan_mitsot_with_nonseq(): def test_scan_mitsot_with_nonseq():
...@@ -313,10 +306,8 @@ def test_scan_mitsot_with_nonseq(): ...@@ -313,10 +306,8 @@ def test_scan_mitsot_with_nonseq():
y_scan_pt.name = "y" y_scan_pt.name = "y"
y_scan_pt.owner.inputs[0].name = "y_all" y_scan_pt.owner.inputs[0].name = "y_all"
out_fg = FunctionGraph([a_pt], [y_scan_pt])
test_input_vals = [np.array(10.0).astype(config.floatX)] test_input_vals = [np.array(10.0).astype(config.floatX)]
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([a_pt], [y_scan_pt], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("x0_func", [dvector, dmatrix]) @pytest.mark.parametrize("x0_func", [dvector, dmatrix])
...@@ -343,9 +334,8 @@ def test_nd_scan_sit_sot(x0_func, A_func): ...@@ -343,9 +334,8 @@ def test_nd_scan_sit_sot(x0_func, A_func):
) )
A_val = np.eye(k, dtype=config.floatX) A_val = np.eye(k, dtype=config.floatX)
fg = FunctionGraph([x0, A], [xs])
test_input_vals = [x0_val, A_val] test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0, A], [xs], test_input_vals, jax_mode="JAX")
def test_nd_scan_sit_sot_with_seq(): def test_nd_scan_sit_sot_with_seq():
...@@ -366,9 +356,8 @@ def test_nd_scan_sit_sot_with_seq(): ...@@ -366,9 +356,8 @@ def test_nd_scan_sit_sot_with_seq():
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
A_val = np.eye(k, dtype=config.floatX) A_val = np.eye(k, dtype=config.floatX)
fg = FunctionGraph([x, A], [xs])
test_input_vals = [x_val, A_val] test_input_vals = [x_val, A_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x, A], [xs], test_input_vals, jax_mode="JAX")
def test_nd_scan_mit_sot(): def test_nd_scan_mit_sot():
...@@ -384,13 +373,12 @@ def test_nd_scan_mit_sot(): ...@@ -384,13 +373,12 @@ def test_nd_scan_mit_sot():
n_steps=10, n_steps=10,
) )
fg = FunctionGraph([x0, A, B], [xs])
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3) x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
A_val = np.eye(3, dtype=config.floatX) A_val = np.eye(3, dtype=config.floatX)
B_val = np.eye(3, dtype=config.floatX) B_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val, B_val] test_input_vals = [x0_val, A_val, B_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0, A, B], [xs], test_input_vals, jax_mode="JAX")
def test_nd_scan_sit_sot_with_carry(): def test_nd_scan_sit_sot_with_carry():
...@@ -409,12 +397,11 @@ def test_nd_scan_sit_sot_with_carry(): ...@@ -409,12 +397,11 @@ def test_nd_scan_sit_sot_with_carry():
mode=get_mode("JAX"), mode=get_mode("JAX"),
) )
fg = FunctionGraph([x0, A], xs)
x0_val = np.arange(3, dtype=config.floatX) x0_val = np.arange(3, dtype=config.floatX)
A_val = np.eye(3, dtype=config.floatX) A_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val] test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0, A], xs, test_input_vals, jax_mode="JAX")
def test_default_mode_excludes_incompatible_rewrites(): def test_default_mode_excludes_incompatible_rewrites():
...@@ -422,8 +409,7 @@ def test_default_mode_excludes_incompatible_rewrites(): ...@@ -422,8 +409,7 @@ def test_default_mode_excludes_incompatible_rewrites():
A = matrix("A") A = matrix("A")
B = matrix("B") B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
fg = FunctionGraph([A, B], [out]) compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX")
compare_jax_and_py(fg, [np.eye(3), np.eye(3)], jax_mode="JAX")
def test_dynamic_sequence_length(): def test_dynamic_sequence_length():
......
...@@ -4,7 +4,6 @@ import pytest ...@@ -4,7 +4,6 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import iscalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -13,29 +12,27 @@ from tests.link.jax.test_basic import compare_jax_and_py ...@@ -13,29 +12,27 @@ from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_shape_ops(): def test_jax_shape_ops():
x_np = np.zeros((20, 3)) x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np)) x = Shape()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [], must_be_device_array=False) compare_jax_and_py([], [x], [], must_be_device_array=False)
x = Shape_i(1)(pt.as_tensor_variable(x_np)) x = Shape_i(1)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [], must_be_device_array=False) compare_jax_and_py([], [x], [], must_be_device_array=False)
def test_jax_specify_shape(): def test_jax_specify_shape():
in_pt = pt.matrix("in") in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None)) x = pt.specify_shape(in_pt, (4, None))
x_fg = FunctionGraph([in_pt], [x]) compare_jax_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)])
compare_jax_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
# When used to assert two arrays have similar shapes # When used to assert two arrays have similar shapes
in_pt = pt.matrix("in") in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape") shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape) x = pt.specify_shape(in_pt, shape_pt.shape)
x_fg = FunctionGraph([in_pt, shape_pt], [x])
compare_jax_and_py( compare_jax_and_py(
x_fg, [in_pt, shape_pt],
[x],
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
) )
...@@ -43,20 +40,17 @@ def test_jax_specify_shape(): ...@@ -43,20 +40,17 @@ def test_jax_specify_shape():
def test_jax_Reshape_constant(): def test_jax_Reshape_constant():
a = vector("a") a = vector("a")
x = reshape(a, (2, 2)) x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x]) compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
def test_jax_Reshape_concrete_shape(): def test_jax_Reshape_concrete_shape():
"""JAX should compile when a concrete value is passed for the `shape` parameter.""" """JAX should compile when a concrete value is passed for the `shape` parameter."""
a = vector("a") a = vector("a")
x = reshape(a, a.shape) x = reshape(a, a.shape)
x_fg = FunctionGraph([a], [x]) compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
x_fg = FunctionGraph([a], [x]) compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
@pytest.mark.xfail( @pytest.mark.xfail(
...@@ -66,23 +60,20 @@ def test_jax_Reshape_shape_graph_input(): ...@@ -66,23 +60,20 @@ def test_jax_Reshape_shape_graph_input():
a = vector("a") a = vector("a")
shape_pt = iscalar("b") shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt)) x = reshape(a, (shape_pt, shape_pt))
x_fg = FunctionGraph([a, shape_pt], [x]) compare_jax_and_py(
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
)
def test_jax_compile_ops(): def test_jax_compile_ops():
x = DeepCopyOp()(pt.as_tensor_variable(1.1)) x = DeepCopyOp()(pt.as_tensor_variable(1.1))
x_fg = FunctionGraph([], [x]) compare_jax_and_py([], [x], [])
compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 1, 1)) x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np)) x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py([], [x], [])
x = ViewOp()(pt.as_tensor_variable(x_np)) x = ViewOp()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py([], [x], [])
...@@ -6,7 +6,6 @@ import pytest ...@@ -6,7 +6,6 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor import slinalg as pt_slinalg from pytensor.tensor import slinalg as pt_slinalg
from pytensor.tensor import subtensor as pt_subtensor from pytensor.tensor import subtensor as pt_subtensor
...@@ -30,13 +29,11 @@ def test_jax_basic(): ...@@ -30,13 +29,11 @@ def test_jax_basic():
out = pt_subtensor.inc_subtensor(out[0, 1], 2.0) out = pt_subtensor.inc_subtensor(out[0, 1], 2.0)
out = out[:5, :3] out = out[:5, :3]
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [ test_input_vals = [
np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10), (10, 1)).astype(config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
] ]
_, [jax_res] = compare_jax_and_py(out_fg, test_input_vals) _, [jax_res] = compare_jax_and_py([x, y], [out], test_input_vals)
# Confirm that the `Subtensor` slice operations are correct # Confirm that the `Subtensor` slice operations are correct
assert jax_res.shape == (5, 3) assert jax_res.shape == (5, 3)
...@@ -46,19 +43,17 @@ def test_jax_basic(): ...@@ -46,19 +43,17 @@ def test_jax_basic():
assert jax_res[0, 1] == -8.0 assert jax_res[0, 1] == -8.0
out = clip(x, y, 5) out = clip(x, y, 5)
out_fg = FunctionGraph([x, y], [out]) compare_jax_and_py([x, y], [out], test_input_vals)
compare_jax_and_py(out_fg, test_input_vals)
out = pt.diagonal(x, 0) out = pt.diagonal(x, 0)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] [x], [out], [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]
) )
out = pt_slinalg.cholesky(x) out = pt_slinalg.cholesky(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [x],
[out],
[ [
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX config.floatX
...@@ -68,9 +63,9 @@ def test_jax_basic(): ...@@ -68,9 +63,9 @@ def test_jax_basic():
# not sure why this isn't working yet with lower=False # not sure why this isn't working yet with lower=False
out = pt_slinalg.Cholesky(lower=False)(x) out = pt_slinalg.Cholesky(lower=False)(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [x],
[out],
[ [
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX config.floatX
...@@ -79,9 +74,9 @@ def test_jax_basic(): ...@@ -79,9 +74,9 @@ def test_jax_basic():
) )
out = pt_slinalg.solve(x, b) out = pt_slinalg.solve(x, b)
out_fg = FunctionGraph([x, b], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [x, b],
[out],
[ [
np.eye(10).astype(config.floatX), np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX), np.arange(10).astype(config.floatX),
...@@ -89,19 +84,17 @@ def test_jax_basic(): ...@@ -89,19 +84,17 @@ def test_jax_basic():
) )
out = pt.diag(b) out = pt.diag(b)
out_fg = FunctionGraph([b], [out]) compare_jax_and_py([b], [out], [np.arange(10).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])
out = pt_nlinalg.det(x) out = pt_nlinalg.det(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] [x], [out], [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]
) )
out = pt_nlinalg.matrix_inverse(x) out = pt_nlinalg.matrix_inverse(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [x],
[out],
[ [
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX config.floatX
...@@ -124,9 +117,9 @@ def test_jax_SolveTriangular(trans, lower, check_finite): ...@@ -124,9 +117,9 @@ def test_jax_SolveTriangular(trans, lower, check_finite):
lower=lower, lower=lower,
check_finite=check_finite, check_finite=check_finite,
) )
out_fg = FunctionGraph([x, b], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [x, b],
[out],
[ [
np.eye(10).astype(config.floatX), np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX), np.arange(10).astype(config.floatX),
...@@ -141,10 +134,10 @@ def test_jax_block_diag(): ...@@ -141,10 +134,10 @@ def test_jax_block_diag():
D = matrix("D") D = matrix("D")
out = pt_slinalg.block_diag(A, B, C, D) out = pt_slinalg.block_diag(A, B, C, D)
out_fg = FunctionGraph([A, B, C, D], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [A, B, C, D],
[out],
[ [
np.random.normal(size=(5, 5)).astype(config.floatX), np.random.normal(size=(5, 5)).astype(config.floatX),
np.random.normal(size=(3, 3)).astype(config.floatX), np.random.normal(size=(3, 3)).astype(config.floatX),
...@@ -158,9 +151,10 @@ def test_jax_block_diag_blockwise(): ...@@ -158,9 +151,10 @@ def test_jax_block_diag_blockwise():
A = pt.tensor3("A") A = pt.tensor3("A")
B = pt.tensor3("B") B = pt.tensor3("B")
out = pt_slinalg.block_diag(A, B) out = pt_slinalg.block_diag(A, B)
out_fg = FunctionGraph([A, B], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [A, B],
[out],
[ [
np.random.normal(size=(5, 5, 5)).astype(config.floatX), np.random.normal(size=(5, 5, 5)).astype(config.floatX),
np.random.normal(size=(5, 3, 3)).astype(config.floatX), np.random.normal(size=(5, 3, 3)).astype(config.floatX),
...@@ -174,11 +168,11 @@ def test_jax_eigvalsh(lower): ...@@ -174,11 +168,11 @@ def test_jax_eigvalsh(lower):
B = matrix("B") B = matrix("B")
out = pt_slinalg.eigvalsh(A, B, lower=lower) out = pt_slinalg.eigvalsh(A, B, lower=lower)
out_fg = FunctionGraph([A, B], [out])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
compare_jax_and_py( compare_jax_and_py(
out_fg, [A, B],
[out],
[ [
np.array( np.array(
[[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]] [[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]
...@@ -189,7 +183,8 @@ def test_jax_eigvalsh(lower): ...@@ -189,7 +183,8 @@ def test_jax_eigvalsh(lower):
], ],
) )
compare_jax_and_py( compare_jax_and_py(
out_fg, [A, B],
[out],
[ [
np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype( np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype(
config.floatX config.floatX
...@@ -207,11 +202,11 @@ def test_jax_solve_discrete_lyapunov( ...@@ -207,11 +202,11 @@ def test_jax_solve_discrete_lyapunov(
A = pt.tensor(name="A", shape=shape) A = pt.tensor(name="A", shape=shape)
B = pt.tensor(name="B", shape=shape) B = pt.tensor(name="B", shape=shape)
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method) out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
out_fg = FunctionGraph([A, B], [out])
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3 atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
compare_jax_and_py( compare_jax_and_py(
out_fg, [A, B],
[out],
[ [
np.random.normal(size=shape).astype(config.floatX), np.random.normal(size=shape).astype(config.floatX),
np.random.normal(size=shape).astype(config.floatX), np.random.normal(size=shape).astype(config.floatX),
......
import numpy as np import numpy as np
import pytest import pytest
from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix from pytensor.tensor import matrix
from pytensor.tensor.sort import argsort, sort from pytensor.tensor.sort import argsort, sort
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -12,6 +11,5 @@ from tests.link.jax.test_basic import compare_jax_and_py ...@@ -12,6 +11,5 @@ from tests.link.jax.test_basic import compare_jax_and_py
def test_sort(func, axis): def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64") x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis) out = func(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]]) arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr]) compare_jax_and_py([x], [out], [arr])
...@@ -5,7 +5,6 @@ import scipy.sparse ...@@ -5,7 +5,6 @@ import scipy.sparse
import pytensor.sparse as ps import pytensor.sparse as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import function from pytensor import function
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -50,8 +49,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op): ...@@ -50,8 +49,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):
test_values.append(y_test) test_values.append(y_test)
dot_pt = op(x_pt, y_pt) dot_pt = op(x_pt, y_pt)
fgraph = FunctionGraph(inputs, [dot_pt]) compare_jax_and_py(inputs, [dot_pt], test_values, jax_mode="JAX")
compare_jax_and_py(fgraph, test_values, jax_mode="JAX")
def test_sparse_dot_non_const_raises(): def test_sparse_dot_non_const_raises():
......
...@@ -21,55 +21,55 @@ def test_jax_Subtensor_constant(): ...@@ -21,55 +21,55 @@ def test_jax_Subtensor_constant():
# Basic indices # Basic indices
out_pt = x_pt[1, 2, 0] out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:, 1, :] out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[:2, 1, :] out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:2, 1, :] out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
# Advanced indexing # Advanced indexing
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2]) out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], [2, 3]] out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
# Advanced and basic indexing # Advanced and basic indexing
out_pt = x_pt[[1, 2], :] out_pt = x_pt[[1, 2], :]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], :, [3, 4]] out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
# Flipping # Flipping
out_pt = x_pt[::-1] out_pt = x_pt[::-1]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
# Boolean indexing should work if indexes are constant # Boolean indexing should work if indexes are constant
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)] out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
@pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling") @pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling")
...@@ -78,8 +78,8 @@ def test_jax_Subtensor_dynamic(): ...@@ -78,8 +78,8 @@ def test_jax_Subtensor_dynamic():
x = pt.arange(3) x = pt.arange(3)
out_pt = x[:a] out_pt = x[:a]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([a], [out_pt])
compare_jax_and_py(out_fg, [1]) compare_jax_and_py([a], [out_pt], [1])
def test_jax_Subtensor_dynamic_boolean_mask(): def test_jax_Subtensor_dynamic_boolean_mask():
...@@ -90,11 +90,9 @@ def test_jax_Subtensor_dynamic_boolean_mask(): ...@@ -90,11 +90,9 @@ def test_jax_Subtensor_dynamic_boolean_mask():
out_pt = x_pt[x_pt < 0] out_pt = x_pt[x_pt < 0]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
x_pt_test = np.arange(-5, 5) x_pt_test = np.arange(-5, 5)
with pytest.raises(NonConcreteBooleanIndexError): with pytest.raises(NonConcreteBooleanIndexError):
compare_jax_and_py(out_fg, [x_pt_test]) compare_jax_and_py([x_pt], [out_pt], [x_pt_test])
def test_jax_Subtensor_boolean_mask_reexpressible(): def test_jax_Subtensor_boolean_mask_reexpressible():
...@@ -110,8 +108,10 @@ def test_jax_Subtensor_boolean_mask_reexpressible(): ...@@ -110,8 +108,10 @@ def test_jax_Subtensor_boolean_mask_reexpressible():
""" """
x_pt = pt.matrix("x") x_pt = pt.matrix("x")
out_pt = x_pt[x_pt < 0].sum() out_pt = x_pt[x_pt < 0].sum()
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [np.arange(25).reshape(5, 5).astype(config.floatX)]) compare_jax_and_py(
[x_pt], [out_pt], [np.arange(25).reshape(5, 5).astype(config.floatX)]
)
def test_boolean_indexing_sum_not_applicable(): def test_boolean_indexing_sum_not_applicable():
...@@ -136,19 +136,19 @@ def test_jax_IncSubtensor(): ...@@ -136,19 +136,19 @@ def test_jax_IncSubtensor():
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[:2, 0, 0], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
# "Set" advanced indices # "Set" advanced indices
st_pt = pt.as_tensor_variable( st_pt = pt.as_tensor_variable(
...@@ -156,39 +156,39 @@ def test_jax_IncSubtensor(): ...@@ -156,39 +156,39 @@ def test_jax_IncSubtensor():
) )
out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, 0], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
# "Set" boolean indices # "Set" boolean indices
mask_pt = pt.constant(x_np > 0) mask_pt = pt.constant(x_np > 0)
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0) out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
# "Increment" basic indices # "Increment" basic indices
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt) out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[:2, 0, 0], st_pt) out_pt = pt_subtensor.inc_subtensor(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
# "Increment" advanced indices # "Increment" advanced indices
st_pt = pt.as_tensor_variable( st_pt = pt.as_tensor_variable(
...@@ -196,33 +196,33 @@ def test_jax_IncSubtensor(): ...@@ -196,33 +196,33 @@ def test_jax_IncSubtensor():
) )
out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt) out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, 0], st_pt) out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
# "Increment" boolean indices # "Increment" boolean indices
mask_pt = pt.constant(x_np > 0) mask_pt = pt.constant(x_np > 0)
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 1.0) out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3]) st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, :3], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3]) st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, :3], st_pt) out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
def test_jax_IncSubtensor_boolean_indexing_reexpressible(): def test_jax_IncSubtensor_boolean_indexing_reexpressible():
...@@ -243,14 +243,14 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible(): ...@@ -243,14 +243,14 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
mask_pt = pt.as_tensor(x_pt) > 0 mask_pt = pt.as_tensor(x_pt) > 0
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0) out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
mask_pt = pt.as_tensor(x_pt) > 0 mask_pt = pt.as_tensor(x_pt) > 0
out_pt = pt_subtensor.inc_subtensor(x_pt[mask_pt], 1.0) out_pt = pt_subtensor.inc_subtensor(x_pt[mask_pt], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
def test_boolean_indexing_set_or_inc_not_applicable(): def test_boolean_indexing_set_or_inc_not_applicable():
......
...@@ -10,8 +10,6 @@ from jax import errors ...@@ -10,8 +10,6 @@ from jax import errors
import pytensor import pytensor
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.type import iscalar, matrix, scalar, vector from pytensor.tensor.type import iscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_basic import check_alloc_runtime_broadcast from tests.tensor.test_basic import check_alloc_runtime_broadcast
...@@ -19,38 +17,31 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast ...@@ -19,38 +17,31 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast
def test_jax_Alloc(): def test_jax_Alloc():
x = ptb.alloc(0.0, 2, 3) x = ptb.alloc(0.0, 2, 3)
x_fg = FunctionGraph([], [x])
_, [jax_res] = compare_jax_and_py(x_fg, []) _, [jax_res] = compare_jax_and_py([], [x], [])
assert jax_res.shape == (2, 3) assert jax_res.shape == (2, 3)
x = ptb.alloc(1.1, 2, 3) x = ptb.alloc(1.1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py([], [x], [])
x = ptb.AllocEmpty("float32")(2, 3) x = ptb.AllocEmpty("float32")(2, 3)
x_fg = FunctionGraph([], [x])
def compare_shape_dtype(x, y): def compare_shape_dtype(x, y):
(x,) = x np.testing.assert_array_equal(x, y, strict=True)
(y,) = y
return x.shape == y.shape and x.dtype == y.dtype
compare_jax_and_py(x_fg, [], assert_fn=compare_shape_dtype) compare_jax_and_py([], [x], [], assert_fn=compare_shape_dtype)
a = scalar("a") a = scalar("a")
x = ptb.alloc(a, 20) x = ptb.alloc(a, 20)
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [10.0]) compare_jax_and_py([a], [x], [10.0])
a = vector("a") a = vector("a")
x = ptb.alloc(a, 20, 10) x = ptb.alloc(a, 20, 10)
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)]) compare_jax_and_py([a], [x], [np.ones(10, dtype=config.floatX)])
def test_alloc_runtime_broadcast(): def test_alloc_runtime_broadcast():
...@@ -59,34 +50,31 @@ def test_alloc_runtime_broadcast(): ...@@ -59,34 +50,31 @@ def test_alloc_runtime_broadcast():
def test_jax_MakeVector(): def test_jax_MakeVector():
x = ptb.make_vector(1, 2, 3) x = ptb.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py([], [x], [])
def test_arange(): def test_arange():
out = ptb.arange(1, 10, 2) out = ptb.arange(1, 10, 2)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, []) compare_jax_and_py([], [out], [])
def test_arange_of_shape(): def test_arange_of_shape():
x = vector("x") x = vector("x")
out = ptb.arange(1, x.shape[-1], 2) out = ptb.arange(1, x.shape[-1], 2)
fgraph = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.zeros((5,))], jax_mode="JAX")
compare_jax_and_py(fgraph, [np.zeros((5,))], jax_mode="JAX")
def test_arange_nonconcrete(): def test_arange_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values.""" """JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
a = scalar("a") a = scalar("a")
a.tag.test_value = 10 a_test_value = 10
out = ptb.arange(a) out = ptb.arange(a)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_Join(): def test_jax_Join():
...@@ -94,16 +82,17 @@ def test_jax_Join(): ...@@ -94,16 +82,17 @@ def test_jax_Join():
b = matrix("b") b = matrix("b")
x = ptb.join(0, a, b) x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py( compare_jax_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
], ],
) )
compare_jax_and_py( compare_jax_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX), np.c_[[4.0, 5.0]].astype(config.floatX),
...@@ -111,16 +100,17 @@ def test_jax_Join(): ...@@ -111,16 +100,17 @@ def test_jax_Join():
) )
x = ptb.join(1, a, b) x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py( compare_jax_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
], ],
) )
compare_jax_and_py( compare_jax_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX), np.c_[[5.0, 6.0]].astype(config.floatX),
...@@ -132,9 +122,9 @@ class TestJaxSplit: ...@@ -132,9 +122,9 @@ class TestJaxSplit:
def test_basic(self): def test_basic(self):
a = matrix("a") a = matrix("a")
a_splits = ptb.split(a, splits_size=[1, 2, 3], n_splits=3, axis=0) a_splits = ptb.split(a, splits_size=[1, 2, 3], n_splits=3, axis=0)
fg = FunctionGraph([a], a_splits)
compare_jax_and_py( compare_jax_and_py(
fg, [a],
a_splits,
[ [
np.zeros((6, 4)).astype(config.floatX), np.zeros((6, 4)).astype(config.floatX),
], ],
...@@ -142,9 +132,9 @@ class TestJaxSplit: ...@@ -142,9 +132,9 @@ class TestJaxSplit:
a = matrix("a", shape=(6, None)) a = matrix("a", shape=(6, None))
a_splits = ptb.split(a, splits_size=[2, a.shape[0] - 2], n_splits=2, axis=0) a_splits = ptb.split(a, splits_size=[2, a.shape[0] - 2], n_splits=2, axis=0)
fg = FunctionGraph([a], a_splits)
compare_jax_and_py( compare_jax_and_py(
fg, [a],
a_splits,
[ [
np.zeros((6, 4)).astype(config.floatX), np.zeros((6, 4)).astype(config.floatX),
], ],
...@@ -207,15 +197,14 @@ class TestJaxSplit: ...@@ -207,15 +197,14 @@ class TestJaxSplit:
def test_jax_eye(): def test_jax_eye():
"""Tests jaxification of the Eye operator""" """Tests jaxification of the Eye operator"""
out = ptb.eye(3) out = ptb.eye(3)
out_fg = FunctionGraph([], [out])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out], [])
def test_tri(): def test_tri():
out = ptb.tri(10, 10, 0) out = ptb.tri(10, 10, 0)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, []) compare_jax_and_py([], [out], [])
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -230,14 +219,13 @@ def test_tri_nonconcrete(): ...@@ -230,14 +219,13 @@ def test_tri_nonconcrete():
scalar("n", dtype="int64"), scalar("n", dtype="int64"),
scalar("k", dtype="int64"), scalar("k", dtype="int64"),
) )
m.tag.test_value = 10 m_test_value = 10
n.tag.test_value = 10 n_test_value = 10
k.tag.test_value = 0 k_test_value = 0
out = ptb.tri(m, n, k) out = ptb.tri(m, n, k)
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but # The actual error the user will see should be jax.errors.ConcretizationTypeError, but
# the error handler raises an Attribute error first, so that's what this test needs to pass # the error handler raises an Attribute error first, so that's what this test needs to pass
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
fgraph = FunctionGraph([m, n, k], [out]) compare_jax_and_py([m, n, k], [out], [m_test_value, n_test_value, k_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
...@@ -27,7 +27,8 @@ def test_blockwise(core_op, shape_opt): ...@@ -27,7 +27,8 @@ def test_blockwise(core_op, shape_opt):
) )
x_test = np.eye(3) * np.arange(1, 6)[:, None, None] x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
compare_numba_and_py( compare_numba_and_py(
([x], outs), [x],
outs,
[x_test], [x_test],
numba_mode=mode, numba_mode=mode,
eval_obj_mode=False, eval_obj_mode=False,
......
...@@ -4,11 +4,8 @@ import numpy as np ...@@ -4,11 +4,8 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg from pytensor.tensor import nlinalg
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -18,14 +15,14 @@ rng = np.random.default_rng(42849) ...@@ -18,14 +15,14 @@ rng = np.random.default_rng(42849)
"x, exc", "x, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
), ),
...@@ -34,18 +31,15 @@ rng = np.random.default_rng(42849) ...@@ -34,18 +31,15 @@ rng = np.random.default_rng(42849)
], ],
) )
def test_Det(x, exc): def test_Det(x, exc):
x, test_x = x
g = nlinalg.Det()(x) g = nlinalg.Det()(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -53,14 +47,14 @@ def test_Det(x, exc): ...@@ -53,14 +47,14 @@ def test_Det(x, exc):
"x, exc", "x, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
), ),
...@@ -69,18 +63,15 @@ def test_Det(x, exc): ...@@ -69,18 +63,15 @@ def test_Det(x, exc):
], ],
) )
def test_SLogDet(x, exc): def test_SLogDet(x, exc):
x, test_x = x
g = nlinalg.SLogDet()(x) g = nlinalg.SLogDet()(x)
g_fg = FunctionGraph(outputs=g)
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -112,21 +103,21 @@ y = np.array( ...@@ -112,21 +103,21 @@ y = np.array(
"x, exc", "x, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(x), (lambda x: x.T.dot(x))(x),
), ),
None, None,
), ),
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(y), (lambda x: x.T.dot(x))(y),
), ),
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -137,22 +128,15 @@ y = np.array( ...@@ -137,22 +128,15 @@ y = np.array(
], ],
) )
def test_Eig(x, exc): def test_Eig(x, exc):
x, test_x = x
g = nlinalg.Eig()(x) g = nlinalg.Eig()(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -160,7 +144,7 @@ def test_Eig(x, exc): ...@@ -160,7 +144,7 @@ def test_Eig(x, exc):
"x, uplo, exc", "x, uplo, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -168,7 +152,7 @@ def test_Eig(x, exc): ...@@ -168,7 +152,7 @@ def test_Eig(x, exc):
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -180,22 +164,15 @@ def test_Eig(x, exc): ...@@ -180,22 +164,15 @@ def test_Eig(x, exc):
], ],
) )
def test_Eigh(x, uplo, exc): def test_Eigh(x, uplo, exc):
x, test_x = x
g = nlinalg.Eigh(uplo)(x) g = nlinalg.Eigh(uplo)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -204,7 +181,7 @@ def test_Eigh(x, uplo, exc): ...@@ -204,7 +181,7 @@ def test_Eigh(x, uplo, exc):
[ [
( (
nlinalg.MatrixInverse, nlinalg.MatrixInverse,
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -213,7 +190,7 @@ def test_Eigh(x, uplo, exc): ...@@ -213,7 +190,7 @@ def test_Eigh(x, uplo, exc):
), ),
( (
nlinalg.MatrixInverse, nlinalg.MatrixInverse,
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -224,7 +201,7 @@ def test_Eigh(x, uplo, exc): ...@@ -224,7 +201,7 @@ def test_Eigh(x, uplo, exc):
), ),
( (
nlinalg.MatrixPinv, nlinalg.MatrixPinv,
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -233,7 +210,7 @@ def test_Eigh(x, uplo, exc): ...@@ -233,7 +210,7 @@ def test_Eigh(x, uplo, exc):
), ),
( (
nlinalg.MatrixPinv, nlinalg.MatrixPinv,
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -245,18 +222,15 @@ def test_Eigh(x, uplo, exc): ...@@ -245,18 +222,15 @@ def test_Eigh(x, uplo, exc):
], ],
) )
def test_matrix_inverses(op, x, exc, op_args): def test_matrix_inverses(op, x, exc, op_args):
x, test_x = x
g = op(*op_args)(x) g = op(*op_args)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -264,7 +238,7 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -264,7 +238,7 @@ def test_matrix_inverses(op, x, exc, op_args):
"x, mode, exc", "x, mode, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -272,7 +246,7 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -272,7 +246,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None, None,
), ),
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -280,7 +254,7 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -280,7 +254,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -290,7 +264,7 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -290,7 +264,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -302,22 +276,15 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -302,22 +276,15 @@ def test_matrix_inverses(op, x, exc, op_args):
], ],
) )
def test_QRFull(x, mode, exc): def test_QRFull(x, mode, exc):
x, test_x = x
g = nlinalg.QRFull(mode)(x) g = nlinalg.QRFull(mode)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -325,7 +292,7 @@ def test_QRFull(x, mode, exc): ...@@ -325,7 +292,7 @@ def test_QRFull(x, mode, exc):
"x, full_matrices, compute_uv, exc", "x, full_matrices, compute_uv, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -334,7 +301,7 @@ def test_QRFull(x, mode, exc): ...@@ -334,7 +301,7 @@ def test_QRFull(x, mode, exc):
None, None,
), ),
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -343,7 +310,7 @@ def test_QRFull(x, mode, exc): ...@@ -343,7 +310,7 @@ def test_QRFull(x, mode, exc):
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -354,7 +321,7 @@ def test_QRFull(x, mode, exc): ...@@ -354,7 +321,7 @@ def test_QRFull(x, mode, exc):
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -367,20 +334,13 @@ def test_QRFull(x, mode, exc): ...@@ -367,20 +334,13 @@ def test_QRFull(x, mode, exc):
], ],
) )
def test_SVD(x, full_matrices, compute_uv, exc): def test_SVD(x, full_matrices, compute_uv, exc):
x, test_x = x
g = nlinalg.SVD(full_matrices, compute_uv)(x) g = nlinalg.SVD(full_matrices, compute_uv)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -3,7 +3,6 @@ import pytest ...@@ -3,7 +3,6 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor.pad import PadMode from pytensor.tensor.pad import PadMode
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
...@@ -58,10 +57,10 @@ def test_numba_pad(mode: PadMode, kwargs): ...@@ -58,10 +57,10 @@ def test_numba_pad(mode: PadMode, kwargs):
x = np.random.normal(size=(3, 3)) x = np.random.normal(size=(3, 3))
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs) res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
res_fg = FunctionGraph([x_pt], [res])
compare_numba_and_py( compare_numba_and_py(
res_fg, [x_pt],
[res],
[x], [x],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL), assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
py_mode="FAST_RUN", py_mode="FAST_RUN",
......
...@@ -5,13 +5,10 @@ import pytensor.scalar as ps ...@@ -5,13 +5,10 @@ import pytensor.scalar as ps
import pytensor.scalar.basic as psb import pytensor.scalar.basic as psb
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -21,48 +18,43 @@ rng = np.random.default_rng(42849) ...@@ -21,48 +18,43 @@ rng = np.random.default_rng(42849)
"x, y", "x, y",
[ [
( (
set_test_value(pt.lvector(), np.arange(4, dtype="int64")), (pt.lvector(), np.arange(4, dtype="int64")),
set_test_value(pt.dvector(), np.arange(4, dtype="float64")), (pt.dvector(), np.arange(4, dtype="float64")),
), ),
( (
set_test_value(pt.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))), (pt.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))),
set_test_value(pt.lscalar(), np.array(4, dtype="int64")), (pt.lscalar(), np.array(4, dtype="int64")),
), ),
], ],
) )
def test_Second(x, y): def test_Second(x, y):
x, x_test = x
y, y_test = y
# We use the `Elemwise`-wrapped version of `Second` # We use the `Elemwise`-wrapped version of `Second`
g = pt.second(x, y) g = pt.second(x, y)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [x, y],
[ g,
i.tag.test_value [x_test, y_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, min, max", "v, min, max",
[ [
(set_test_value(pt.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0), ((pt.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0),
(set_test_value(pt.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0), ((pt.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0),
(set_test_value(pt.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0), ((pt.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0),
], ],
) )
def test_Clip(v, min, max): def test_Clip(v, min, max):
v, v_test = v
g = ps.clip(v, min, max) g = ps.clip(v, min, max)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ [g],
i.tag.test_value [v_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -100,46 +92,39 @@ def test_Clip(v, min, max): ...@@ -100,46 +92,39 @@ def test_Clip(v, min, max):
def test_Composite(inputs, input_values, scalar_fn): def test_Composite(inputs, input_values, scalar_fn):
composite_inputs = [ps.ScalarType(config.floatX)(name=i.name) for i in inputs] composite_inputs = [ps.ScalarType(config.floatX)(name=i.name) for i in inputs]
comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)])) comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)]))
out_fg = FunctionGraph(inputs, [comp_op(*inputs)]) compare_numba_and_py(inputs, [comp_op(*inputs)], input_values)
compare_numba_and_py(out_fg, input_values)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, dtype", "v, dtype",
[ [
(set_test_value(pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64), ((pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64),
(set_test_value(pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32), ((pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32),
], ],
) )
def test_Cast(v, dtype): def test_Cast(v, dtype):
v, v_test = v
g = psb.Cast(dtype)(v) g = psb.Cast(dtype)(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ [g],
i.tag.test_value [v_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, dtype", "v, dtype",
[ [
(set_test_value(pt.iscalar(), np.array(10, dtype="int32")), psb.float64), ((pt.iscalar(), np.array(10, dtype="int32")), psb.float64),
], ],
) )
def test_reciprocal(v, dtype): def test_reciprocal(v, dtype):
v, v_test = v
g = psb.reciprocal(v) g = psb.reciprocal(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ [g],
i.tag.test_value [v_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -156,6 +141,7 @@ def test_isnan(composite): ...@@ -156,6 +141,7 @@ def test_isnan(composite):
out = pt.isnan(x) out = pt.isnan(x)
compare_numba_and_py( compare_numba_and_py(
([x], [out]), [x],
[out],
[np.array([1, 0], dtype="float64")], [np.array([1, 0], dtype="float64")],
) )
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论