提交 cc8c4992 authored 作者: Adv's avatar Adv 提交者: Ricardo Vieira

Stop using FunctionGraph and tag.test_value in linker tests

上级 51ea1a0b
...@@ -7,12 +7,12 @@ import pytest ...@@ -7,12 +7,12 @@ import pytest
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import JAX, Mode from pytensor.compile.mode import JAX, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value from pytensor.graph.op import Op
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
...@@ -34,25 +34,28 @@ py_mode = Mode(linker="py", optimizer=None) ...@@ -34,25 +34,28 @@ py_mode = Mode(linker="py", optimizer=None)
def compare_jax_and_py( def compare_jax_and_py(
fgraph: FunctionGraph, graph_inputs: Iterable[Variable],
graph_outputs: Variable | Iterable[Variable],
test_inputs: Iterable, test_inputs: Iterable,
*,
assert_fn: Callable | None = None, assert_fn: Callable | None = None,
must_be_device_array: bool = True, must_be_device_array: bool = True,
jax_mode=jax_mode, jax_mode=jax_mode,
py_mode=py_mode, py_mode=py_mode,
): ):
"""Function to compare python graph output and jax compiled output for testing equality """Function to compare python function output and jax compiled output for testing equality
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to The inputs and outputs are then passed to this function which then compiles the given function in both
this function which then compiles the graphs in both jax and python, runs the calculation jax and python, runs the calculation in both and checks if the results are the same
in both and checks if the results are the same
Parameters Parameters
---------- ----------
fgraph: FunctionGraph graph_inputs:
PyTensor function Graph object Symbolic inputs to the graph
outputs:
Symbolic outputs of the graph
test_inputs: iter test_inputs: iter
Numerical inputs for testing the function graph Numerical inputs for testing the function.
assert_fn: func, opt assert_fn: func, opt
Assert function used to check for equality between python and jax. If not Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose provided uses np.testing.assert_allclose
...@@ -68,8 +71,10 @@ def compare_jax_and_py( ...@@ -68,8 +71,10 @@ def compare_jax_and_py(
if assert_fn is None: if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] if any(inp.owner is not None for inp in graph_inputs):
pytensor_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode) raise ValueError("Inputs must be root variables")
pytensor_jax_fn = function(graph_inputs, graph_outputs, mode=jax_mode)
jax_res = pytensor_jax_fn(*test_inputs) jax_res = pytensor_jax_fn(*test_inputs)
if must_be_device_array: if must_be_device_array:
...@@ -78,10 +83,10 @@ def compare_jax_and_py( ...@@ -78,10 +83,10 @@ def compare_jax_and_py(
else: else:
assert isinstance(jax_res, jax.Array) assert isinstance(jax_res, jax.Array)
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs) py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1: if isinstance(graph_outputs, list | tuple):
for j, p in zip(jax_res, py_res, strict=True): for j, p in zip(jax_res, py_res, strict=True):
assert_fn(j, p) assert_fn(j, p)
else: else:
...@@ -187,16 +192,14 @@ def test_jax_ifelse(): ...@@ -187,16 +192,14 @@ def test_jax_ifelse():
false_vals = np.r_[-1, -2, -3] false_vals = np.r_[-1, -2, -3]
x = ifelse(np.array(True), true_vals, false_vals) x = ifelse(np.array(True), true_vals, false_vals)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py([], [x], [])
a = dscalar("a") a = dscalar("a")
a.tag.test_value = np.array(0.2, dtype=config.floatX) a_test = np.array(0.2, dtype=config.floatX)
x = ifelse(a < 0.5, true_vals, false_vals) x = ifelse(a < 0.5, true_vals, false_vals)
x_fg = FunctionGraph([a], [x]) # I.e. False
compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs]) compare_jax_and_py([a], [x], [a_test])
def test_jax_checkandraise(): def test_jax_checkandraise():
...@@ -209,11 +212,6 @@ def test_jax_checkandraise(): ...@@ -209,11 +212,6 @@ def test_jax_checkandraise():
function((p,), res, mode=jax_mode) function((p,), res, mode=jax_mode)
def set_test_value(x, v):
x.tag.test_value = v
return x
def test_OpFromGraph(): def test_OpFromGraph():
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False) ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
...@@ -221,10 +219,9 @@ def test_OpFromGraph(): ...@@ -221,10 +219,9 @@ def test_OpFromGraph():
o1, o2 = ofg_2(y, z) o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2 out = ofg_1(x, o1) + o2
out_fg = FunctionGraph([x, y, z], [out])
xv = np.ones((2, 2), dtype=config.floatX) xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_jax_and_py(out_fg, [xv, yv, zv]) compare_jax_and_py([x, y, z], [out], [xv, yv, zv])
...@@ -4,8 +4,6 @@ import pytest ...@@ -4,8 +4,6 @@ import pytest
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas from pytensor.tensor import blas as pt_blas
...@@ -16,21 +14,20 @@ from tests.link.jax.test_basic import compare_jax_and_py ...@@ -16,21 +14,20 @@ from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_BatchedDot(): def test_jax_BatchedDot():
# tensor3 . tensor3 # tensor3 . tensor3
a = tensor3("a") a = tensor3("a")
a.tag.test_value = ( a_test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
) )
b = tensor3("b") b = tensor3("b")
b.tag.test_value = ( b_test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
) )
out = pt_blas.BatchedDot()(a, b) out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out]) compare_jax_and_py([a, b], [out], [a_test_value, b_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# A dimension mismatch should raise a TypeError for compatibility # A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)] inputs = [a_test_value[:-1], b_test_value]
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts) jax_mode = Mode(JAXLinker(), opts)
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) pytensor_jax_fn = function([a, b], [out], mode=jax_mode)
with pytest.raises(TypeError): with pytest.raises(TypeError):
pytensor_jax_fn(*inputs) pytensor_jax_fn(*inputs)
...@@ -2,7 +2,6 @@ import numpy as np ...@@ -2,7 +2,6 @@ import numpy as np
import pytest import pytest
from pytensor import config from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot, matmul from pytensor.tensor.math import Dot, matmul
...@@ -32,8 +31,7 @@ def test_matmul(matmul_op): ...@@ -32,8 +31,7 @@ def test_matmul(matmul_op):
out = matmul_op(a, b) out = matmul_op(a, b)
assert isinstance(out.owner.op, Blockwise) assert isinstance(out.owner.op, Blockwise)
fg = FunctionGraph([a, b], [out]) fn, _ = compare_jax_and_py([a, b], [out], test_values)
fn, _ = compare_jax_and_py(fg, test_values)
# Check we are not adding any unnecessary stuff # Check we are not adding any unnecessary stuff
jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values)) jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))
......
...@@ -2,7 +2,6 @@ import numpy as np ...@@ -2,7 +2,6 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -22,8 +21,7 @@ def test_jax_einsum(): ...@@ -22,8 +21,7 @@ def test_jax_einsum():
} }
x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
out = pt.einsum(subscripts, x_pt, y_pt, z_pt) out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
fg = FunctionGraph([x_pt, y_pt, z_pt], [out]) compare_jax_and_py([x_pt, y_pt, z_pt], [out], [x, y, z])
compare_jax_and_py(fg, [x, y, z])
def test_ellipsis_einsum(): def test_ellipsis_einsum():
...@@ -34,5 +32,4 @@ def test_ellipsis_einsum(): ...@@ -34,5 +32,4 @@ def test_ellipsis_einsum():
x_pt = pt.tensor("x", shape=x.shape) x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape) y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt) out = pt.einsum(subscripts, x_pt, y_pt)
fg = FunctionGraph([x_pt, y_pt], [out]) compare_jax_and_py([x_pt, y_pt], [out], [x, y])
compare_jax_and_py(fg, [x, y])
...@@ -6,8 +6,6 @@ import pytensor ...@@ -6,8 +6,6 @@ import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile import get_mode from pytensor.compile import get_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import elemwise as pt_elemwise from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import prod from pytensor.tensor.math import prod
...@@ -26,22 +24,22 @@ def test_jax_Dimshuffle(): ...@@ -26,22 +24,22 @@ def test_jax_Dimshuffle():
a_pt = matrix("a") a_pt = matrix("a")
x = a_pt.T x = a_pt.T
x_fg = FunctionGraph([a_pt], [x]) compare_jax_and_py(
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) [a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)
x = a_pt.dimshuffle([0, 1, "x"]) x = a_pt.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_pt], [x]) compare_jax_and_py(
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) [a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)
a_pt = tensor(dtype=config.floatX, shape=(None, 1)) a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = a_pt.dimshuffle((0,)) x = a_pt.dimshuffle((0,))
x_fg = FunctionGraph([a_pt], [x]) compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1)) a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt) x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x]) compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_jax_CAReduce(): def test_jax_CAReduce():
...@@ -49,64 +47,58 @@ def test_jax_CAReduce(): ...@@ -49,64 +47,58 @@ def test_jax_CAReduce():
a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX) a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
x = pt_sum(a_pt, axis=None) x = pt_sum(a_pt, axis=None)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)]) compare_jax_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)])
a_pt = matrix("a") a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = pt_sum(a_pt, axis=0) x = pt_sum(a_pt, axis=0)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = pt_sum(a_pt, axis=1) x = pt_sum(a_pt, axis=1)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
a_pt = matrix("a") a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = prod(a_pt, axis=0) x = prod(a_pt, axis=0)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = pt_all(a_pt) x = pt_all(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis): def test_softmax(axis):
x = matrix("x") x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = softmax(x, axis=axis) out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [x_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis): def test_logsoftmax(axis):
x = matrix("x") x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = log_softmax(x, axis=axis) out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py([x], [out], [x_test_value])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis): def test_softmax_grad(axis):
dy = matrix("dy") dy = matrix("dy")
dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm") sm = matrix("sm")
sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm) out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py([dy, sm], [out], [dy_test_value, sm_test_value])
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)]) @pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
...@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark): ...@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
def test_multiple_input_multiply(): def test_multiple_input_multiply():
x, y, z = vectors("xyz") x, y, z = vectors("xyz")
out = pt.mul(x, y, z) out = pt.mul(x, y, z)
compare_jax_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]])
fg = FunctionGraph(outputs=[out], clone=False)
compare_jax_and_py(fg, [[1.5], [2.5], [3.5]])
...@@ -3,8 +3,6 @@ import pytest ...@@ -3,8 +3,6 @@ import pytest
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as pt_extra_ops from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.sort import argsort from pytensor.tensor.sort import argsort
from pytensor.tensor.type import matrix, tensor from pytensor.tensor.type import matrix, tensor
...@@ -19,57 +17,45 @@ def test_extra_ops(): ...@@ -19,57 +17,45 @@ def test_extra_ops():
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.cumsum(a, axis=0) out = pt_extra_ops.cumsum(a, axis=0)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.cumprod(a, axis=1) out = pt_extra_ops.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.diff(a, n=2, axis=1) out = pt_extra_ops.diff(a, n=2, axis=1)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.repeat(a, (3, 3), axis=1) out = pt_extra_ops.repeat(a, (3, 3), axis=1)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
c = ptb.as_tensor(5) c = ptb.as_tensor(5)
out = pt_extra_ops.fill_diagonal(a, c) out = pt_extra_ops.fill_diagonal(a, c)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
out = pt_extra_ops.fill_diagonal_offset(a, c, c) out = pt_extra_ops.fill_diagonal_offset(a, c, c)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
out = pt_extra_ops.Unique(axis=1)(a) out = pt_extra_ops.Unique(axis=1)(a)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
indices = np.arange(np.prod((3, 4))) indices = np.arange(np.prod((3, 4)))
out = pt_extra_ops.unravel_index(indices, (3, 4), order="C") out = pt_extra_ops.unravel_index(indices, (3, 4), order="C")
fgraph = FunctionGraph([], out) compare_jax_and_py([], out, [], must_be_device_array=False)
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
v = ptb.as_tensor_variable(6.0) v = ptb.as_tensor_variable(6.0)
sorted_idx = argsort(a.ravel()) sorted_idx = argsort(a.ravel())
out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v) out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_bartlett_dynamic_shape(): def test_bartlett_dynamic_shape():
c = tensor(shape=(), dtype=int) c = tensor(shape=(), dtype=int)
out = pt_extra_ops.bartlett(c) out = pt_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out]) compare_jax_and_py([], [out], [np.array(5)])
compare_jax_and_py(fgraph, [np.array(5)])
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
...@@ -79,8 +65,7 @@ def test_ravel_multi_index_dynamic_shape(): ...@@ -79,8 +65,7 @@ def test_ravel_multi_index_dynamic_shape():
x = tensor(shape=(None,), dtype=int) x = tensor(shape=(None,), dtype=int)
y = tensor(shape=(None,), dtype=int) y = tensor(shape=(None,), dtype=int)
out = pt_extra_ops.ravel_multi_index((x, y), (3, 4)) out = pt_extra_ops.ravel_multi_index((x, y), (3, 4))
fgraph = FunctionGraph([], [out]) compare_jax_and_py([], [out], [x_test, y_test])
compare_jax_and_py(fgraph, [x_test, y_test])
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
...@@ -89,5 +74,4 @@ def test_unique_dynamic_shape(): ...@@ -89,5 +74,4 @@ def test_unique_dynamic_shape():
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.Unique()(a) out = pt_extra_ops.Unique()(a)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test])
compare_jax_and_py(fgraph, [a_test])
...@@ -2,8 +2,6 @@ import numpy as np ...@@ -2,8 +2,6 @@ import numpy as np
import pytest import pytest
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.math import Argmax, Max, maximum from pytensor.tensor.math import Argmax, Max, maximum
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, vector from pytensor.tensor.type import dvector, matrix, scalar, vector
...@@ -20,33 +18,39 @@ def test_jax_max_and_argmax(): ...@@ -20,33 +18,39 @@ def test_jax_max_and_argmax():
mx = Max([0])(x) mx = Max([0])(x)
amx = Argmax([0])(x) amx = Argmax([0])(x)
out = mx * amx out = mx * amx
out_fg = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.r_[1, 2]])
compare_jax_and_py(out_fg, [np.r_[1, 2]])
def test_dot(): def test_dot():
y = vector("y") y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) y_test_value = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x") x = vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) x_test_value = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A") A = matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=config.floatX) A_test_value = np.empty((2, 2), dtype=config.floatX)
alpha = scalar("alpha") alpha = scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=config.floatX) alpha_test_value = np.array(3.0, dtype=config.floatX)
beta = scalar("beta") beta = scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=config.floatX) beta_test_value = np.array(5.0, dtype=config.floatX)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible # This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should # optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone. # leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) compare_jax_and_py(
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) [y, x, A, alpha, beta],
out,
[
y_test_value,
x_test_value,
A_test_value,
alpha_test_value,
beta_test_value,
],
)
out = maximum(y, x) out = maximum(y, x)
fgraph = FunctionGraph([y, x], [out]) compare_jax_and_py([y, x], [out], [y_test_value, x_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = pt_max(y) out = pt_max(y)
fgraph = FunctionGraph([y], [out]) compare_jax_and_py([y], [out], [y_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
...@@ -3,7 +3,6 @@ import pytest ...@@ -3,7 +3,6 @@ import pytest
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor.type import matrix from pytensor.tensor.type import matrix
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -21,41 +20,34 @@ def test_jax_basic_multiout(): ...@@ -21,41 +20,34 @@ def test_jax_basic_multiout():
x = matrix("x") x = matrix("x")
outs = pt_nlinalg.eig(x) outs = pt_nlinalg.eig(x)
out_fg = FunctionGraph([x], outs)
def assert_fn(x, y): def assert_fn(x, y):
np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.eigh(x) outs = pt_nlinalg.eigh(x)
out_fg = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.qr(x, mode="full") outs = pt_nlinalg.qr(x, mode="full")
out_fg = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.qr(x, mode="reduced") outs = pt_nlinalg.qr(x, mode="reduced")
out_fg = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.svd(x) outs = pt_nlinalg.svd(x)
out_fg = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.slogdet(x) outs = pt_nlinalg.slogdet(x)
out_fg = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
def test_pinv(): def test_pinv():
x = matrix("x") x = matrix("x")
x_inv = pt_nlinalg.pinv(x) x_inv = pt_nlinalg.pinv(x)
fgraph = FunctionGraph([x], [x_inv])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_jax_and_py(fgraph, [x_np]) compare_jax_and_py([x], [x_inv], [x_np])
def test_pinv_hermitian(): def test_pinv_hermitian():
...@@ -94,8 +86,7 @@ def test_kron(): ...@@ -94,8 +86,7 @@ def test_kron():
y = matrix("y") y = matrix("y")
z = pt_nlinalg.kron(x, y) z = pt_nlinalg.kron(x, y)
fgraph = FunctionGraph([x, y], [z])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_jax_and_py(fgraph, [x_np, y_np]) compare_jax_and_py([x, y], [z], [x_np, y_np])
...@@ -3,7 +3,6 @@ import pytest ...@@ -3,7 +3,6 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor.pad import PadMode from pytensor.tensor.pad import PadMode
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -53,10 +52,10 @@ def test_jax_pad(mode: PadMode, kwargs): ...@@ -53,10 +52,10 @@ def test_jax_pad(mode: PadMode, kwargs):
x = np.random.normal(size=(3, 3)) x = np.random.normal(size=(3, 3))
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs) res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
res_fg = FunctionGraph([x_pt], [res])
compare_jax_and_py( compare_jax_and_py(
res_fg, [x_pt],
[res],
[x], [x],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL), assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
py_mode="FAST_RUN", py_mode="FAST_RUN",
......
...@@ -7,13 +7,11 @@ import pytensor.tensor as pt ...@@ -7,13 +7,11 @@ import pytensor.tensor as pt
import pytensor.tensor.random.basic as ptr import pytensor.tensor.random.basic as ptr
from pytensor import clone_replace from pytensor import clone_replace
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.random.basic import RandomVariable from pytensor.tensor.random.basic import RandomVariable
from pytensor.tensor.random.type import RandomType from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.random.utils import RandomStream
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value from tests.link.jax.test_basic import compare_jax_and_py, jax_mode
from tests.tensor.random.test_basic import ( from tests.tensor.random.test_basic import (
batched_permutation_tester, batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester, batched_unweighted_choice_without_replacement_tester,
...@@ -147,11 +145,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -147,11 +145,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.beta, ptr.beta,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -163,11 +161,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -163,11 +161,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.cauchy, ptr.cauchy,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -179,7 +177,7 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -179,7 +177,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.exponential, ptr.exponential,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
...@@ -191,11 +189,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -191,11 +189,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr._gamma, ptr._gamma,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([0.5, 3.0], dtype=np.float64), np.array([0.5, 3.0], dtype=np.float64),
), ),
...@@ -207,11 +205,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -207,11 +205,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.gumbel, ptr.gumbel,
[ [
set_test_value( (
pt.lvector(), pt.lvector(),
np.array([1, 2], dtype=np.int64), np.array([1, 2], dtype=np.int64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -223,8 +221,8 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -223,8 +221,8 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.laplace, ptr.laplace,
[ [
set_test_value(pt.dvector(), np.array([1.0, 2.0], dtype=np.float64)), (pt.dvector(), np.array([1.0, 2.0], dtype=np.float64)),
set_test_value(pt.dscalar(), np.array(1.0, dtype=np.float64)), (pt.dscalar(), np.array(1.0, dtype=np.float64)),
], ],
(2,), (2,),
"laplace", "laplace",
...@@ -233,11 +231,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -233,11 +231,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.logistic, ptr.logistic,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -249,11 +247,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -249,11 +247,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.lognormal, ptr.lognormal,
[ [
set_test_value( (
pt.lvector(), pt.lvector(),
np.array([0, 0], dtype=np.int64), np.array([0, 0], dtype=np.int64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -265,11 +263,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -265,11 +263,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.normal, ptr.normal,
[ [
set_test_value( (
pt.lvector(), pt.lvector(),
np.array([1, 2], dtype=np.int64), np.array([1, 2], dtype=np.int64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -281,11 +279,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -281,11 +279,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.pareto, ptr.pareto,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([2.0, 10.0], dtype=np.float64), np.array([2.0, 10.0], dtype=np.float64),
), ),
...@@ -297,7 +295,7 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -297,7 +295,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.poisson, ptr.poisson,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([100000.0, 200000.0], dtype=np.float64), np.array([100000.0, 200000.0], dtype=np.float64),
), ),
...@@ -309,11 +307,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -309,11 +307,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.integers, ptr.integers,
[ [
set_test_value( (
pt.lscalar(), pt.lscalar(),
np.array(0, dtype=np.int64), np.array(0, dtype=np.int64),
), ),
set_test_value( # high-value necessary since test on cdf ( # high-value necessary since test on cdf
pt.lscalar(), pt.lscalar(),
np.array(1000, dtype=np.int64), np.array(1000, dtype=np.int64),
), ),
...@@ -332,15 +330,15 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -332,15 +330,15 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.t, ptr.t,
[ [
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(2.0, dtype=np.float64), np.array(2.0, dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -352,11 +350,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -352,11 +350,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.uniform, ptr.uniform,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1000.0, dtype=np.float64), np.array(1000.0, dtype=np.float64),
), ),
...@@ -368,11 +366,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -368,11 +366,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.halfnormal, ptr.halfnormal,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([-1.0, 200.0], dtype=np.float64), np.array([-1.0, 200.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1000.0, dtype=np.float64), np.array(1000.0, dtype=np.float64),
), ),
...@@ -384,11 +382,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -384,11 +382,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.invgamma, ptr.invgamma,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([10.4, 2.8], dtype=np.float64), np.array([10.4, 2.8], dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([3.4, 7.3], dtype=np.float64), np.array([3.4, 7.3], dtype=np.float64),
), ),
...@@ -400,7 +398,7 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -400,7 +398,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.chisquare, ptr.chisquare,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([2.4, 4.9], dtype=np.float64), np.array([2.4, 4.9], dtype=np.float64),
), ),
...@@ -412,15 +410,15 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -412,15 +410,15 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.gengamma, ptr.gengamma,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([10.4, 2.8], dtype=np.float64), np.array([10.4, 2.8], dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([3.4, 7.3], dtype=np.float64), np.array([3.4, 7.3], dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([0.9, 2.0], dtype=np.float64), np.array([0.9, 2.0], dtype=np.float64),
), ),
...@@ -432,11 +430,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -432,11 +430,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
( (
ptr.wald, ptr.wald,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([10.4, 2.8], dtype=np.float64), np.array([10.4, 2.8], dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([4.5, 2.0], dtype=np.float64), np.array([4.5, 2.0], dtype=np.float64),
), ),
...@@ -449,11 +447,11 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -449,11 +447,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
pytest.param( pytest.param(
ptr.vonmises, ptr.vonmises,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([-0.5, 1.3], dtype=np.float64), np.array([-0.5, 1.3], dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([5.5, 13.0], dtype=np.float64), np.array([5.5, 13.0], dtype=np.float64),
), ),
...@@ -478,20 +476,16 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c ...@@ -478,20 +476,16 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
The transpiled `RandomVariable` `Op`. The transpiled `RandomVariable` `Op`.
dist_params dist_params
The parameters passed to the op. The parameters passed to the op.
""" """
dist_params, test_values = (
zip(*dist_params, strict=True) if dist_params else ([], [])
)
rng = shared(np.random.default_rng(29403)) rng = shared(np.random.default_rng(29403))
g = rv_op(*dist_params, size=(10000, *base_size), rng=rng) g = rv_op(*dist_params, size=(10000, *base_size), rng=rng)
g_fn = compile_random_function(dist_params, g, mode=jax_mode) g_fn = compile_random_function(dist_params, g, mode=jax_mode)
samples = g_fn( samples = g_fn(*test_values)
*[
i.tag.test_value
for i in g_fn.maker.fgraph.inputs
if not isinstance(i, SharedVariable | Constant)
]
)
bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_params]) bcast_dist_args = np.broadcast_arrays(*test_values)
for idx in np.ndindex(*base_size): for idx in np.ndindex(*base_size):
cdf_params = params_conv(*(arg[idx] for arg in bcast_dist_args)) cdf_params = params_conv(*(arg[idx] for arg in bcast_dist_args))
...@@ -775,13 +769,12 @@ def test_random_unimplemented(): ...@@ -775,13 +769,12 @@ def test_random_unimplemented():
nonexistentrv = NonExistentRV() nonexistentrv = NonExistentRV()
rng = shared(np.random.default_rng(123)) rng = shared(np.random.default_rng(123))
out = nonexistentrv(rng=rng) out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
with pytest.warns( with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
): ):
compare_jax_and_py(fgraph, []) compare_jax_and_py([], [out], [])
def test_random_custom_implementation(): def test_random_custom_implementation():
...@@ -810,11 +803,10 @@ def test_random_custom_implementation(): ...@@ -810,11 +803,10 @@ def test_random_custom_implementation():
nonexistentrv = CustomRV() nonexistentrv = CustomRV()
rng = shared(np.random.default_rng(123)) rng = shared(np.random.default_rng(123))
out = nonexistentrv(rng=rng) out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
with pytest.warns( with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
): ):
compare_jax_and_py(fgraph, []) compare_jax_and_py([], [out], [])
def test_random_concrete_shape(): def test_random_concrete_shape():
......
...@@ -5,7 +5,6 @@ import pytensor.scalar.basic as ps ...@@ -5,7 +5,6 @@ import pytensor.scalar.basic as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite
from pytensor.tensor import as_tensor from pytensor.tensor import as_tensor
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
...@@ -51,20 +50,19 @@ def test_second(): ...@@ -51,20 +50,19 @@ def test_second():
b = scalar("b") b = scalar("b")
out = ps.second(a0, b) out = ps.second(a0, b)
fgraph = FunctionGraph([a0, b], [out]) compare_jax_and_py([a0, b], [out], [10.0, 5.0])
compare_jax_and_py(fgraph, [10.0, 5.0])
a1 = vector("a1") a1 = vector("a1")
out = pt.second(a1, b) out = pt.second(a1, b)
fgraph = FunctionGraph([a1, b], [out]) compare_jax_and_py([a1, b], [out], [np.zeros([5], dtype=config.floatX), 5.0])
compare_jax_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0])
a2 = matrix("a2", shape=(1, None), dtype="float64") a2 = matrix("a2", shape=(1, None), dtype="float64")
b2 = matrix("b2", shape=(None, 1), dtype="int32") b2 = matrix("b2", shape=(None, 1), dtype="int32")
out = pt.second(a2, b2) out = pt.second(a2, b2)
fgraph = FunctionGraph([a2, b2], [out])
compare_jax_and_py( compare_jax_and_py(
fgraph, [np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")] [a2, b2],
[out],
[np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")],
) )
...@@ -81,11 +79,10 @@ def test_second_constant_scalar(): ...@@ -81,11 +79,10 @@ def test_second_constant_scalar():
def test_identity(): def test_identity():
a = scalar("a") a = scalar("a")
a.tag.test_value = 10 a_test_value = 10
out = ps.identity(a) out = ps.identity(a)
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -109,13 +106,11 @@ def test_jax_Composite_singe_output(x, y, x_val, y_val): ...@@ -109,13 +106,11 @@ def test_jax_Composite_singe_output(x, y, x_val, y_val):
out = comp_op(x, y) out = comp_op(x, y)
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [ test_input_vals = [
x_val.astype(config.floatX), x_val.astype(config.floatX),
y_val.astype(config.floatX), y_val.astype(config.floatX),
] ]
_ = compare_jax_and_py(out_fg, test_input_vals) _ = compare_jax_and_py([x, y], [out], test_input_vals)
def test_jax_Composite_multi_output(): def test_jax_Composite_multi_output():
...@@ -124,32 +119,28 @@ def test_jax_Composite_multi_output(): ...@@ -124,32 +119,28 @@ def test_jax_Composite_multi_output():
x_s = ps.float64("xs") x_s = ps.float64("xs")
outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x) outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x)
fgraph = FunctionGraph([x], outs) compare_jax_and_py([x], outs, [np.arange(10, dtype=config.floatX)])
compare_jax_and_py(fgraph, [np.arange(10, dtype=config.floatX)])
def test_erf(): def test_erf():
x = scalar("x") x = scalar("x")
out = erf(x) out = erf(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0]) compare_jax_and_py([x], [out], [1.0])
def test_erfc(): def test_erfc():
x = scalar("x") x = scalar("x")
out = erfc(x) out = erfc(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0]) compare_jax_and_py([x], [out], [1.0])
def test_erfinv(): def test_erfinv():
x = scalar("x") x = scalar("x")
out = erfinv(x) out = erfinv(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [0.95]) compare_jax_and_py([x], [out], [0.95])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -166,8 +157,7 @@ def test_tfp_ops(op, test_values): ...@@ -166,8 +157,7 @@ def test_tfp_ops(op, test_values):
inputs = [as_tensor(test_value).type() for test_value in test_values] inputs = [as_tensor(test_value).type() for test_value in test_values]
output = op(*inputs) output = op(*inputs)
fg = FunctionGraph(inputs, [output]) compare_jax_and_py(inputs, [output], test_values)
compare_jax_and_py(fg, test_values)
def test_betaincinv(): def test_betaincinv():
...@@ -175,9 +165,10 @@ def test_betaincinv(): ...@@ -175,9 +165,10 @@ def test_betaincinv():
b = vector("b", dtype="float64") b = vector("b", dtype="float64")
x = vector("x", dtype="float64") x = vector("x", dtype="float64")
out = betaincinv(a, b, x) out = betaincinv(a, b, x)
fg = FunctionGraph([a, b, x], [out])
compare_jax_and_py( compare_jax_and_py(
fg, [a, b, x],
[out],
[ [
np.array([5.5, 7.0]), np.array([5.5, 7.0]),
np.array([5.5, 7.0]), np.array([5.5, 7.0]),
...@@ -190,39 +181,40 @@ def test_gammaincinv(): ...@@ -190,39 +181,40 @@ def test_gammaincinv():
k = vector("k", dtype="float64") k = vector("k", dtype="float64")
x = vector("x", dtype="float64") x = vector("x", dtype="float64")
out = gammaincinv(k, x) out = gammaincinv(k, x)
fg = FunctionGraph([k, x], [out])
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) compare_jax_and_py([k, x], [out], [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
def test_gammainccinv(): def test_gammainccinv():
k = vector("k", dtype="float64") k = vector("k", dtype="float64")
x = vector("x", dtype="float64") x = vector("x", dtype="float64")
out = gammainccinv(k, x) out = gammainccinv(k, x)
fg = FunctionGraph([k, x], [out])
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])]) compare_jax_and_py([k, x], [out], [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
def test_psi(): def test_psi():
x = scalar("x") x = scalar("x")
out = psi(x) out = psi(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [3.0]) compare_jax_and_py([x], [out], [3.0])
def test_tri_gamma(): def test_tri_gamma():
x = vector("x", dtype="float64") x = vector("x", dtype="float64")
out = tri_gamma(x) out = tri_gamma(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [np.array([3.0, 5.0])]) compare_jax_and_py([x], [out], [np.array([3.0, 5.0])])
def test_polygamma(): def test_polygamma():
n = vector("n", dtype="int32") n = vector("n", dtype="int32")
x = vector("x", dtype="float32") x = vector("x", dtype="float32")
out = polygamma(n, x) out = polygamma(n, x)
fg = FunctionGraph([n, x], [out])
compare_jax_and_py( compare_jax_and_py(
fg, [n, x],
[out],
[ [
np.array([0, 1, 2]).astype("int32"), np.array([0, 1, 2]).astype("int32"),
np.array([0.5, 0.9, 2.5]).astype("float32"), np.array([0.5, 0.9, 2.5]).astype("float32"),
...@@ -233,41 +225,34 @@ def test_polygamma(): ...@@ -233,41 +225,34 @@ def test_polygamma():
def test_log1mexp(): def test_log1mexp():
x = vector("x") x = vector("x")
out = log1mexp(x) out = log1mexp(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]]) compare_jax_and_py([x], [out], [[-1.0, -0.75, -0.5, -0.25]])
def test_nnet(): def test_nnet():
x = vector("x") x = vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) x_test_value = np.r_[1.0, 2.0].astype(config.floatX)
out = sigmoid(x) out = sigmoid(x)
fgraph = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [x_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = softplus(x) out = softplus(x)
fgraph = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [x_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_variadic_Scalar(): def test_jax_variadic_Scalar():
mu = vector("mu", dtype=config.floatX) mu = vector("mu", dtype=config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX) mu_test_value = np.r_[0.1, 1.1].astype(config.floatX)
tau = vector("tau", dtype=config.floatX) tau = vector("tau", dtype=config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) tau_test_value = np.r_[1.0, 2.0].astype(config.floatX)
res = -tau * mu res = -tau * mu
fgraph = FunctionGraph([mu, tau], [res]) compare_jax_and_py([mu, tau], [res], [mu_test_value, tau_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
res = -tau * (tau - mu) ** 2 res = -tau * (tau - mu) ** 2
fgraph = FunctionGraph([mu, tau], [res]) compare_jax_and_py([mu, tau], [res], [mu_test_value, tau_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_add_scalars(): def test_add_scalars():
...@@ -275,8 +260,7 @@ def test_add_scalars(): ...@@ -275,8 +260,7 @@ def test_add_scalars():
size = x.shape[0] + x.shape[0] + x.shape[1] size = x.shape[0] + x.shape[0] + x.shape[1]
out = pt.ones(size).astype(config.floatX) out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.ones((2, 3)).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)])
def test_mul_scalars(): def test_mul_scalars():
...@@ -284,8 +268,7 @@ def test_mul_scalars(): ...@@ -284,8 +268,7 @@ def test_mul_scalars():
size = x.shape[0] * x.shape[0] * x.shape[1] size = x.shape[0] * x.shape[0] * x.shape[1]
out = pt.ones(size).astype(config.floatX) out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.ones((2, 3)).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)])
def test_div_scalars(): def test_div_scalars():
...@@ -293,8 +276,7 @@ def test_div_scalars(): ...@@ -293,8 +276,7 @@ def test_div_scalars():
size = x.shape[0] // x.shape[1] size = x.shape[0] // x.shape[1]
out = pt.ones(size).astype(config.floatX) out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.ones((12, 3)).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)])
def test_mod_scalars(): def test_mod_scalars():
...@@ -302,39 +284,43 @@ def test_mod_scalars(): ...@@ -302,39 +284,43 @@ def test_mod_scalars():
size = x.shape[0] % x.shape[1] size = x.shape[0] % x.shape[1]
out = pt.ones(size).astype(config.floatX) out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.ones((12, 3)).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)])
def test_jax_multioutput(): def test_jax_multioutput():
x = vector("x") x = vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) x_test_value = np.r_[1.0, 2.0].astype(config.floatX)
y = vector("y") y = vector("y")
y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) y_test_value = np.r_[3.0, 4.0].astype(config.floatX)
w = cosh(x**2 + y / 3.0) w = cosh(x**2 + y / 3.0)
v = cosh(x / 3.0 + y**2) v = cosh(x / 3.0 + y**2)
fgraph = FunctionGraph([x, y], [w, v]) compare_jax_and_py([x, y], [w, v], [x_test_value, y_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_logp(): def test_jax_logp():
mu = vector("mu") mu = vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX) mu_test_value = np.r_[0.0, 0.0].astype(config.floatX)
tau = vector("tau") tau = vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX) tau_test_value = np.r_[1.0, 1.0].astype(config.floatX)
sigma = vector("sigma") sigma = vector("sigma")
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX) sigma_test_value = (1.0 / tau_test_value).astype(config.floatX)
value = vector("value") value = vector("value")
value.tag.test_value = np.r_[0.1, -10].astype(config.floatX) value_test_value = np.r_[0.1, -10].astype(config.floatX)
logp = (-tau * (value - mu) ** 2 + log(tau / np.pi / 2.0)) / 2.0 logp = (-tau * (value - mu) ** 2 + log(tau / np.pi / 2.0)) / 2.0
conditions = [sigma > 0] conditions = [sigma > 0]
alltrue = pt_all([pt_all(1 * val) for val in conditions]) alltrue = pt_all([pt_all(1 * val) for val in conditions])
normal_logp = pt.switch(alltrue, logp, -np.inf) normal_logp = pt.switch(alltrue, logp, -np.inf)
fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp]) compare_jax_and_py(
[mu, tau, sigma, value],
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) [normal_logp],
[
mu_test_value,
tau_test_value,
sigma_test_value,
value_test_value,
],
)
...@@ -7,7 +7,6 @@ import pytensor.tensor as pt ...@@ -7,7 +7,6 @@ import pytensor.tensor as pt
from pytensor import function, shared from pytensor import function, shared
from pytensor.compile import get_mode from pytensor.compile import get_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.scan import until from pytensor.scan import until
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
...@@ -30,9 +29,8 @@ def test_scan_sit_sot(view): ...@@ -30,9 +29,8 @@ def test_scan_sit_sot(view):
) )
if view: if view:
xs = xs[view] xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.e] test_input_vals = [np.e]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0], [xs], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) @pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
...@@ -45,9 +43,8 @@ def test_scan_mit_sot(view): ...@@ -45,9 +43,8 @@ def test_scan_mit_sot(view):
) )
if view: if view:
xs = xs[view] xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.full((3,), np.e)] test_input_vals = [np.full((3,), np.e)]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0], [xs], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)]) @pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)])
...@@ -72,9 +69,8 @@ def test_scan_multiple_mit_sot(view_x, view_y): ...@@ -72,9 +69,8 @@ def test_scan_multiple_mit_sot(view_x, view_y):
if view_y: if view_y:
ys = ys[view_y] ys = ys[view_y]
fg = FunctionGraph([x0, y0], [xs, ys])
test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)] test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0, y0], [xs, ys], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)]) @pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)])
...@@ -90,12 +86,11 @@ def test_scan_nit_sot(view): ...@@ -90,12 +86,11 @@ def test_scan_nit_sot(view):
) )
if view: if view:
ys = ys[view] ys = ys[view]
fg = FunctionGraph([xs], [ys])
test_input_vals = [rng.normal(size=10)] test_input_vals = [rng.normal(size=10)]
# We need to remove pushout rewrites, or the whole scan would just be # We need to remove pushout rewrites, or the whole scan would just be
# converted to an Elemwise on xs # converted to an Elemwise on xs
jax_fn, _ = compare_jax_and_py( jax_fn, _ = compare_jax_and_py(
fg, test_input_vals, jax_mode=get_mode("JAX").excluding("scan_pushout") [xs], [ys], test_input_vals, jax_mode=get_mode("JAX").excluding("scan_pushout")
) )
scan_nodes = [ scan_nodes = [
node for node in jax_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan) node for node in jax_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
...@@ -112,8 +107,7 @@ def test_scan_mit_mot(): ...@@ -112,8 +107,7 @@ def test_scan_mit_mot():
n_steps=10, n_steps=10,
) )
grads_wrt_xs = pt.grad(ys.sum(), wrt=xs) grads_wrt_xs = pt.grad(ys.sum(), wrt=xs)
fg = FunctionGraph([xs], [grads_wrt_xs]) compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)])
compare_jax_and_py(fg, [np.arange(10)])
def test_scan_update(): def test_scan_update():
...@@ -192,8 +186,7 @@ def test_scan_while(): ...@@ -192,8 +186,7 @@ def test_scan_while():
n_steps=100, n_steps=100,
) )
fg = FunctionGraph([], [xs]) compare_jax_and_py([], [xs], [])
compare_jax_and_py(fg, [])
def test_scan_SEIR(): def test_scan_SEIR():
...@@ -257,11 +250,6 @@ def test_scan_SEIR(): ...@@ -257,11 +250,6 @@ def test_scan_SEIR():
logp_c_all.name = "C_t_logp" logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp" logp_d_all.name = "D_t_logp"
out_fg = FunctionGraph(
[at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
s0, e0, i0 = 100, 50, 25 s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=config.floatX) logp_c0 = np.array(0.0, dtype=config.floatX)
logp_d0 = np.array(0.0, dtype=config.floatX) logp_d0 = np.array(0.0, dtype=config.floatX)
...@@ -283,7 +271,12 @@ def test_scan_SEIR(): ...@@ -283,7 +271,12 @@ def test_scan_SEIR():
gamma_val, gamma_val,
delta_val, delta_val,
] ]
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX") compare_jax_and_py(
[at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
test_input_vals,
jax_mode="JAX",
)
def test_scan_mitsot_with_nonseq(): def test_scan_mitsot_with_nonseq():
...@@ -313,10 +306,8 @@ def test_scan_mitsot_with_nonseq(): ...@@ -313,10 +306,8 @@ def test_scan_mitsot_with_nonseq():
y_scan_pt.name = "y" y_scan_pt.name = "y"
y_scan_pt.owner.inputs[0].name = "y_all" y_scan_pt.owner.inputs[0].name = "y_all"
out_fg = FunctionGraph([a_pt], [y_scan_pt])
test_input_vals = [np.array(10.0).astype(config.floatX)] test_input_vals = [np.array(10.0).astype(config.floatX)]
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([a_pt], [y_scan_pt], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("x0_func", [dvector, dmatrix]) @pytest.mark.parametrize("x0_func", [dvector, dmatrix])
...@@ -343,9 +334,8 @@ def test_nd_scan_sit_sot(x0_func, A_func): ...@@ -343,9 +334,8 @@ def test_nd_scan_sit_sot(x0_func, A_func):
) )
A_val = np.eye(k, dtype=config.floatX) A_val = np.eye(k, dtype=config.floatX)
fg = FunctionGraph([x0, A], [xs])
test_input_vals = [x0_val, A_val] test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0, A], [xs], test_input_vals, jax_mode="JAX")
def test_nd_scan_sit_sot_with_seq(): def test_nd_scan_sit_sot_with_seq():
...@@ -366,9 +356,8 @@ def test_nd_scan_sit_sot_with_seq(): ...@@ -366,9 +356,8 @@ def test_nd_scan_sit_sot_with_seq():
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
A_val = np.eye(k, dtype=config.floatX) A_val = np.eye(k, dtype=config.floatX)
fg = FunctionGraph([x, A], [xs])
test_input_vals = [x_val, A_val] test_input_vals = [x_val, A_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x, A], [xs], test_input_vals, jax_mode="JAX")
def test_nd_scan_mit_sot(): def test_nd_scan_mit_sot():
...@@ -384,13 +373,12 @@ def test_nd_scan_mit_sot(): ...@@ -384,13 +373,12 @@ def test_nd_scan_mit_sot():
n_steps=10, n_steps=10,
) )
fg = FunctionGraph([x0, A, B], [xs])
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3) x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
A_val = np.eye(3, dtype=config.floatX) A_val = np.eye(3, dtype=config.floatX)
B_val = np.eye(3, dtype=config.floatX) B_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val, B_val] test_input_vals = [x0_val, A_val, B_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0, A, B], [xs], test_input_vals, jax_mode="JAX")
def test_nd_scan_sit_sot_with_carry(): def test_nd_scan_sit_sot_with_carry():
...@@ -409,12 +397,11 @@ def test_nd_scan_sit_sot_with_carry(): ...@@ -409,12 +397,11 @@ def test_nd_scan_sit_sot_with_carry():
mode=get_mode("JAX"), mode=get_mode("JAX"),
) )
fg = FunctionGraph([x0, A], xs)
x0_val = np.arange(3, dtype=config.floatX) x0_val = np.arange(3, dtype=config.floatX)
A_val = np.eye(3, dtype=config.floatX) A_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val] test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") compare_jax_and_py([x0, A], xs, test_input_vals, jax_mode="JAX")
def test_default_mode_excludes_incompatible_rewrites(): def test_default_mode_excludes_incompatible_rewrites():
...@@ -422,8 +409,7 @@ def test_default_mode_excludes_incompatible_rewrites(): ...@@ -422,8 +409,7 @@ def test_default_mode_excludes_incompatible_rewrites():
A = matrix("A") A = matrix("A")
B = matrix("B") B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
fg = FunctionGraph([A, B], [out]) compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX")
compare_jax_and_py(fg, [np.eye(3), np.eye(3)], jax_mode="JAX")
def test_dynamic_sequence_length(): def test_dynamic_sequence_length():
......
...@@ -4,7 +4,6 @@ import pytest ...@@ -4,7 +4,6 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import iscalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -13,29 +12,27 @@ from tests.link.jax.test_basic import compare_jax_and_py ...@@ -13,29 +12,27 @@ from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_shape_ops(): def test_jax_shape_ops():
x_np = np.zeros((20, 3)) x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np)) x = Shape()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [], must_be_device_array=False) compare_jax_and_py([], [x], [], must_be_device_array=False)
x = Shape_i(1)(pt.as_tensor_variable(x_np)) x = Shape_i(1)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [], must_be_device_array=False) compare_jax_and_py([], [x], [], must_be_device_array=False)
def test_jax_specify_shape(): def test_jax_specify_shape():
in_pt = pt.matrix("in") in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None)) x = pt.specify_shape(in_pt, (4, None))
x_fg = FunctionGraph([in_pt], [x]) compare_jax_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)])
compare_jax_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
# When used to assert two arrays have similar shapes # When used to assert two arrays have similar shapes
in_pt = pt.matrix("in") in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape") shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape) x = pt.specify_shape(in_pt, shape_pt.shape)
x_fg = FunctionGraph([in_pt, shape_pt], [x])
compare_jax_and_py( compare_jax_and_py(
x_fg, [in_pt, shape_pt],
[x],
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
) )
...@@ -43,20 +40,17 @@ def test_jax_specify_shape(): ...@@ -43,20 +40,17 @@ def test_jax_specify_shape():
def test_jax_Reshape_constant(): def test_jax_Reshape_constant():
a = vector("a") a = vector("a")
x = reshape(a, (2, 2)) x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x]) compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
def test_jax_Reshape_concrete_shape(): def test_jax_Reshape_concrete_shape():
"""JAX should compile when a concrete value is passed for the `shape` parameter.""" """JAX should compile when a concrete value is passed for the `shape` parameter."""
a = vector("a") a = vector("a")
x = reshape(a, a.shape) x = reshape(a, a.shape)
x_fg = FunctionGraph([a], [x]) compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
x_fg = FunctionGraph([a], [x]) compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
@pytest.mark.xfail( @pytest.mark.xfail(
...@@ -66,23 +60,20 @@ def test_jax_Reshape_shape_graph_input(): ...@@ -66,23 +60,20 @@ def test_jax_Reshape_shape_graph_input():
a = vector("a") a = vector("a")
shape_pt = iscalar("b") shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt)) x = reshape(a, (shape_pt, shape_pt))
x_fg = FunctionGraph([a, shape_pt], [x]) compare_jax_and_py(
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
)
def test_jax_compile_ops(): def test_jax_compile_ops():
x = DeepCopyOp()(pt.as_tensor_variable(1.1)) x = DeepCopyOp()(pt.as_tensor_variable(1.1))
x_fg = FunctionGraph([], [x]) compare_jax_and_py([], [x], [])
compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 1, 1)) x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np)) x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py([], [x], [])
x = ViewOp()(pt.as_tensor_variable(x_np)) x = ViewOp()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py([], [x], [])
...@@ -6,7 +6,6 @@ import pytest ...@@ -6,7 +6,6 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor import slinalg as pt_slinalg from pytensor.tensor import slinalg as pt_slinalg
from pytensor.tensor import subtensor as pt_subtensor from pytensor.tensor import subtensor as pt_subtensor
...@@ -30,13 +29,11 @@ def test_jax_basic(): ...@@ -30,13 +29,11 @@ def test_jax_basic():
out = pt_subtensor.inc_subtensor(out[0, 1], 2.0) out = pt_subtensor.inc_subtensor(out[0, 1], 2.0)
out = out[:5, :3] out = out[:5, :3]
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [ test_input_vals = [
np.tile(np.arange(10), (10, 1)).astype(config.floatX), np.tile(np.arange(10), (10, 1)).astype(config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
] ]
_, [jax_res] = compare_jax_and_py(out_fg, test_input_vals) _, [jax_res] = compare_jax_and_py([x, y], [out], test_input_vals)
# Confirm that the `Subtensor` slice operations are correct # Confirm that the `Subtensor` slice operations are correct
assert jax_res.shape == (5, 3) assert jax_res.shape == (5, 3)
...@@ -46,19 +43,17 @@ def test_jax_basic(): ...@@ -46,19 +43,17 @@ def test_jax_basic():
assert jax_res[0, 1] == -8.0 assert jax_res[0, 1] == -8.0
out = clip(x, y, 5) out = clip(x, y, 5)
out_fg = FunctionGraph([x, y], [out]) compare_jax_and_py([x, y], [out], test_input_vals)
compare_jax_and_py(out_fg, test_input_vals)
out = pt.diagonal(x, 0) out = pt.diagonal(x, 0)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] [x], [out], [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]
) )
out = pt_slinalg.cholesky(x) out = pt_slinalg.cholesky(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [x],
[out],
[ [
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX config.floatX
...@@ -68,9 +63,9 @@ def test_jax_basic(): ...@@ -68,9 +63,9 @@ def test_jax_basic():
# not sure why this isn't working yet with lower=False # not sure why this isn't working yet with lower=False
out = pt_slinalg.Cholesky(lower=False)(x) out = pt_slinalg.Cholesky(lower=False)(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [x],
[out],
[ [
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX config.floatX
...@@ -79,9 +74,9 @@ def test_jax_basic(): ...@@ -79,9 +74,9 @@ def test_jax_basic():
) )
out = pt_slinalg.solve(x, b) out = pt_slinalg.solve(x, b)
out_fg = FunctionGraph([x, b], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [x, b],
[out],
[ [
np.eye(10).astype(config.floatX), np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX), np.arange(10).astype(config.floatX),
...@@ -89,19 +84,17 @@ def test_jax_basic(): ...@@ -89,19 +84,17 @@ def test_jax_basic():
) )
out = pt.diag(b) out = pt.diag(b)
out_fg = FunctionGraph([b], [out]) compare_jax_and_py([b], [out], [np.arange(10).astype(config.floatX)])
compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])
out = pt_nlinalg.det(x) out = pt_nlinalg.det(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] [x], [out], [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]
) )
out = pt_nlinalg.matrix_inverse(x) out = pt_nlinalg.matrix_inverse(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [x],
[out],
[ [
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX config.floatX
...@@ -124,9 +117,9 @@ def test_jax_SolveTriangular(trans, lower, check_finite): ...@@ -124,9 +117,9 @@ def test_jax_SolveTriangular(trans, lower, check_finite):
lower=lower, lower=lower,
check_finite=check_finite, check_finite=check_finite,
) )
out_fg = FunctionGraph([x, b], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [x, b],
[out],
[ [
np.eye(10).astype(config.floatX), np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX), np.arange(10).astype(config.floatX),
...@@ -141,10 +134,10 @@ def test_jax_block_diag(): ...@@ -141,10 +134,10 @@ def test_jax_block_diag():
D = matrix("D") D = matrix("D")
out = pt_slinalg.block_diag(A, B, C, D) out = pt_slinalg.block_diag(A, B, C, D)
out_fg = FunctionGraph([A, B, C, D], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [A, B, C, D],
[out],
[ [
np.random.normal(size=(5, 5)).astype(config.floatX), np.random.normal(size=(5, 5)).astype(config.floatX),
np.random.normal(size=(3, 3)).astype(config.floatX), np.random.normal(size=(3, 3)).astype(config.floatX),
...@@ -158,9 +151,10 @@ def test_jax_block_diag_blockwise(): ...@@ -158,9 +151,10 @@ def test_jax_block_diag_blockwise():
A = pt.tensor3("A") A = pt.tensor3("A")
B = pt.tensor3("B") B = pt.tensor3("B")
out = pt_slinalg.block_diag(A, B) out = pt_slinalg.block_diag(A, B)
out_fg = FunctionGraph([A, B], [out])
compare_jax_and_py( compare_jax_and_py(
out_fg, [A, B],
[out],
[ [
np.random.normal(size=(5, 5, 5)).astype(config.floatX), np.random.normal(size=(5, 5, 5)).astype(config.floatX),
np.random.normal(size=(5, 3, 3)).astype(config.floatX), np.random.normal(size=(5, 3, 3)).astype(config.floatX),
...@@ -174,11 +168,11 @@ def test_jax_eigvalsh(lower): ...@@ -174,11 +168,11 @@ def test_jax_eigvalsh(lower):
B = matrix("B") B = matrix("B")
out = pt_slinalg.eigvalsh(A, B, lower=lower) out = pt_slinalg.eigvalsh(A, B, lower=lower)
out_fg = FunctionGraph([A, B], [out])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
compare_jax_and_py( compare_jax_and_py(
out_fg, [A, B],
[out],
[ [
np.array( np.array(
[[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]] [[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]
...@@ -189,7 +183,8 @@ def test_jax_eigvalsh(lower): ...@@ -189,7 +183,8 @@ def test_jax_eigvalsh(lower):
], ],
) )
compare_jax_and_py( compare_jax_and_py(
out_fg, [A, B],
[out],
[ [
np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype( np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype(
config.floatX config.floatX
...@@ -207,11 +202,11 @@ def test_jax_solve_discrete_lyapunov( ...@@ -207,11 +202,11 @@ def test_jax_solve_discrete_lyapunov(
A = pt.tensor(name="A", shape=shape) A = pt.tensor(name="A", shape=shape)
B = pt.tensor(name="B", shape=shape) B = pt.tensor(name="B", shape=shape)
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method) out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
out_fg = FunctionGraph([A, B], [out])
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3 atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
compare_jax_and_py( compare_jax_and_py(
out_fg, [A, B],
[out],
[ [
np.random.normal(size=shape).astype(config.floatX), np.random.normal(size=shape).astype(config.floatX),
np.random.normal(size=shape).astype(config.floatX), np.random.normal(size=shape).astype(config.floatX),
......
import numpy as np import numpy as np
import pytest import pytest
from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix from pytensor.tensor import matrix
from pytensor.tensor.sort import argsort, sort from pytensor.tensor.sort import argsort, sort
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -12,6 +11,5 @@ from tests.link.jax.test_basic import compare_jax_and_py ...@@ -12,6 +11,5 @@ from tests.link.jax.test_basic import compare_jax_and_py
def test_sort(func, axis): def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64") x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis) out = func(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]]) arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr]) compare_jax_and_py([x], [out], [arr])
...@@ -5,7 +5,6 @@ import scipy.sparse ...@@ -5,7 +5,6 @@ import scipy.sparse
import pytensor.sparse as ps import pytensor.sparse as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import function from pytensor import function
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -50,8 +49,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op): ...@@ -50,8 +49,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):
test_values.append(y_test) test_values.append(y_test)
dot_pt = op(x_pt, y_pt) dot_pt = op(x_pt, y_pt)
fgraph = FunctionGraph(inputs, [dot_pt]) compare_jax_and_py(inputs, [dot_pt], test_values, jax_mode="JAX")
compare_jax_and_py(fgraph, test_values, jax_mode="JAX")
def test_sparse_dot_non_const_raises(): def test_sparse_dot_non_const_raises():
......
...@@ -21,55 +21,55 @@ def test_jax_Subtensor_constant(): ...@@ -21,55 +21,55 @@ def test_jax_Subtensor_constant():
# Basic indices # Basic indices
out_pt = x_pt[1, 2, 0] out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:, 1, :] out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[:2, 1, :] out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:2, 1, :] out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
# Advanced indexing # Advanced indexing
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2]) out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], [2, 3]] out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
# Advanced and basic indexing # Advanced and basic indexing
out_pt = x_pt[[1, 2], :] out_pt = x_pt[[1, 2], :]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], :, [3, 4]] out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
# Flipping # Flipping
out_pt = x_pt[::-1] out_pt = x_pt[::-1]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
# Boolean indexing should work if indexes are constant # Boolean indexing should work if indexes are constant
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)] out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
@pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling") @pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling")
...@@ -78,8 +78,8 @@ def test_jax_Subtensor_dynamic(): ...@@ -78,8 +78,8 @@ def test_jax_Subtensor_dynamic():
x = pt.arange(3) x = pt.arange(3)
out_pt = x[:a] out_pt = x[:a]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([a], [out_pt])
compare_jax_and_py(out_fg, [1]) compare_jax_and_py([a], [out_pt], [1])
def test_jax_Subtensor_dynamic_boolean_mask(): def test_jax_Subtensor_dynamic_boolean_mask():
...@@ -90,11 +90,9 @@ def test_jax_Subtensor_dynamic_boolean_mask(): ...@@ -90,11 +90,9 @@ def test_jax_Subtensor_dynamic_boolean_mask():
out_pt = x_pt[x_pt < 0] out_pt = x_pt[x_pt < 0]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
x_pt_test = np.arange(-5, 5) x_pt_test = np.arange(-5, 5)
with pytest.raises(NonConcreteBooleanIndexError): with pytest.raises(NonConcreteBooleanIndexError):
compare_jax_and_py(out_fg, [x_pt_test]) compare_jax_and_py([x_pt], [out_pt], [x_pt_test])
def test_jax_Subtensor_boolean_mask_reexpressible(): def test_jax_Subtensor_boolean_mask_reexpressible():
...@@ -110,8 +108,10 @@ def test_jax_Subtensor_boolean_mask_reexpressible(): ...@@ -110,8 +108,10 @@ def test_jax_Subtensor_boolean_mask_reexpressible():
""" """
x_pt = pt.matrix("x") x_pt = pt.matrix("x")
out_pt = x_pt[x_pt < 0].sum() out_pt = x_pt[x_pt < 0].sum()
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [np.arange(25).reshape(5, 5).astype(config.floatX)]) compare_jax_and_py(
[x_pt], [out_pt], [np.arange(25).reshape(5, 5).astype(config.floatX)]
)
def test_boolean_indexing_sum_not_applicable(): def test_boolean_indexing_sum_not_applicable():
...@@ -136,19 +136,19 @@ def test_jax_IncSubtensor(): ...@@ -136,19 +136,19 @@ def test_jax_IncSubtensor():
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[:2, 0, 0], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
# "Set" advanced indices # "Set" advanced indices
st_pt = pt.as_tensor_variable( st_pt = pt.as_tensor_variable(
...@@ -156,39 +156,39 @@ def test_jax_IncSubtensor(): ...@@ -156,39 +156,39 @@ def test_jax_IncSubtensor():
) )
out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, 0], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
# "Set" boolean indices # "Set" boolean indices
mask_pt = pt.constant(x_np > 0) mask_pt = pt.constant(x_np > 0)
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0) out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
# "Increment" basic indices # "Increment" basic indices
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt) out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[:2, 0, 0], st_pt) out_pt = pt_subtensor.inc_subtensor(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
# "Increment" advanced indices # "Increment" advanced indices
st_pt = pt.as_tensor_variable( st_pt = pt.as_tensor_variable(
...@@ -196,33 +196,33 @@ def test_jax_IncSubtensor(): ...@@ -196,33 +196,33 @@ def test_jax_IncSubtensor():
) )
out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt) out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, 0], st_pt) out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
# "Increment" boolean indices # "Increment" boolean indices
mask_pt = pt.constant(x_np > 0) mask_pt = pt.constant(x_np > 0)
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 1.0) out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3]) st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, :3], st_pt) out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3]) st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, :3], st_pt) out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out_pt], [])
def test_jax_IncSubtensor_boolean_indexing_reexpressible(): def test_jax_IncSubtensor_boolean_indexing_reexpressible():
...@@ -243,14 +243,14 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible(): ...@@ -243,14 +243,14 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
mask_pt = pt.as_tensor(x_pt) > 0 mask_pt = pt.as_tensor(x_pt) > 0
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0) out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
mask_pt = pt.as_tensor(x_pt) > 0 mask_pt = pt.as_tensor(x_pt) > 0
out_pt = pt_subtensor.inc_subtensor(x_pt[mask_pt], 1.0) out_pt = pt_subtensor.inc_subtensor(x_pt[mask_pt], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np]) compare_jax_and_py([x_pt], [out_pt], [x_np])
def test_boolean_indexing_set_or_inc_not_applicable(): def test_boolean_indexing_set_or_inc_not_applicable():
......
...@@ -10,8 +10,6 @@ from jax import errors ...@@ -10,8 +10,6 @@ from jax import errors
import pytensor import pytensor
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.type import iscalar, matrix, scalar, vector from pytensor.tensor.type import iscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_basic import check_alloc_runtime_broadcast from tests.tensor.test_basic import check_alloc_runtime_broadcast
...@@ -19,38 +17,31 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast ...@@ -19,38 +17,31 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast
def test_jax_Alloc(): def test_jax_Alloc():
x = ptb.alloc(0.0, 2, 3) x = ptb.alloc(0.0, 2, 3)
x_fg = FunctionGraph([], [x])
_, [jax_res] = compare_jax_and_py(x_fg, []) _, [jax_res] = compare_jax_and_py([], [x], [])
assert jax_res.shape == (2, 3) assert jax_res.shape == (2, 3)
x = ptb.alloc(1.1, 2, 3) x = ptb.alloc(1.1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py([], [x], [])
x = ptb.AllocEmpty("float32")(2, 3) x = ptb.AllocEmpty("float32")(2, 3)
x_fg = FunctionGraph([], [x])
def compare_shape_dtype(x, y): def compare_shape_dtype(x, y):
(x,) = x np.testing.assert_array_equal(x, y, strict=True)
(y,) = y
return x.shape == y.shape and x.dtype == y.dtype
compare_jax_and_py(x_fg, [], assert_fn=compare_shape_dtype) compare_jax_and_py([], [x], [], assert_fn=compare_shape_dtype)
a = scalar("a") a = scalar("a")
x = ptb.alloc(a, 20) x = ptb.alloc(a, 20)
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [10.0]) compare_jax_and_py([a], [x], [10.0])
a = vector("a") a = vector("a")
x = ptb.alloc(a, 20, 10) x = ptb.alloc(a, 20, 10)
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)]) compare_jax_and_py([a], [x], [np.ones(10, dtype=config.floatX)])
def test_alloc_runtime_broadcast(): def test_alloc_runtime_broadcast():
...@@ -59,34 +50,31 @@ def test_alloc_runtime_broadcast(): ...@@ -59,34 +50,31 @@ def test_alloc_runtime_broadcast():
def test_jax_MakeVector(): def test_jax_MakeVector():
x = ptb.make_vector(1, 2, 3) x = ptb.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py([], [x], [])
def test_arange(): def test_arange():
out = ptb.arange(1, 10, 2) out = ptb.arange(1, 10, 2)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, []) compare_jax_and_py([], [out], [])
def test_arange_of_shape(): def test_arange_of_shape():
x = vector("x") x = vector("x")
out = ptb.arange(1, x.shape[-1], 2) out = ptb.arange(1, x.shape[-1], 2)
fgraph = FunctionGraph([x], [out]) compare_jax_and_py([x], [out], [np.zeros((5,))], jax_mode="JAX")
compare_jax_and_py(fgraph, [np.zeros((5,))], jax_mode="JAX")
def test_arange_nonconcrete(): def test_arange_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values.""" """JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
a = scalar("a") a = scalar("a")
a.tag.test_value = 10 a_test_value = 10
out = ptb.arange(a) out = ptb.arange(a)
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
fgraph = FunctionGraph([a], [out]) compare_jax_and_py([a], [out], [a_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_Join(): def test_jax_Join():
...@@ -94,16 +82,17 @@ def test_jax_Join(): ...@@ -94,16 +82,17 @@ def test_jax_Join():
b = matrix("b") b = matrix("b")
x = ptb.join(0, a, b) x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py( compare_jax_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
], ],
) )
compare_jax_and_py( compare_jax_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX), np.c_[[4.0, 5.0]].astype(config.floatX),
...@@ -111,16 +100,17 @@ def test_jax_Join(): ...@@ -111,16 +100,17 @@ def test_jax_Join():
) )
x = ptb.join(1, a, b) x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py( compare_jax_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
], ],
) )
compare_jax_and_py( compare_jax_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX), np.c_[[5.0, 6.0]].astype(config.floatX),
...@@ -132,9 +122,9 @@ class TestJaxSplit: ...@@ -132,9 +122,9 @@ class TestJaxSplit:
def test_basic(self): def test_basic(self):
a = matrix("a") a = matrix("a")
a_splits = ptb.split(a, splits_size=[1, 2, 3], n_splits=3, axis=0) a_splits = ptb.split(a, splits_size=[1, 2, 3], n_splits=3, axis=0)
fg = FunctionGraph([a], a_splits)
compare_jax_and_py( compare_jax_and_py(
fg, [a],
a_splits,
[ [
np.zeros((6, 4)).astype(config.floatX), np.zeros((6, 4)).astype(config.floatX),
], ],
...@@ -142,9 +132,9 @@ class TestJaxSplit: ...@@ -142,9 +132,9 @@ class TestJaxSplit:
a = matrix("a", shape=(6, None)) a = matrix("a", shape=(6, None))
a_splits = ptb.split(a, splits_size=[2, a.shape[0] - 2], n_splits=2, axis=0) a_splits = ptb.split(a, splits_size=[2, a.shape[0] - 2], n_splits=2, axis=0)
fg = FunctionGraph([a], a_splits)
compare_jax_and_py( compare_jax_and_py(
fg, [a],
a_splits,
[ [
np.zeros((6, 4)).astype(config.floatX), np.zeros((6, 4)).astype(config.floatX),
], ],
...@@ -207,15 +197,14 @@ class TestJaxSplit: ...@@ -207,15 +197,14 @@ class TestJaxSplit:
def test_jax_eye(): def test_jax_eye():
"""Tests jaxification of the Eye operator""" """Tests jaxification of the Eye operator"""
out = ptb.eye(3) out = ptb.eye(3)
out_fg = FunctionGraph([], [out])
compare_jax_and_py(out_fg, []) compare_jax_and_py([], [out], [])
def test_tri(): def test_tri():
out = ptb.tri(10, 10, 0) out = ptb.tri(10, 10, 0)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, []) compare_jax_and_py([], [out], [])
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -230,14 +219,13 @@ def test_tri_nonconcrete(): ...@@ -230,14 +219,13 @@ def test_tri_nonconcrete():
scalar("n", dtype="int64"), scalar("n", dtype="int64"),
scalar("k", dtype="int64"), scalar("k", dtype="int64"),
) )
m.tag.test_value = 10 m_test_value = 10
n.tag.test_value = 10 n_test_value = 10
k.tag.test_value = 0 k_test_value = 0
out = ptb.tri(m, n, k) out = ptb.tri(m, n, k)
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but # The actual error the user will see should be jax.errors.ConcretizationTypeError, but
# the error handler raises an Attribute error first, so that's what this test needs to pass # the error handler raises an Attribute error first, so that's what this test needs to pass
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
fgraph = FunctionGraph([m, n, k], [out]) compare_jax_and_py([m, n, k], [out], [m_test_value, n_test_value, k_test_value])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
import contextlib import contextlib
import inspect import inspect
from collections.abc import Callable, Sequence from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from unittest import mock from unittest import mock
...@@ -21,10 +21,8 @@ from pytensor.compile.builders import OpFromGraph ...@@ -21,10 +21,8 @@ from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.compile.ops import ViewOp from pytensor.compile.ops import ViewOp
from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.basic import Apply, Constant from pytensor.graph.op import Op
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
...@@ -39,7 +37,6 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape ...@@ -39,7 +37,6 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
if TYPE_CHECKING: if TYPE_CHECKING:
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
from pytensor.tensor import TensorLike
class MyType(Type): class MyType(Type):
...@@ -128,11 +125,6 @@ py_mode = Mode("py", opts) ...@@ -128,11 +125,6 @@ py_mode = Mode("py", opts)
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
def set_test_value(x, v):
x.tag.test_value = v
return x
def compare_shape_dtype(x, y): def compare_shape_dtype(x, y):
return x.shape == y.shape and x.dtype == y.dtype return x.shape == y.shape and x.dtype == y.dtype
...@@ -225,28 +217,30 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode): ...@@ -225,28 +217,30 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
def compare_numba_and_py( def compare_numba_and_py(
fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]], graph_inputs: Iterable[Variable],
inputs: Sequence["TensorLike"], graph_outputs: Variable | Iterable[Variable],
assert_fn: Callable | None = None, test_inputs: Iterable,
*, *,
assert_fn: Callable | None = None,
numba_mode=numba_mode, numba_mode=numba_mode,
py_mode=py_mode, py_mode=py_mode,
updates=None, updates=None,
inplace: bool = False, inplace: bool = False,
eval_obj_mode: bool = True, eval_obj_mode: bool = True,
) -> tuple[Callable, Any]: ) -> tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality """Function to compare python function output and Numba compiled output for testing equality
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to The inputs and outputs are then passed to this function which then compiles the given function in both
this function which then compiles the graphs in both Numba and python, runs the calculation numba and python, runs the calculation in both and checks if the results are the same
in both and checks if the results are the same
Parameters Parameters
---------- ----------
fgraph graph_inputs:
`FunctionGraph` or tuple(inputs, outputs) to compare. Symbolic inputs to the graph
inputs graph_outputs:
Numeric inputs to be passed to the compiled graphs. Symbolic outputs of the graph
test_inputs
Numerical inputs with which to evaluate the graph.
assert_fn assert_fn
Assert function used to check for equality between python and Numba. If not Assert function used to check for equality between python and Numba. If not
provided uses `np.testing.assert_allclose`. provided uses `np.testing.assert_allclose`.
...@@ -267,42 +261,38 @@ def compare_numba_and_py( ...@@ -267,42 +261,38 @@ def compare_numba_and_py(
x, y x, y
) )
if isinstance(fgraph, FunctionGraph): if any(inp.owner is not None for inp in graph_inputs):
fn_inputs = fgraph.inputs raise ValueError("Inputs must be root variables")
fn_outputs = fgraph.outputs
else:
fn_inputs, fn_outputs = fgraph
fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)]
pytensor_py_fn = function( pytensor_py_fn = function(
fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates graph_inputs, graph_outputs, mode=py_mode, accept_inplace=True, updates=updates
) )
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
py_res = pytensor_py_fn(*test_inputs) py_res = pytensor_py_fn(*test_inputs_copy)
# Get some coverage (and catch errors in python mode before unreadable numba ones) # Get some coverage (and catch errors in python mode before unreadable numba ones)
if eval_obj_mode: if eval_obj_mode:
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs test_inputs_copy = (
eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode) (inp.copy() for inp in test_inputs) if inplace else test_inputs
)
eval_python_only(graph_inputs, graph_outputs, test_inputs_copy, mode=numba_mode)
pytensor_numba_fn = function( pytensor_numba_fn = function(
fn_inputs, graph_inputs,
fn_outputs, graph_outputs,
mode=numba_mode, mode=numba_mode,
accept_inplace=True, accept_inplace=True,
updates=updates, updates=updates,
) )
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
numba_res = pytensor_numba_fn(*test_inputs_copy)
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs if isinstance(graph_outputs, tuple | list):
numba_res = pytensor_numba_fn(*test_inputs)
if len(fn_outputs) > 1:
for j, p in zip(numba_res, py_res, strict=True): for j, p in zip(numba_res, py_res, strict=True):
assert_fn(j, p) assert_fn(j, p)
else: else:
assert_fn(numba_res[0], py_res[0]) assert_fn(numba_res, py_res)
return pytensor_numba_fn, numba_res return pytensor_numba_fn, numba_res
...@@ -380,53 +370,53 @@ def test_create_numba_signature(v, expected, force_scalar): ...@@ -380,53 +370,53 @@ def test_create_numba_signature(v, expected, force_scalar):
) )
def test_Shape(x, i): def test_Shape(x, i):
g = Shape()(pt.as_tensor_variable(x)) g = Shape()(pt.as_tensor_variable(x))
g_fg = FunctionGraph([], [g])
compare_numba_and_py(g_fg, []) compare_numba_and_py([], [g], [])
g = Shape_i(i)(pt.as_tensor_variable(x)) g = Shape_i(i)(pt.as_tensor_variable(x))
g_fg = FunctionGraph([], [g])
compare_numba_and_py(g_fg, []) compare_numba_and_py([], [g], [])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, shape, ndim", "v, shape, ndim",
[ [
(set_test_value(pt.vector(), np.array([4], dtype=config.floatX)), (), 0), ((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0),
(set_test_value(pt.vector(), np.arange(4, dtype=config.floatX)), (2, 2), 2), ((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2),
( (
set_test_value(pt.vector(), np.arange(4, dtype=config.floatX)), (pt.vector(), np.arange(4, dtype=config.floatX)),
set_test_value(pt.lvector(), np.array([2, 2], dtype="int64")), (pt.lvector(), np.array([2, 2], dtype="int64")),
2, 2,
), ),
], ],
) )
def test_Reshape(v, shape, ndim): def test_Reshape(v, shape, ndim):
v, v_test_value = v
shape, shape_test_value = shape
g = Reshape(ndim)(v, shape) g = Reshape(ndim)(v, shape)
g_fg = FunctionGraph(outputs=[g]) inputs = [v] if not isinstance(shape, Variable) else [v, shape]
test_values = (
[v_test_value]
if not isinstance(shape, Variable)
else [v_test_value, shape_test_value]
)
compare_numba_and_py( compare_numba_and_py(
g_fg, inputs,
[ [g],
i.tag.test_value test_values,
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
def test_Reshape_scalar(): def test_Reshape_scalar():
v = pt.vector() v = pt.vector()
v.tag.test_value = np.array([1.0], dtype=config.floatX) v_test_value = np.array([1.0], dtype=config.floatX)
g = Reshape(1)(v[0], (1,)) g = Reshape(1)(v[0], (1,))
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ g,
i.tag.test_value [v_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -434,53 +424,44 @@ def test_Reshape_scalar(): ...@@ -434,53 +424,44 @@ def test_Reshape_scalar():
"v, shape, fails", "v, shape, fails",
[ [
( (
set_test_value(pt.matrix(), np.array([[1.0]], dtype=config.floatX)), (pt.matrix(), np.array([[1.0]], dtype=config.floatX)),
(1, 1), (1, 1),
False, False,
), ),
( (
set_test_value(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, 1), (1, 1),
True, True,
), ),
( (
set_test_value(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, None), (1, None),
False, False,
), ),
], ],
) )
def test_SpecifyShape(v, shape, fails): def test_SpecifyShape(v, shape, fails):
v, v_test_value = v
g = SpecifyShape()(v, *shape) g = SpecifyShape()(v, *shape)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if not fails else pytest.raises(AssertionError) cm = contextlib.suppress() if not fails else pytest.raises(AssertionError)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ [g],
i.tag.test_value [v_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
@pytest.mark.parametrize( def test_ViewOp():
"v", v = pt.vector()
[ v_test_value = np.arange(4, dtype=config.floatX)
set_test_value(pt.vector(), np.arange(4, dtype=config.floatX)),
],
)
def test_ViewOp(v):
g = ViewOp()(v) g = ViewOp()(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ [g],
i.tag.test_value [v_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -489,20 +470,16 @@ def test_ViewOp(v): ...@@ -489,20 +470,16 @@ def test_ViewOp(v):
[ [
( (
[ [
set_test_value( (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX) (pt.lmatrix(), rng.poisson(size=(2, 3))),
),
set_test_value(pt.lmatrix(), rng.poisson(size=(2, 3))),
], ],
MySingleOut, MySingleOut,
UserWarning, UserWarning,
), ),
( (
[ [
set_test_value( (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX) (pt.lmatrix(), rng.poisson(size=(2, 3))),
),
set_test_value(pt.lmatrix(), rng.poisson(size=(2, 3))),
], ],
MyMultiOut, MyMultiOut,
UserWarning, UserWarning,
...@@ -510,38 +487,32 @@ def test_ViewOp(v): ...@@ -510,38 +487,32 @@ def test_ViewOp(v):
], ],
) )
def test_perform(inputs, op, exc): def test_perform(inputs, op, exc):
inputs, test_values = zip(*inputs, strict=True)
g = op()(*inputs) g = op()(*inputs)
if isinstance(g, list): if isinstance(g, list):
g_fg = FunctionGraph(outputs=g) outputs = g
else: else:
g_fg = FunctionGraph(outputs=[g]) outputs = [g]
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, inputs,
[ outputs,
i.tag.test_value test_values,
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
def test_perform_params(): def test_perform_params():
"""This tests for `Op.perform` implementations that require the `params` arguments.""" """This tests for `Op.perform` implementations that require the `params` arguments."""
x = pt.vector() x = pt.vector(shape=(2,))
x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX) x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x, np.array(True)) out = assert_op(x, np.array(True))
if not isinstance(out, list | tuple): compare_numba_and_py([x], out, [x_test_value])
out = [out]
out_fg = FunctionGraph([x], out)
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
def test_perform_type_convert(): def test_perform_type_convert():
...@@ -552,59 +523,50 @@ def test_perform_type_convert(): ...@@ -552,59 +523,50 @@ def test_perform_type_convert():
""" """
x = pt.vector() x = pt.vector()
x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX) x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x.sum(), np.array(True)) out = assert_op(x.sum(), np.array(True))
if not isinstance(out, list | tuple): compare_numba_and_py([x], out, [x_test_value])
out = [out]
out_fg = FunctionGraph([x], out)
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, y, exc", "x, y, exc",
[ [
( (
set_test_value(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)), (pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
set_test_value(pt.vector(), rng.random(size=(2,)).astype(config.floatX)), (pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
None, None,
), ),
( (
set_test_value( (pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64") (pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
),
set_test_value(
pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")
),
None, None,
), ),
( (
set_test_value(pt.lmatrix(), rng.poisson(size=(3, 2))), (pt.lmatrix(), rng.poisson(size=(3, 2))),
set_test_value(pt.fvector(), rng.random(size=(2,)).astype("float32")), (pt.fvector(), rng.random(size=(2,)).astype("float32")),
None, None,
), ),
( (
set_test_value(pt.lvector(), rng.random(size=(2,)).astype(np.int64)), (pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
set_test_value(pt.lvector(), rng.random(size=(2,)).astype(np.int64)), (pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
None, None,
), ),
], ],
) )
def test_Dot(x, y, exc): def test_Dot(x, y, exc):
x, x_test_value = x
y, y_test_value = y
g = ptm.Dot()(x, y) g = ptm.Dot()(x, y)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x, y],
[ [g],
i.tag.test_value [x_test_value, y_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -612,44 +574,41 @@ def test_Dot(x, y, exc): ...@@ -612,44 +574,41 @@ def test_Dot(x, y, exc):
"x, exc", "x, exc",
[ [
( (
set_test_value(ps.float64(), np.array(0.0, dtype="float64")), (ps.float64(), np.array(0.0, dtype="float64")),
None, None,
), ),
( (
set_test_value(ps.float64(), np.array(-32.0, dtype="float64")), (ps.float64(), np.array(-32.0, dtype="float64")),
None, None,
), ),
( (
set_test_value(ps.float64(), np.array(-40.0, dtype="float64")), (ps.float64(), np.array(-40.0, dtype="float64")),
None, None,
), ),
( (
set_test_value(ps.float64(), np.array(32.0, dtype="float64")), (ps.float64(), np.array(32.0, dtype="float64")),
None, None,
), ),
( (
set_test_value(ps.float64(), np.array(40.0, dtype="float64")), (ps.float64(), np.array(40.0, dtype="float64")),
None, None,
), ),
( (
set_test_value(ps.int64(), np.array(32, dtype="int64")), (ps.int64(), np.array(32, dtype="int64")),
None, None,
), ),
], ],
) )
def test_Softplus(x, exc): def test_Softplus(x, exc):
x, x_test_value = x
g = psm.Softplus(ps.upgrade_to_float)(x) g = psm.Softplus(ps.upgrade_to_float)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ [g],
i.tag.test_value [x_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -657,22 +616,22 @@ def test_Softplus(x, exc): ...@@ -657,22 +616,22 @@ def test_Softplus(x, exc):
"x, y, exc", "x, y, exc",
[ [
( (
set_test_value( (
pt.dtensor3(), pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"), rng.random(size=(2, 3, 3)).astype("float64"),
), ),
set_test_value( (
pt.dtensor3(), pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"), rng.random(size=(2, 3, 3)).astype("float64"),
), ),
None, None,
), ),
( (
set_test_value( (
pt.dtensor3(), pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"), rng.random(size=(2, 3, 3)).astype("float64"),
), ),
set_test_value( (
pt.ltensor3(), pt.ltensor3(),
rng.poisson(size=(2, 3, 3)).astype("int64"), rng.poisson(size=(2, 3, 3)).astype("int64"),
), ),
...@@ -681,22 +640,17 @@ def test_Softplus(x, exc): ...@@ -681,22 +640,17 @@ def test_Softplus(x, exc):
], ],
) )
def test_BatchedDot(x, y, exc): def test_BatchedDot(x, y, exc):
g = blas.BatchedDot()(x, y) x, x_test_value = x
y, y_test_value = y
if isinstance(g, list): g = blas.BatchedDot()(x, y)
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x, y],
[ g,
i.tag.test_value [x_test_value, y_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -767,15 +721,15 @@ y = np.array( ...@@ -767,15 +721,15 @@ y = np.array(
[ [
([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]), ([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]),
( (
[set_test_value(pt.dscalar(), np.array(0.2, dtype=np.float64))], [(pt.dscalar(), np.array(0.2, dtype=np.float64))],
lambda x: x < 0.5, lambda x: x < 0.5,
np.r_[1, 2, 3], np.r_[1, 2, 3],
np.r_[-1, -2, -3], np.r_[-1, -2, -3],
), ),
( (
[ [
set_test_value(pt.dscalar(), np.array(0.3, dtype=np.float64)), (pt.dscalar(), np.array(0.3, dtype=np.float64)),
set_test_value(pt.dscalar(), np.array(0.5, dtype=np.float64)), (pt.dscalar(), np.array(0.5, dtype=np.float64)),
], ],
lambda x, y: x > y, lambda x, y: x > y,
x, x,
...@@ -783,8 +737,8 @@ y = np.array( ...@@ -783,8 +737,8 @@ y = np.array(
), ),
( (
[ [
set_test_value(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)), (pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
set_test_value(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)), (pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
], ],
lambda x, y: pt.all(x > y), lambda x, y: pt.all(x > y),
x, x,
...@@ -792,8 +746,8 @@ y = np.array( ...@@ -792,8 +746,8 @@ y = np.array(
), ),
( (
[ [
set_test_value(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)), (pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
set_test_value(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)), (pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
], ],
lambda x, y: pt.all(x > y), lambda x, y: pt.all(x > y),
[x, 2 * x], [x, 2 * x],
...@@ -801,8 +755,8 @@ y = np.array( ...@@ -801,8 +755,8 @@ y = np.array(
), ),
( (
[ [
set_test_value(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)), (pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
set_test_value(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)), (pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
], ],
lambda x, y: pt.all(x > y), lambda x, y: pt.all(x > y),
[x, 2 * x], [x, 2 * x],
...@@ -811,14 +765,9 @@ y = np.array( ...@@ -811,14 +765,9 @@ y = np.array(
], ],
) )
def test_IfElse(inputs, cond_fn, true_vals, false_vals): def test_IfElse(inputs, cond_fn, true_vals, false_vals):
inputs, test_values = zip(*inputs, strict=True) if inputs else ([], [])
out = ifelse(cond_fn(*inputs), true_vals, false_vals) out = ifelse(cond_fn(*inputs), true_vals, false_vals)
compare_numba_and_py(inputs, out, test_values)
if not isinstance(out, list):
out = [out]
out_fg = FunctionGraph(inputs, out)
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
@pytest.mark.xfail(reason="https://github.com/numba/numba/issues/7409") @pytest.mark.xfail(reason="https://github.com/numba/numba/issues/7409")
...@@ -883,7 +832,7 @@ def test_OpFromGraph(): ...@@ -883,7 +832,7 @@ def test_OpFromGraph():
yv = np.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv]) compare_numba_and_py([x, y, z], [out], [xv, yv, zv])
@pytest.mark.filterwarnings("error") @pytest.mark.filterwarnings("error")
......
...@@ -27,7 +27,8 @@ def test_blockwise(core_op, shape_opt): ...@@ -27,7 +27,8 @@ def test_blockwise(core_op, shape_opt):
) )
x_test = np.eye(3) * np.arange(1, 6)[:, None, None] x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
compare_numba_and_py( compare_numba_and_py(
([x], outs), [x],
outs,
[x_test], [x_test],
numba_mode=mode, numba_mode=mode,
eval_obj_mode=False, eval_obj_mode=False,
......
...@@ -11,10 +11,7 @@ import pytensor.tensor.math as ptm ...@@ -11,10 +11,7 @@ import pytensor.tensor.math as ptm
from pytensor import config, function from pytensor import config, function
from pytensor.compile import get_mode from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import float64 from pytensor.scalar import float64
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
...@@ -22,7 +19,6 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad ...@@ -22,7 +19,6 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
scalar_my_multi_out, scalar_my_multi_out,
set_test_value,
) )
from tests.tensor.test_elemwise import ( from tests.tensor.test_elemwise import (
careduce_benchmark_tester, careduce_benchmark_tester,
...@@ -116,13 +112,13 @@ rng = np.random.default_rng(42849) ...@@ -116,13 +112,13 @@ rng = np.random.default_rng(42849)
def test_Elemwise(inputs, input_vals, output_fn, exc): def test_Elemwise(inputs, input_vals, output_fn, exc):
outputs = output_fn(*inputs) outputs = output_fn(*inputs)
out_fg = FunctionGraph(
outputs=[outputs] if not isinstance(outputs, list) else outputs
)
cm = contextlib.suppress() if exc is None else pytest.raises(exc) cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm: with cm:
compare_numba_and_py(out_fg, input_vals) compare_numba_and_py(
inputs,
outputs,
input_vals,
)
@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults") @pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
...@@ -135,7 +131,7 @@ def test_elemwise_runtime_broadcast(): ...@@ -135,7 +131,7 @@ def test_elemwise_runtime_broadcast():
[ [
# `{'drop': [], 'shuffle': [], 'augment': [0, 1]}` # `{'drop': [], 'shuffle': [], 'augment': [0, 1]}`
( (
set_test_value( (
pt.lscalar(name="a"), pt.lscalar(name="a"),
np.array(1, dtype=np.int64), np.array(1, dtype=np.int64),
), ),
...@@ -144,21 +140,17 @@ def test_elemwise_runtime_broadcast(): ...@@ -144,21 +140,17 @@ def test_elemwise_runtime_broadcast():
# I.e. `a_pt.T` # I.e. `a_pt.T`
# `{'drop': [], 'shuffle': [1, 0], 'augment': []}` # `{'drop': [], 'shuffle': [1, 0], 'augment': []}`
( (
set_test_value( (pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)),
pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
),
(1, 0), (1, 0),
), ),
# `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}` # `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}`
( (
set_test_value( (pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)),
pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
),
(1, 0, "x"), (1, 0, "x"),
), ),
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}` # `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
( (
set_test_value( (
pt.tensor(dtype=config.floatX, shape=(None, 1, None), name="a"), pt.tensor(dtype=config.floatX, shape=(None, 1, None), name="a"),
np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX), np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX),
), ),
...@@ -167,21 +159,21 @@ def test_elemwise_runtime_broadcast(): ...@@ -167,21 +159,21 @@ def test_elemwise_runtime_broadcast():
# I.e. `a_pt.dimshuffle((0,))` # I.e. `a_pt.dimshuffle((0,))`
# `{'drop': [1], 'shuffle': [0], 'augment': []}` # `{'drop': [1], 'shuffle': [0], 'augment': []}`
( (
set_test_value( (
pt.tensor(dtype=config.floatX, shape=(None, 1), name="a"), pt.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX), np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
), ),
(0,), (0,),
), ),
( (
set_test_value( (
pt.tensor(dtype=config.floatX, shape=(None, 1), name="a"), pt.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX), np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
), ),
(0,), (0,),
), ),
( (
set_test_value( (
pt.tensor(dtype=config.floatX, shape=(1, 1, 1), name="a"), pt.tensor(dtype=config.floatX, shape=(1, 1, 1), name="a"),
np.array([[[1.0]]], dtype=config.floatX), np.array([[[1.0]]], dtype=config.floatX),
), ),
...@@ -190,15 +182,12 @@ def test_elemwise_runtime_broadcast(): ...@@ -190,15 +182,12 @@ def test_elemwise_runtime_broadcast():
], ],
) )
def test_Dimshuffle(v, new_order): def test_Dimshuffle(v, new_order):
v, v_test_value = v
g = v.dimshuffle(new_order) g = v.dimshuffle(new_order)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ [g],
i.tag.test_value [v_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -229,79 +218,68 @@ def test_Dimshuffle_non_contiguous(): ...@@ -229,79 +218,68 @@ def test_Dimshuffle_non_contiguous():
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
0, 0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), (pt.vector(), np.arange(3, dtype=config.floatX)),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x), lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
0, 0,
set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])), (pt.vector(dtype="bool"), np.array([False, True, False])),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x), lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
0, 0,
set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])), (pt.vector(dtype="bool"), np.array([False, True, False])),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Sum( lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
0, 0,
set_test_value( (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Sum( lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
(0, 1), (0, 1),
set_test_value( (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Sum( lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
(1, 0), (1, 0),
set_test_value( (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Sum( lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
None, None,
set_test_value( (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Sum( lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
1, 1,
set_test_value( (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Prod( lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
(), # Empty axes would normally be rewritten away, but we want to test it still works (), # Empty axes would normally be rewritten away, but we want to test it still works
set_test_value( (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Prod( lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
None, None,
set_test_value( (
pt.scalar(), np.array(99.0, dtype=config.floatX) pt.scalar(),
np.array(99.0, dtype=config.floatX),
), # Scalar input would normally be rewritten away, but we want to test it still works ), # Scalar input would normally be rewritten away, but we want to test it still works
), ),
( (
...@@ -309,77 +287,62 @@ def test_Dimshuffle_non_contiguous(): ...@@ -309,77 +287,62 @@ def test_Dimshuffle_non_contiguous():
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
0, 0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), (pt.vector(), np.arange(3, dtype=config.floatX)),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros( lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
0, 0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), (pt.vector(), np.arange(3, dtype=config.floatX)),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Prod( lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
0, 0,
set_test_value( (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Prod( lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x), )(x),
1, 1,
set_test_value( (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x), lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None, None,
set_test_value( (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x), lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None, None,
set_test_value( (pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))),
pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x), lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None, None,
set_test_value( (pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x), lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None, None,
set_test_value( (pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))),
pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))
),
), ),
], ],
) )
def test_CAReduce(careduce_fn, axis, v): def test_CAReduce(careduce_fn, axis, v):
v, v_test_value = v
g = careduce_fn(v, axis=axis) g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g])
fn, _ = compare_numba_and_py( fn, _ = compare_numba_and_py(
g_fg, [v],
[ [g],
i.tag.test_value [v_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
# Confirm CAReduce is in the compiled function # Confirm CAReduce is in the compiled function
fn.dprint() # fn.dprint()
[node] = fn.maker.fgraph.apply_nodes [node] = fn.maker.fgraph.apply_nodes
assert isinstance(node.op, CAReduce) assert isinstance(node.op, CAReduce)
...@@ -387,102 +350,91 @@ def test_CAReduce(careduce_fn, axis, v): ...@@ -387,102 +350,91 @@ def test_CAReduce(careduce_fn, axis, v):
def test_scalar_Elemwise_Clip(): def test_scalar_Elemwise_Clip():
a = pt.scalar("a") a = pt.scalar("a")
b = pt.scalar("b") b = pt.scalar("b")
inputs = [a, b]
z = pt.switch(1, a, b) z = pt.switch(1, a, b)
c = pt.clip(z, 1, 3) c = pt.clip(z, 1, 3)
c_fg = FunctionGraph(outputs=[c])
compare_numba_and_py(c_fg, [1, 1]) compare_numba_and_py(inputs, [c], [1, 1])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dy, sm, axis, exc", "dy, sm, axis, exc",
[ [
( (
set_test_value( (pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)),
pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
),
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
None, None,
None, None,
), ),
( (
set_test_value( (pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)),
pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
),
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0, 0,
None, None,
), ),
( (
set_test_value( (pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)),
pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
),
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1, 1,
None, None,
), ),
], ],
) )
def test_SoftmaxGrad(dy, sm, axis, exc): def test_SoftmaxGrad(dy, sm, axis, exc):
dy, dy_test_value = dy
sm, sm_test_value = sm
g = SoftmaxGrad(axis=axis)(dy, sm) g = SoftmaxGrad(axis=axis)(dy, sm)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [dy, sm],
[ [g],
i.tag.test_value [dy_test_value, sm_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
def test_SoftMaxGrad_constant_dy(): def test_SoftMaxGrad_constant_dy():
dy = pt.constant(np.zeros((3,), dtype=config.floatX)) dy = pt.constant(np.zeros((3,), dtype=config.floatX))
sm = pt.vector(shape=(3,)) sm = pt.vector(shape=(3,))
inputs = [sm]
g = SoftmaxGrad(axis=None)(dy, sm) g = SoftmaxGrad(axis=None)(dy, sm)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(g_fg, [np.ones((3,), dtype=config.floatX)]) compare_numba_and_py(inputs, [g], [np.ones((3,), dtype=config.floatX)])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, axis, exc", "x, axis, exc",
[ [
( (
set_test_value(pt.vector(), rng.random(size=(2,)).astype(config.floatX)), (pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
None, None,
None, None,
), ),
( (
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
None, None,
None, None,
), ),
( (
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0, 0,
None, None,
), ),
], ],
) )
def test_Softmax(x, axis, exc): def test_Softmax(x, axis, exc):
x, x_test_value = x
g = Softmax(axis=axis)(x) g = Softmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ [g],
i.tag.test_value [x_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -490,35 +442,32 @@ def test_Softmax(x, axis, exc): ...@@ -490,35 +442,32 @@ def test_Softmax(x, axis, exc):
"x, axis, exc", "x, axis, exc",
[ [
( (
set_test_value(pt.vector(), rng.random(size=(2,)).astype(config.floatX)), (pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
None, None,
None, None,
), ),
( (
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0, 0,
None, None,
), ),
( (
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), (pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1, 1,
None, None,
), ),
], ],
) )
def test_LogSoftmax(x, axis, exc): def test_LogSoftmax(x, axis, exc):
x, x_test_value = x
g = LogSoftmax(axis=axis)(x) g = LogSoftmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ [g],
i.tag.test_value [x_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -526,44 +475,37 @@ def test_LogSoftmax(x, axis, exc): ...@@ -526,44 +475,37 @@ def test_LogSoftmax(x, axis, exc):
"x, axes, exc", "x, axes, exc",
[ [
( (
set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")), (pt.dscalar(), np.array(0.0, dtype="float64")),
[], [],
None, None,
), ),
( (
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), (pt.dvector(), rng.random(size=(3,)).astype("float64")),
[0], [0],
None, None,
), ),
( (
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), (pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0], [0],
None, None,
), ),
( (
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), (pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0, 1], [0, 1],
None, None,
), ),
], ],
) )
def test_Max(x, axes, exc): def test_Max(x, axes, exc):
x, x_test_value = x
g = ptm.Max(axes)(x) g = ptm.Max(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ [g],
i.tag.test_value [x_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -571,44 +513,37 @@ def test_Max(x, axes, exc): ...@@ -571,44 +513,37 @@ def test_Max(x, axes, exc):
"x, axes, exc", "x, axes, exc",
[ [
( (
set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")), (pt.dscalar(), np.array(0.0, dtype="float64")),
[], [],
None, None,
), ),
( (
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")), (pt.dvector(), rng.random(size=(3,)).astype("float64")),
[0], [0],
None, None,
), ),
( (
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), (pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0], [0],
None, None,
), ),
( (
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")), (pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0, 1], [0, 1],
None, None,
), ),
], ],
) )
def test_Argmax(x, axes, exc): def test_Argmax(x, axes, exc):
x, x_test_value = x
g = ptm.Argmax(axes)(x) g = ptm.Argmax(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ [g],
i.tag.test_value [x_test_value],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -636,7 +571,8 @@ def test_scalar_loop(): ...@@ -636,7 +571,8 @@ def test_scalar_loop():
with pytest.warns(UserWarning, match="object mode"): with pytest.warns(UserWarning, match="object mode"):
compare_numba_and_py( compare_numba_and_py(
([x], [elemwise_loop]), [x],
[elemwise_loop],
(np.array([1, 2, 3], dtype="float64"),), (np.array([1, 2, 3], dtype="float64"),),
) )
......
...@@ -5,11 +5,8 @@ import pytest ...@@ -5,11 +5,8 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import extra_ops from pytensor.tensor import extra_ops
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -18,20 +15,17 @@ rng = np.random.default_rng(42849) ...@@ -18,20 +15,17 @@ rng = np.random.default_rng(42849)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"val", "val",
[ [
set_test_value(pt.lscalar(), np.array(6, dtype="int64")), (pt.lscalar(), np.array(6, dtype="int64")),
], ],
) )
def test_Bartlett(val): def test_Bartlett(val):
val, test_val = val
g = extra_ops.bartlett(val) g = extra_ops.bartlett(val)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [val],
[ g,
i.tag.test_value [test_val],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, atol=1e-15), assert_fn=lambda x, y: np.testing.assert_allclose(x, y, atol=1e-15),
) )
...@@ -40,97 +34,71 @@ def test_Bartlett(val): ...@@ -40,97 +34,71 @@ def test_Bartlett(val):
"val, axis, mode", "val, axis, mode",
[ [
( (
set_test_value( (pt.matrix(), np.arange(3, dtype=config.floatX).reshape((3, 1))),
pt.matrix(), np.arange(3, dtype=config.floatX).reshape((3, 1))
),
1, 1,
"add", "add",
), ),
( (
set_test_value( (pt.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))),
pt.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))
),
-1, -1,
"add", "add",
), ),
( (
set_test_value( (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
0, 0,
"add", "add",
), ),
( (
set_test_value( (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
1, 1,
"add", "add",
), ),
( (
set_test_value( (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
None, None,
"add", "add",
), ),
( (
set_test_value( (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
0, 0,
"mul", "mul",
), ),
( (
set_test_value( (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
1, 1,
"mul", "mul",
), ),
( (
set_test_value( (pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
None, None,
"mul", "mul",
), ),
], ],
) )
def test_CumOp(val, axis, mode): def test_CumOp(val, axis, mode):
val, test_val = val
g = extra_ops.CumOp(axis=axis, mode=mode)(val) g = extra_ops.CumOp(axis=axis, mode=mode)(val)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [val],
[ g,
i.tag.test_value [test_val],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
@pytest.mark.parametrize( def test_FillDiagonal():
"a, val", a = pt.lmatrix("a")
[ test_a = np.zeros((10, 2), dtype="int64")
(
set_test_value(pt.lmatrix(), np.zeros((10, 2), dtype="int64")), val = pt.lscalar("val")
set_test_value(pt.lscalar(), np.array(1, dtype="int64")), test_val = np.array(1, dtype="int64")
)
],
)
def test_FillDiagonal(a, val):
g = extra_ops.FillDiagonal()(a, val) g = extra_ops.FillDiagonal()(a, val)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [a, val],
[ g,
i.tag.test_value [test_a, test_val],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -138,33 +106,32 @@ def test_FillDiagonal(a, val): ...@@ -138,33 +106,32 @@ def test_FillDiagonal(a, val):
"a, val, offset", "a, val, offset",
[ [
( (
set_test_value(pt.lmatrix(), np.zeros((10, 2), dtype="int64")), (pt.lmatrix(), np.zeros((10, 2), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")), (pt.lscalar(), np.array(1, dtype="int64")),
set_test_value(pt.lscalar(), np.array(-1, dtype="int64")), (pt.lscalar(), np.array(-1, dtype="int64")),
), ),
( (
set_test_value(pt.lmatrix(), np.zeros((10, 2), dtype="int64")), (pt.lmatrix(), np.zeros((10, 2), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")), (pt.lscalar(), np.array(1, dtype="int64")),
set_test_value(pt.lscalar(), np.array(0, dtype="int64")), (pt.lscalar(), np.array(0, dtype="int64")),
), ),
( (
set_test_value(pt.lmatrix(), np.zeros((10, 3), dtype="int64")), (pt.lmatrix(), np.zeros((10, 3), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")), (pt.lscalar(), np.array(1, dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")), (pt.lscalar(), np.array(1, dtype="int64")),
), ),
], ],
) )
def test_FillDiagonalOffset(a, val, offset): def test_FillDiagonalOffset(a, val, offset):
a, test_a = a
val, test_val = val
offset, test_offset = offset
g = extra_ops.FillDiagonalOffset()(a, val, offset) g = extra_ops.FillDiagonalOffset()(a, val, offset)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [a, val, offset],
[ g,
i.tag.test_value [test_a, test_val, test_offset],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -172,65 +139,56 @@ def test_FillDiagonalOffset(a, val, offset): ...@@ -172,65 +139,56 @@ def test_FillDiagonalOffset(a, val, offset):
"arr, shape, mode, order, exc", "arr, shape, mode, order, exc",
[ [
( (
tuple(set_test_value(pt.lscalar(), v) for v in np.array([0])), tuple((pt.lscalar(), v) for v in np.array([0])),
set_test_value(pt.lvector(), np.array([2])), (pt.lvector(), np.array([2])),
"raise", "raise",
"C", "C",
None, None,
), ),
( (
tuple(set_test_value(pt.lscalar(), v) for v in np.array([0, 0, 3])), tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
set_test_value(pt.lvector(), np.array([2, 3, 4])), (pt.lvector(), np.array([2, 3, 4])),
"raise", "raise",
"C", "C",
None, None,
), ),
( (
tuple( tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])),
set_test_value(pt.lvector(), v) (pt.lvector(), np.array([2, 3, 4])),
for v in np.array([[0, 1], [2, 0], [1, 3]])
),
set_test_value(pt.lvector(), np.array([2, 3, 4])),
"raise", "raise",
"C", "C",
None, None,
), ),
( (
tuple( tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])),
set_test_value(pt.lvector(), v) (pt.lvector(), np.array([2, 3, 4])),
for v in np.array([[0, 1], [2, 0], [1, 3]])
),
set_test_value(pt.lvector(), np.array([2, 3, 4])),
"raise", "raise",
"F", "F",
NotImplementedError, NotImplementedError,
), ),
( (
tuple( tuple(
set_test_value(pt.lvector(), v) (pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
), ),
set_test_value(pt.lvector(), np.array([2, 3, 4])), (pt.lvector(), np.array([2, 3, 4])),
"raise", "raise",
"C", "C",
ValueError, ValueError,
), ),
( (
tuple( tuple(
set_test_value(pt.lvector(), v) (pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
), ),
set_test_value(pt.lvector(), np.array([2, 3, 4])), (pt.lvector(), np.array([2, 3, 4])),
"wrap", "wrap",
"C", "C",
None, None,
), ),
( (
tuple( tuple(
set_test_value(pt.lvector(), v) (pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
), ),
set_test_value(pt.lvector(), np.array([2, 3, 4])), (pt.lvector(), np.array([2, 3, 4])),
"clip", "clip",
"C", "C",
None, None,
...@@ -238,18 +196,16 @@ def test_FillDiagonalOffset(a, val, offset): ...@@ -238,18 +196,16 @@ def test_FillDiagonalOffset(a, val, offset):
], ],
) )
def test_RavelMultiIndex(arr, shape, mode, order, exc): def test_RavelMultiIndex(arr, shape, mode, order, exc):
g = extra_ops.RavelMultiIndex(mode, order)(*((*arr, shape))) arr, test_arr = zip(*arr, strict=True)
g_fg = FunctionGraph(outputs=[g]) shape, test_shape = shape
g = extra_ops.RavelMultiIndex(mode, order)(*arr, shape)
cm = contextlib.suppress() if exc is None else pytest.raises(exc) cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [*arr, shape],
[ g,
i.tag.test_value [*test_arr, test_shape],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -257,44 +213,42 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc): ...@@ -257,44 +213,42 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc):
"x, repeats, axis, exc", "x, repeats, axis, exc",
[ [
( (
set_test_value(pt.lscalar(), np.array(1, dtype="int64")), (pt.lscalar(), np.array(1, dtype="int64")),
set_test_value(pt.lscalar(), np.array(0, dtype="int64")), (pt.lscalar(), np.array(0, dtype="int64")),
None, None,
None, None,
), ),
( (
set_test_value(pt.lmatrix(), np.zeros((2, 2), dtype="int64")), (pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")), (pt.lscalar(), np.array(1, dtype="int64")),
None, None,
None, None,
), ),
( (
set_test_value(pt.lvector(), np.arange(2, dtype="int64")), (pt.lvector(), np.arange(2, dtype="int64")),
set_test_value(pt.lvector(), np.array([1, 1], dtype="int64")), (pt.lvector(), np.array([1, 1], dtype="int64")),
None, None,
None, None,
), ),
( (
set_test_value(pt.lmatrix(), np.zeros((2, 2), dtype="int64")), (pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")), (pt.lscalar(), np.array(1, dtype="int64")),
0, 0,
UserWarning, UserWarning,
), ),
], ],
) )
def test_Repeat(x, repeats, axis, exc): def test_Repeat(x, repeats, axis, exc):
x, test_x = x
repeats, test_repeats = repeats
g = extra_ops.Repeat(axis)(x, repeats) g = extra_ops.Repeat(axis)(x, repeats)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x, repeats],
[ g,
i.tag.test_value [test_x, test_repeats],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -302,7 +256,7 @@ def test_Repeat(x, repeats, axis, exc): ...@@ -302,7 +256,7 @@ def test_Repeat(x, repeats, axis, exc):
"x, axis, return_index, return_inverse, return_counts, exc", "x, axis, return_index, return_inverse, return_counts, exc",
[ [
( (
set_test_value(pt.lscalar(), np.array(1, dtype="int64")), (pt.lscalar(), np.array(1, dtype="int64")),
None, None,
False, False,
False, False,
...@@ -310,7 +264,7 @@ def test_Repeat(x, repeats, axis, exc): ...@@ -310,7 +264,7 @@ def test_Repeat(x, repeats, axis, exc):
None, None,
), ),
( (
set_test_value(pt.lvector(), np.array([1, 1, 2], dtype="int64")), (pt.lvector(), np.array([1, 1, 2], dtype="int64")),
None, None,
False, False,
False, False,
...@@ -318,7 +272,7 @@ def test_Repeat(x, repeats, axis, exc): ...@@ -318,7 +272,7 @@ def test_Repeat(x, repeats, axis, exc):
None, None,
), ),
( (
set_test_value(pt.lmatrix(), np.array([[1, 1], [2, 2]], dtype="int64")), (pt.lmatrix(), np.array([[1, 1], [2, 2]], dtype="int64")),
None, None,
False, False,
False, False,
...@@ -326,9 +280,7 @@ def test_Repeat(x, repeats, axis, exc): ...@@ -326,9 +280,7 @@ def test_Repeat(x, repeats, axis, exc):
None, None,
), ),
( (
set_test_value( (pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")),
pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")
),
0, 0,
False, False,
False, False,
...@@ -336,9 +288,7 @@ def test_Repeat(x, repeats, axis, exc): ...@@ -336,9 +288,7 @@ def test_Repeat(x, repeats, axis, exc):
UserWarning, UserWarning,
), ),
( (
set_test_value( (pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")),
pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")
),
0, 0,
True, True,
True, True,
...@@ -348,22 +298,15 @@ def test_Repeat(x, repeats, axis, exc): ...@@ -348,22 +298,15 @@ def test_Repeat(x, repeats, axis, exc):
], ],
) )
def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
x, test_x = x
g = extra_ops.Unique(return_index, return_inverse, return_counts, axis)(x) g = extra_ops.Unique(return_index, return_inverse, return_counts, axis)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -371,19 +314,19 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): ...@@ -371,19 +314,19 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
"arr, shape, order, exc", "arr, shape, order, exc",
[ [
( (
set_test_value(pt.lvector(), np.array([9, 15, 1], dtype="int64")), (pt.lvector(), np.array([9, 15, 1], dtype="int64")),
pt.as_tensor([2, 3, 4]), pt.as_tensor([2, 3, 4]),
"C", "C",
None, None,
), ),
( (
set_test_value(pt.lvector(), np.array([1, 0], dtype="int64")), (pt.lvector(), np.array([1, 0], dtype="int64")),
pt.as_tensor([2]), pt.as_tensor([2]),
"C", "C",
None, None,
), ),
( (
set_test_value(pt.lvector(), np.array([9, 15, 1], dtype="int64")), (pt.lvector(), np.array([9, 15, 1], dtype="int64")),
pt.as_tensor([2, 3, 4]), pt.as_tensor([2, 3, 4]),
"F", "F",
NotImplementedError, NotImplementedError,
...@@ -391,22 +334,15 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): ...@@ -391,22 +334,15 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
], ],
) )
def test_UnravelIndex(arr, shape, order, exc): def test_UnravelIndex(arr, shape, order, exc):
arr, test_arr = arr
g = extra_ops.UnravelIndex(order)(arr, shape) g = extra_ops.UnravelIndex(order)(arr, shape)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.raises(exc) cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [arr],
[ g,
i.tag.test_value [test_arr],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -414,18 +350,18 @@ def test_UnravelIndex(arr, shape, order, exc): ...@@ -414,18 +350,18 @@ def test_UnravelIndex(arr, shape, order, exc):
"a, v, side, sorter, exc", "a, v, side, sorter, exc",
[ [
( (
set_test_value(pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), (pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)),
set_test_value(pt.matrix(), rng.random((3, 2)).astype(config.floatX)), (pt.matrix(), rng.random((3, 2)).astype(config.floatX)),
"left", "left",
None, None,
None, None,
), ),
pytest.param( pytest.param(
set_test_value( (
pt.vector(), pt.vector(),
np.array([0.29769574, 0.71649186, 0.20475563]).astype(config.floatX), np.array([0.29769574, 0.71649186, 0.20475563]).astype(config.floatX),
), ),
set_test_value( (
pt.matrix(), pt.matrix(),
np.array( np.array(
[ [
...@@ -440,25 +376,26 @@ def test_UnravelIndex(arr, shape, order, exc): ...@@ -440,25 +376,26 @@ def test_UnravelIndex(arr, shape, order, exc):
None, None,
), ),
( (
set_test_value(pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), (pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)),
set_test_value(pt.matrix(), rng.random((3, 2)).astype(config.floatX)), (pt.matrix(), rng.random((3, 2)).astype(config.floatX)),
"right", "right",
set_test_value(pt.lvector(), np.array([0, 2, 1])), (pt.lvector(), np.array([0, 2, 1])),
UserWarning, UserWarning,
), ),
], ],
) )
def test_Searchsorted(a, v, side, sorter, exc): def test_Searchsorted(a, v, side, sorter, exc):
a, test_a = a
v, test_v = v
if sorter is not None:
sorter, test_sorter = sorter
g = extra_ops.SearchsortedOp(side)(a, v, sorter) g = extra_ops.SearchsortedOp(side)(a, v, sorter)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [a, v] if sorter is None else [a, v, sorter],
[ g,
i.tag.test_value [test_a, test_v] if sorter is None else [test_a, test_v, test_sorter],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -4,11 +4,8 @@ import numpy as np ...@@ -4,11 +4,8 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg from pytensor.tensor import nlinalg
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -18,14 +15,14 @@ rng = np.random.default_rng(42849) ...@@ -18,14 +15,14 @@ rng = np.random.default_rng(42849)
"x, exc", "x, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
), ),
...@@ -34,18 +31,15 @@ rng = np.random.default_rng(42849) ...@@ -34,18 +31,15 @@ rng = np.random.default_rng(42849)
], ],
) )
def test_Det(x, exc): def test_Det(x, exc):
x, test_x = x
g = nlinalg.Det()(x) g = nlinalg.Det()(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -53,14 +47,14 @@ def test_Det(x, exc): ...@@ -53,14 +47,14 @@ def test_Det(x, exc):
"x, exc", "x, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
), ),
...@@ -69,18 +63,15 @@ def test_Det(x, exc): ...@@ -69,18 +63,15 @@ def test_Det(x, exc):
], ],
) )
def test_SLogDet(x, exc): def test_SLogDet(x, exc):
x, test_x = x
g = nlinalg.SLogDet()(x) g = nlinalg.SLogDet()(x)
g_fg = FunctionGraph(outputs=g)
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -112,21 +103,21 @@ y = np.array( ...@@ -112,21 +103,21 @@ y = np.array(
"x, exc", "x, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(x), (lambda x: x.T.dot(x))(x),
), ),
None, None,
), ),
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(y), (lambda x: x.T.dot(x))(y),
), ),
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -137,22 +128,15 @@ y = np.array( ...@@ -137,22 +128,15 @@ y = np.array(
], ],
) )
def test_Eig(x, exc): def test_Eig(x, exc):
x, test_x = x
g = nlinalg.Eig()(x) g = nlinalg.Eig()(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -160,7 +144,7 @@ def test_Eig(x, exc): ...@@ -160,7 +144,7 @@ def test_Eig(x, exc):
"x, uplo, exc", "x, uplo, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -168,7 +152,7 @@ def test_Eig(x, exc): ...@@ -168,7 +152,7 @@ def test_Eig(x, exc):
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -180,22 +164,15 @@ def test_Eig(x, exc): ...@@ -180,22 +164,15 @@ def test_Eig(x, exc):
], ],
) )
def test_Eigh(x, uplo, exc): def test_Eigh(x, uplo, exc):
x, test_x = x
g = nlinalg.Eigh(uplo)(x) g = nlinalg.Eigh(uplo)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -204,7 +181,7 @@ def test_Eigh(x, uplo, exc): ...@@ -204,7 +181,7 @@ def test_Eigh(x, uplo, exc):
[ [
( (
nlinalg.MatrixInverse, nlinalg.MatrixInverse,
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -213,7 +190,7 @@ def test_Eigh(x, uplo, exc): ...@@ -213,7 +190,7 @@ def test_Eigh(x, uplo, exc):
), ),
( (
nlinalg.MatrixInverse, nlinalg.MatrixInverse,
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -224,7 +201,7 @@ def test_Eigh(x, uplo, exc): ...@@ -224,7 +201,7 @@ def test_Eigh(x, uplo, exc):
), ),
( (
nlinalg.MatrixPinv, nlinalg.MatrixPinv,
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -233,7 +210,7 @@ def test_Eigh(x, uplo, exc): ...@@ -233,7 +210,7 @@ def test_Eigh(x, uplo, exc):
), ),
( (
nlinalg.MatrixPinv, nlinalg.MatrixPinv,
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -245,18 +222,15 @@ def test_Eigh(x, uplo, exc): ...@@ -245,18 +222,15 @@ def test_Eigh(x, uplo, exc):
], ],
) )
def test_matrix_inverses(op, x, exc, op_args): def test_matrix_inverses(op, x, exc, op_args):
x, test_x = x
g = op(*op_args)(x) g = op(*op_args)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -264,7 +238,7 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -264,7 +238,7 @@ def test_matrix_inverses(op, x, exc, op_args):
"x, mode, exc", "x, mode, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -272,7 +246,7 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -272,7 +246,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None, None,
), ),
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -280,7 +254,7 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -280,7 +254,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -290,7 +264,7 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -290,7 +264,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -302,22 +276,15 @@ def test_matrix_inverses(op, x, exc, op_args): ...@@ -302,22 +276,15 @@ def test_matrix_inverses(op, x, exc, op_args):
], ],
) )
def test_QRFull(x, mode, exc): def test_QRFull(x, mode, exc):
x, test_x = x
g = nlinalg.QRFull(mode)(x) g = nlinalg.QRFull(mode)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -325,7 +292,7 @@ def test_QRFull(x, mode, exc): ...@@ -325,7 +292,7 @@ def test_QRFull(x, mode, exc):
"x, full_matrices, compute_uv, exc", "x, full_matrices, compute_uv, exc",
[ [
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -334,7 +301,7 @@ def test_QRFull(x, mode, exc): ...@@ -334,7 +301,7 @@ def test_QRFull(x, mode, exc):
None, None,
), ),
( (
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
...@@ -343,7 +310,7 @@ def test_QRFull(x, mode, exc): ...@@ -343,7 +310,7 @@ def test_QRFull(x, mode, exc):
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -354,7 +321,7 @@ def test_QRFull(x, mode, exc): ...@@ -354,7 +321,7 @@ def test_QRFull(x, mode, exc):
None, None,
), ),
( (
set_test_value( (
pt.lmatrix(), pt.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64") rng.integers(1, 10, size=(3, 3)).astype("int64")
...@@ -367,20 +334,13 @@ def test_QRFull(x, mode, exc): ...@@ -367,20 +334,13 @@ def test_QRFull(x, mode, exc):
], ],
) )
def test_SVD(x, full_matrices, compute_uv, exc): def test_SVD(x, full_matrices, compute_uv, exc):
x, test_x = x
g = nlinalg.SVD(full_matrices, compute_uv)(x) g = nlinalg.SVD(full_matrices, compute_uv)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm: with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, [x],
[ g,
i.tag.test_value [test_x],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -3,7 +3,6 @@ import pytest ...@@ -3,7 +3,6 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor.pad import PadMode from pytensor.tensor.pad import PadMode
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
...@@ -58,10 +57,10 @@ def test_numba_pad(mode: PadMode, kwargs): ...@@ -58,10 +57,10 @@ def test_numba_pad(mode: PadMode, kwargs):
x = np.random.normal(size=(3, 3)) x = np.random.normal(size=(3, 3))
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs) res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
res_fg = FunctionGraph([x_pt], [res])
compare_numba_and_py( compare_numba_and_py(
res_fg, [x_pt],
[res],
[x], [x],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL), assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
py_mode="FAST_RUN", py_mode="FAST_RUN",
......
...@@ -10,13 +10,9 @@ import pytensor.tensor.random.basic as ptr ...@@ -10,13 +10,9 @@ import pytensor.tensor.random.basic as ptr
from pytensor import shared from pytensor import shared
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
numba_mode, numba_mode,
set_test_value,
) )
from tests.tensor.random.test_basic import ( from tests.tensor.random.test_basic import (
batched_permutation_tester, batched_permutation_tester,
...@@ -159,11 +155,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -159,11 +155,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.uniform, ptr.uniform,
[ [
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
...@@ -173,15 +169,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -173,15 +169,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.triangular, ptr.triangular,
[ [
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(-5.0, dtype=np.float64), np.array(-5.0, dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(5.0, dtype=np.float64), np.array(5.0, dtype=np.float64),
), ),
...@@ -191,11 +187,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -191,11 +187,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.lognormal, ptr.lognormal,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -205,11 +201,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -205,11 +201,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.pareto, ptr.pareto,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([2.0, 10.0], dtype=np.float64), np.array([2.0, 10.0], dtype=np.float64),
), ),
...@@ -219,7 +215,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -219,7 +215,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.exponential, ptr.exponential,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
...@@ -229,7 +225,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -229,7 +225,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.weibull, ptr.weibull,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
...@@ -239,11 +235,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -239,11 +235,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.logistic, ptr.logistic,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -253,7 +249,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -253,7 +249,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.geometric, ptr.geometric,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([0.3, 0.4], dtype=np.float64), np.array([0.3, 0.4], dtype=np.float64),
), ),
...@@ -263,15 +259,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -263,15 +259,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
pytest.param( pytest.param(
ptr.hypergeometric, ptr.hypergeometric,
[ [
set_test_value( (
pt.lscalar(), pt.lscalar(),
np.array(7, dtype=np.int64), np.array(7, dtype=np.int64),
), ),
set_test_value( (
pt.lscalar(), pt.lscalar(),
np.array(8, dtype=np.int64), np.array(8, dtype=np.int64),
), ),
set_test_value( (
pt.lscalar(), pt.lscalar(),
np.array(15, dtype=np.int64), np.array(15, dtype=np.int64),
), ),
...@@ -282,11 +278,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -282,11 +278,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.wald, ptr.wald,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -296,11 +292,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -296,11 +292,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.laplace, ptr.laplace,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -310,11 +306,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -310,11 +306,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.binomial, ptr.binomial,
[ [
set_test_value( (
pt.lvector(), pt.lvector(),
np.array([1, 2], dtype=np.int64), np.array([1, 2], dtype=np.int64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(0.9, dtype=np.float64), np.array(0.9, dtype=np.float64),
), ),
...@@ -324,21 +320,21 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -324,21 +320,21 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.normal, ptr.normal,
[ [
set_test_value( (
pt.lvector(), pt.lvector(),
np.array([1, 2], dtype=np.int64), np.array([1, 2], dtype=np.int64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
], ],
pt.as_tensor(tuple(set_test_value(pt.lscalar(), v) for v in [3, 2])), pt.as_tensor([3, 2]),
), ),
( (
ptr.poisson, ptr.poisson,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
...@@ -348,11 +344,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -348,11 +344,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.halfnormal, ptr.halfnormal,
[ [
set_test_value( (
pt.lvector(), pt.lvector(),
np.array([1, 2], dtype=np.int64), np.array([1, 2], dtype=np.int64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -362,7 +358,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -362,7 +358,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.bernoulli, ptr.bernoulli,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([0.1, 0.9], dtype=np.float64), np.array([0.1, 0.9], dtype=np.float64),
), ),
...@@ -372,11 +368,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -372,11 +368,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.beta, ptr.beta,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -386,11 +382,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -386,11 +382,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr._gamma, ptr._gamma,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([0.5, 3.0], dtype=np.float64), np.array([0.5, 3.0], dtype=np.float64),
), ),
...@@ -400,7 +396,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -400,7 +396,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.chisquare, ptr.chisquare,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
) )
...@@ -410,11 +406,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -410,11 +406,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.negative_binomial, ptr.negative_binomial,
[ [
set_test_value( (
pt.lvector(), pt.lvector(),
np.array([100, 200], dtype=np.int64), np.array([100, 200], dtype=np.int64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(0.09, dtype=np.float64), np.array(0.09, dtype=np.float64),
), ),
...@@ -424,11 +420,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -424,11 +420,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.vonmises, ptr.vonmises,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([-0.5, 0.5], dtype=np.float64), np.array([-0.5, 0.5], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -438,14 +434,14 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -438,14 +434,14 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
( (
ptr.permutation, ptr.permutation,
[ [
set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)), (pt.dmatrix(), np.eye(5, dtype=np.float64)),
], ],
(), (),
), ),
( (
partial(ptr.choice, replace=True), partial(ptr.choice, replace=True),
[ [
set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)), (pt.dmatrix(), np.eye(5, dtype=np.float64)),
], ],
pt.as_tensor([2]), pt.as_tensor([2]),
), ),
...@@ -455,17 +451,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -455,17 +451,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a, p=p, size=size, replace=True, rng=rng a, p=p, size=size, replace=True, rng=rng
), ),
[ [
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), (pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value( (pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)),
pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
),
], ],
(pt.as_tensor([2, 3])), (pt.as_tensor([2, 3])),
), ),
pytest.param( pytest.param(
partial(ptr.choice, replace=False), partial(ptr.choice, replace=False),
[ [
set_test_value(pt.dvector(), np.arange(5, dtype=np.float64)), (pt.dvector(), np.arange(5, dtype=np.float64)),
], ],
pt.as_tensor([2]), pt.as_tensor([2]),
marks=pytest.mark.xfail( marks=pytest.mark.xfail(
...@@ -476,7 +470,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -476,7 +470,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
pytest.param( pytest.param(
partial(ptr.choice, replace=False), partial(ptr.choice, replace=False),
[ [
set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)), (pt.dmatrix(), np.eye(5, dtype=np.float64)),
], ],
pt.as_tensor([2]), pt.as_tensor([2]),
marks=pytest.mark.xfail( marks=pytest.mark.xfail(
...@@ -490,8 +484,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -490,8 +484,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a, p=p, size=size, replace=False, rng=rng a, p=p, size=size, replace=False, rng=rng
), ),
[ [
set_test_value(pt.vector(), np.arange(5, dtype=np.float64)), (pt.vector(), np.arange(5, dtype=np.float64)),
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([0.5, 0.0, 0.25, 0.0, 0.25], dtype=np.float64), np.array([0.5, 0.0, 0.25, 0.0, 0.25], dtype=np.float64),
), ),
...@@ -504,10 +498,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -504,10 +498,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a, p=p, size=size, replace=False, rng=rng a, p=p, size=size, replace=False, rng=rng
), ),
[ [
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), (pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value( (pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)),
pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
),
], ],
(), (),
), ),
...@@ -517,10 +509,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -517,10 +509,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a, p=p, size=size, replace=False, rng=rng a, p=p, size=size, replace=False, rng=rng
), ),
[ [
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), (pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value( (pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)),
pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
),
], ],
(pt.as_tensor([2, 1])), (pt.as_tensor([2, 1])),
), ),
...@@ -529,17 +519,14 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -529,17 +519,14 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
) )
def test_aligned_RandomVariable(rv_op, dist_args, size): def test_aligned_RandomVariable(rv_op, dist_args, size):
"""Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers.""" """Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers."""
dist_args, test_dist_args = zip(*dist_args, strict=True)
rng = shared(np.random.default_rng(29402)) rng = shared(np.random.default_rng(29402))
g = rv_op(*dist_args, size=size, rng=rng) g = rv_op(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, dist_args,
[ [g],
i.tag.test_value test_dist_args,
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
eval_obj_mode=False, # No python impl eval_obj_mode=False, # No python impl
) )
...@@ -550,11 +537,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ...@@ -550,11 +537,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
( (
ptr.cauchy, ptr.cauchy,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -566,11 +553,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ...@@ -566,11 +553,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
( (
ptr.gumbel, ptr.gumbel,
[ [
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( (
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
...@@ -583,18 +570,13 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ...@@ -583,18 +570,13 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
) )
def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv): def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv):
"""Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers.""" """Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers."""
dist_args, test_dist_args = zip(*dist_args, strict=True)
rng = shared(np.random.default_rng(29402)) rng = shared(np.random.default_rng(29402))
g = rv_op(*dist_args, size=(2000, *base_size), rng=rng) g = rv_op(*dist_args, size=(2000, *base_size), rng=rng)
g_fn = function(dist_args, g, mode=numba_mode) g_fn = function(dist_args, g, mode=numba_mode)
samples = g_fn( samples = g_fn(*test_dist_args)
*[
i.tag.test_value
for i in g_fn.maker.fgraph.inputs
if not isinstance(i, SharedVariable | Constant)
]
)
bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_args]) bcast_dist_args = np.broadcast_arrays(*test_dist_args)
for idx in np.ndindex(*base_size): for idx in np.ndindex(*base_size):
cdf_params = params_conv(*(arg[idx] for arg in bcast_dist_args)) cdf_params = params_conv(*(arg[idx] for arg in bcast_dist_args))
...@@ -608,7 +590,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -608,7 +590,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
"a, size, cm", "a, size, cm",
[ [
pytest.param( pytest.param(
set_test_value( (
pt.dvector(), pt.dvector(),
np.array([100000, 1, 1], dtype=np.float64), np.array([100000, 1, 1], dtype=np.float64),
), ),
...@@ -616,7 +598,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -616,7 +598,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
contextlib.suppress(), contextlib.suppress(),
), ),
pytest.param( pytest.param(
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
np.array( np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
...@@ -627,7 +609,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -627,7 +609,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
contextlib.suppress(), contextlib.suppress(),
), ),
pytest.param( pytest.param(
set_test_value( (
pt.dmatrix(), pt.dmatrix(),
np.array( np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
...@@ -643,13 +625,12 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -643,13 +625,12 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
], ],
) )
def test_DirichletRV(a, size, cm): def test_DirichletRV(a, size, cm):
a, a_val = a
rng = shared(np.random.default_rng(29402)) rng = shared(np.random.default_rng(29402))
g = ptr.dirichlet(a, size=size, rng=rng) g = ptr.dirichlet(a, size=size, rng=rng)
g_fn = function([a], g, mode=numba_mode) g_fn = function([a], g, mode=numba_mode)
with cm: with cm:
a_val = a.tag.test_value
all_samples = [] all_samples = []
for i in range(1000): for i in range(1000):
samples = g_fn(a_val) samples = g_fn(a_val)
......
...@@ -5,13 +5,10 @@ import pytensor.scalar as ps ...@@ -5,13 +5,10 @@ import pytensor.scalar as ps
import pytensor.scalar.basic as psb import pytensor.scalar.basic as psb
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -21,48 +18,43 @@ rng = np.random.default_rng(42849) ...@@ -21,48 +18,43 @@ rng = np.random.default_rng(42849)
"x, y", "x, y",
[ [
( (
set_test_value(pt.lvector(), np.arange(4, dtype="int64")), (pt.lvector(), np.arange(4, dtype="int64")),
set_test_value(pt.dvector(), np.arange(4, dtype="float64")), (pt.dvector(), np.arange(4, dtype="float64")),
), ),
( (
set_test_value(pt.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))), (pt.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))),
set_test_value(pt.lscalar(), np.array(4, dtype="int64")), (pt.lscalar(), np.array(4, dtype="int64")),
), ),
], ],
) )
def test_Second(x, y): def test_Second(x, y):
x, x_test = x
y, y_test = y
# We use the `Elemwise`-wrapped version of `Second` # We use the `Elemwise`-wrapped version of `Second`
g = pt.second(x, y) g = pt.second(x, y)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [x, y],
[ g,
i.tag.test_value [x_test, y_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, min, max", "v, min, max",
[ [
(set_test_value(pt.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0), ((pt.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0),
(set_test_value(pt.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0), ((pt.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0),
(set_test_value(pt.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0), ((pt.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0),
], ],
) )
def test_Clip(v, min, max): def test_Clip(v, min, max):
v, v_test = v
g = ps.clip(v, min, max) g = ps.clip(v, min, max)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ [g],
i.tag.test_value [v_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -100,46 +92,39 @@ def test_Clip(v, min, max): ...@@ -100,46 +92,39 @@ def test_Clip(v, min, max):
def test_Composite(inputs, input_values, scalar_fn): def test_Composite(inputs, input_values, scalar_fn):
composite_inputs = [ps.ScalarType(config.floatX)(name=i.name) for i in inputs] composite_inputs = [ps.ScalarType(config.floatX)(name=i.name) for i in inputs]
comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)])) comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)]))
out_fg = FunctionGraph(inputs, [comp_op(*inputs)]) compare_numba_and_py(inputs, [comp_op(*inputs)], input_values)
compare_numba_and_py(out_fg, input_values)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, dtype", "v, dtype",
[ [
(set_test_value(pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64), ((pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64),
(set_test_value(pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32), ((pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32),
], ],
) )
def test_Cast(v, dtype): def test_Cast(v, dtype):
v, v_test = v
g = psb.Cast(dtype)(v) g = psb.Cast(dtype)(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ [g],
i.tag.test_value [v_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"v, dtype", "v, dtype",
[ [
(set_test_value(pt.iscalar(), np.array(10, dtype="int32")), psb.float64), ((pt.iscalar(), np.array(10, dtype="int32")), psb.float64),
], ],
) )
def test_reciprocal(v, dtype): def test_reciprocal(v, dtype):
v, v_test = v
g = psb.reciprocal(v) g = psb.reciprocal(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ [g],
i.tag.test_value [v_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -156,6 +141,7 @@ def test_isnan(composite): ...@@ -156,6 +141,7 @@ def test_isnan(composite):
out = pt.isnan(x) out = pt.isnan(x)
compare_numba_and_py( compare_numba_and_py(
([x], [out]), [x],
[out],
[np.array([1, 0], dtype="float64")], [np.array([1, 0], dtype="float64")],
) )
...@@ -5,7 +5,6 @@ import pytensor ...@@ -5,7 +5,6 @@ import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function, grad from pytensor import config, function, grad
from pytensor.compile.mode import Mode, get_mode from pytensor.compile.mode import Mode, get_mode
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Log1p from pytensor.scalar import Log1p
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
...@@ -147,7 +146,7 @@ def test_xit_xot_types( ...@@ -147,7 +146,7 @@ def test_xit_xot_types(
if output_vals is None: if output_vals is None:
compare_numba_and_py( compare_numba_and_py(
(sequences + non_sequences, res), input_vals, updates=updates sequences + non_sequences, res, input_vals, updates=updates
) )
else: else:
numba_mode = get_mode("NUMBA") numba_mode = get_mode("NUMBA")
...@@ -217,10 +216,7 @@ def test_scan_multiple_output(benchmark): ...@@ -217,10 +216,7 @@ def test_scan_multiple_output(benchmark):
logp_c_all.name = "C_t_logp" logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp" logp_d_all.name = "D_t_logp"
out_fg = FunctionGraph( out = [st, et, it, logp_c_all, logp_d_all]
[pt_C, pt_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
s0, e0, i0 = 100, 50, 25 s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=config.floatX) logp_c0 = np.array(0.0, dtype=config.floatX)
...@@ -243,21 +239,21 @@ def test_scan_multiple_output(benchmark): ...@@ -243,21 +239,21 @@ def test_scan_multiple_output(benchmark):
gamma_val, gamma_val,
delta_val, delta_val,
] ]
scan_fn, _ = compare_numba_and_py(out_fg, test_input_vals) scan_fn, _ = compare_numba_and_py(
[pt_C, pt_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
out,
test_input_vals,
)
benchmark(scan_fn, *test_input_vals) benchmark(scan_fn, *test_input_vals)
@config.change_flags(compute_test_value="raise")
def test_scan_tap_output(): def test_scan_tap_output():
a_pt = pt.scalar("a") a_pt = pt.scalar("a")
a_pt.tag.test_value = 10.0
b_pt = pt.arange(11).astype(config.floatX) b_pt = pt.vector("b")
b_pt.name = "b"
c_pt = pt.arange(20, 31, dtype=config.floatX) c_pt = pt.vector("c")
c_pt.name = "c"
def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a): def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a):
x_tm1.name = "x_tm1" x_tm1.name = "x_tm1"
...@@ -301,14 +297,12 @@ def test_scan_tap_output(): ...@@ -301,14 +297,12 @@ def test_scan_tap_output():
strict=True, strict=True,
) )
out_fg = FunctionGraph([a_pt, b_pt, c_pt], scan_res)
test_input_vals = [ test_input_vals = [
np.array(10.0).astype(config.floatX), np.array(10.0).astype(config.floatX),
np.arange(11, dtype=config.floatX), np.arange(11, dtype=config.floatX),
np.arange(20, 31, dtype=config.floatX), np.arange(20, 31, dtype=config.floatX),
] ]
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py([a_pt, b_pt, c_pt], scan_res, test_input_vals)
def test_scan_while(): def test_scan_while():
...@@ -323,12 +317,10 @@ def test_scan_while(): ...@@ -323,12 +317,10 @@ def test_scan_while():
n_steps=1024, n_steps=1024,
) )
out_fg = FunctionGraph([max_value], [values])
test_input_vals = [ test_input_vals = [
np.array(45).astype(config.floatX), np.array(45).astype(config.floatX),
] ]
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py([max_value], [values], test_input_vals)
def test_scan_multiple_none_output(): def test_scan_multiple_none_output():
...@@ -343,11 +335,8 @@ def test_scan_multiple_none_output(): ...@@ -343,11 +335,8 @@ def test_scan_multiple_none_output():
outputs_info=[pt.ones_like(A), None, None], outputs_info=[pt.ones_like(A), None, None],
n_steps=3, n_steps=3,
) )
out_fg = FunctionGraph([A], result)
test_input_vals = (np.array([1.0, 2.0]),) test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py([A], result, test_input_vals)
compare_numba_and_py(out_fg, test_input_vals)
@pytest.mark.parametrize("n_steps_val", [1, 5]) @pytest.mark.parametrize("n_steps_val", [1, 5])
...@@ -372,11 +361,14 @@ def test_scan_save_mem_basic(n_steps_val): ...@@ -372,11 +361,14 @@ def test_scan_save_mem_basic(n_steps_val):
numba_mode = get_mode("NUMBA").including("scan_save_mem") numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem") py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([init_x, n_steps], [output])
test_input_vals = (state_val, n_steps_val) test_input_vals = (state_val, n_steps_val)
compare_numba_and_py( compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode [init_x, n_steps],
[output],
test_input_vals,
numba_mode=numba_mode,
py_mode=py_mode,
) )
...@@ -410,14 +402,12 @@ def test_mitmots_basic(): ...@@ -410,14 +402,12 @@ def test_mitmots_basic():
numba_mode = get_mode("NUMBA").including("scan_save_mem") numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem") py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([seq, init_x], g_outs)
seq_val = np.arange(3) seq_val = np.arange(3)
init_x_val = np.r_[-2, -1] init_x_val = np.r_[-2, -1]
test_input_vals = (seq_val, init_x_val) test_input_vals = (seq_val, init_x_val)
compare_numba_and_py( compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode [seq, init_x], g_outs, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
) )
......
...@@ -9,14 +9,14 @@ from scipy import linalg as scipy_linalg ...@@ -9,14 +9,14 @@ from scipy import linalg as scipy_linalg
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.graph import FunctionGraph from pytensor import config
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
numba = pytest.importorskip("numba") numba = pytest.importorskip("numba")
floatX = pytensor.config.floatX floatX = config.floatX
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -88,7 +88,12 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl ...@@ -88,7 +88,12 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL) np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()]) compiled_fgraph = f.maker.fgraph
compare_numba_and_py(
compiled_fgraph.inputs,
compiled_fgraph.outputs,
[A_func(A_val.copy()), b_val.copy()],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -159,12 +164,10 @@ def test_numba_Cholesky(lower, trans): ...@@ -159,12 +164,10 @@ def test_numba_Cholesky(lower, trans):
cov_ = cov cov_ = cov
chol = pt.linalg.cholesky(cov_, lower=lower) chol = pt.linalg.cholesky(cov_, lower=lower)
fg = FunctionGraph(outputs=[chol])
x = np.array([0.1, 0.2, 0.3]).astype(floatX) x = np.array([0.1, 0.2, 0.3]).astype(floatX)
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None] val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
compare_numba_and_py(fg, [val]) compare_numba_and_py([cov], [chol], [val])
def test_numba_Cholesky_raises_on_nan_input(): def test_numba_Cholesky_raises_on_nan_input():
...@@ -218,8 +221,7 @@ def test_block_diag(): ...@@ -218,8 +221,7 @@ def test_block_diag():
B_val = np.random.normal(size=(3, 3)).astype(floatX) B_val = np.random.normal(size=(3, 3)).astype(floatX)
C_val = np.random.normal(size=(2, 2)).astype(floatX) C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX) D_val = np.random.normal(size=(4, 4)).astype(floatX)
out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X]) compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val])
def test_lamch(): def test_lamch():
...@@ -390,7 +392,7 @@ def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]): ...@@ -390,7 +392,7 @@ def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
) )
op = f.maker.fgraph.outputs[0].owner.op op = f.maker.fgraph.outputs[0].owner.op
compare_numba_and_py(([A, b], [X]), inputs=[A_val, b_val], inplace=True) compare_numba_and_py([A, b], [X], test_inputs=[A_val, b_val], inplace=True)
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first. # Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
A_val_copy = A_val.copy() A_val_copy = A_val.copy()
......
...@@ -100,4 +100,4 @@ def test_sparse_objmode(): ...@@ -100,4 +100,4 @@ def test_sparse_objmode():
UserWarning, UserWarning,
match="Numba will use object mode to run SparseDot's perform method", match="Numba will use object mode to run SparseDot's perform method",
): ):
compare_numba_and_py(((x, y), (out,)), [x_val, y_val]) compare_numba_and_py([x, y], out, [x_val, y_val])
...@@ -4,7 +4,6 @@ import numpy as np ...@@ -4,7 +4,6 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from pytensor.tensor import as_tensor from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -44,8 +43,7 @@ def test_Subtensor(x, indices): ...@@ -44,8 +43,7 @@ def test_Subtensor(x, indices):
"""Test NumPy's basic indexing.""" """Test NumPy's basic indexing."""
out_pt = x[indices] out_pt = x[indices]
assert isinstance(out_pt.owner.op, Subtensor) assert isinstance(out_pt.owner.op, Subtensor)
out_fg = FunctionGraph([], [out_pt]) compare_numba_and_py([], [out_pt], [])
compare_numba_and_py(out_fg, [])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -59,16 +57,14 @@ def test_AdvancedSubtensor1(x, indices): ...@@ -59,16 +57,14 @@ def test_AdvancedSubtensor1(x, indices):
"""Test NumPy's advanced indexing in one dimension.""" """Test NumPy's advanced indexing in one dimension."""
out_pt = advanced_subtensor1(x, *indices) out_pt = advanced_subtensor1(x, *indices)
assert isinstance(out_pt.owner.op, AdvancedSubtensor1) assert isinstance(out_pt.owner.op, AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_pt]) compare_numba_and_py([], [out_pt], [])
compare_numba_and_py(out_fg, [])
def test_AdvancedSubtensor1_out_of_bounds(): def test_AdvancedSubtensor1_out_of_bounds():
out_pt = advanced_subtensor1(np.arange(3), [4]) out_pt = advanced_subtensor1(np.arange(3), [4])
assert isinstance(out_pt.owner.op, AdvancedSubtensor1) assert isinstance(out_pt.owner.op, AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_pt])
with pytest.raises(IndexError): with pytest.raises(IndexError):
compare_numba_and_py(out_fg, []) compare_numba_and_py([], [out_pt], [])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -151,7 +147,6 @@ def test_AdvancedSubtensor(x, indices, objmode_needed): ...@@ -151,7 +147,6 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
x_pt = x.type() x_pt = x.type()
out_pt = x_pt[indices] out_pt = x_pt[indices]
assert isinstance(out_pt.owner.op, AdvancedSubtensor) assert isinstance(out_pt.owner.op, AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
with ( with (
pytest.warns( pytest.warns(
UserWarning, UserWarning,
...@@ -161,7 +156,8 @@ def test_AdvancedSubtensor(x, indices, objmode_needed): ...@@ -161,7 +156,8 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
compare_numba_and_py( compare_numba_and_py(
out_fg, [x_pt],
[out_pt],
[x.data], [x.data],
numba_mode=numba_mode.including("specialize"), numba_mode=numba_mode.including("specialize"),
) )
...@@ -195,19 +191,16 @@ def test_AdvancedSubtensor(x, indices, objmode_needed): ...@@ -195,19 +191,16 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
def test_IncSubtensor(x, y, indices): def test_IncSubtensor(x, y, indices):
out_pt = set_subtensor(x[indices], y) out_pt = set_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, IncSubtensor) assert isinstance(out_pt.owner.op, IncSubtensor)
out_fg = FunctionGraph([], [out_pt]) compare_numba_and_py([], [out_pt], [])
compare_numba_and_py(out_fg, [])
out_pt = inc_subtensor(x[indices], y) out_pt = inc_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, IncSubtensor) assert isinstance(out_pt.owner.op, IncSubtensor)
out_fg = FunctionGraph([], [out_pt]) compare_numba_and_py([], [out_pt], [])
compare_numba_and_py(out_fg, [])
x_pt = x.type() x_pt = x.type()
out_pt = set_subtensor(x_pt[indices], y, inplace=True) out_pt = set_subtensor(x_pt[indices], y, inplace=True)
assert isinstance(out_pt.owner.op, IncSubtensor) assert isinstance(out_pt.owner.op, IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_numba_and_py([x_pt], [out_pt], [x.data])
compare_numba_and_py(out_fg, [x.data])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -249,13 +242,11 @@ def test_IncSubtensor(x, y, indices): ...@@ -249,13 +242,11 @@ def test_IncSubtensor(x, y, indices):
def test_AdvancedIncSubtensor1(x, y, indices): def test_AdvancedIncSubtensor1(x, y, indices):
out_pt = advanced_set_subtensor1(x, y, *indices) out_pt = advanced_set_subtensor1(x, y, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_pt]) compare_numba_and_py([], [out_pt], [])
compare_numba_and_py(out_fg, [])
out_pt = advanced_inc_subtensor1(x, y, *indices) out_pt = advanced_inc_subtensor1(x, y, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_pt]) compare_numba_and_py([], [out_pt], [])
compare_numba_and_py(out_fg, [])
# With symbolic inputs # With symbolic inputs
x_pt = x.type() x_pt = x.type()
...@@ -263,15 +254,13 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -263,15 +254,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
out_pt = AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices) out_pt = AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt, y_pt], [out_pt]) compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data])
compare_numba_and_py(out_fg, [x.data, y.data])
out_pt = AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)( out_pt = AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)(
x_pt, y_pt, *indices x_pt, y_pt, *indices
) )
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1) assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt, y_pt], [out_pt]) compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data])
compare_numba_and_py(out_fg, [x.data, y.data])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -454,7 +443,7 @@ def test_AdvancedIncSubtensor( ...@@ -454,7 +443,7 @@ def test_AdvancedIncSubtensor(
if set_requires_objmode if set_requires_objmode
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode) fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode)
if inplace: if inplace:
# Test updates inplace # Test updates inplace
...@@ -474,7 +463,7 @@ def test_AdvancedIncSubtensor( ...@@ -474,7 +463,7 @@ def test_AdvancedIncSubtensor(
if inc_requires_objmode if inc_requires_objmode
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode) fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode)
if inplace: if inplace:
# Test updates inplace # Test updates inplace
x_orig = x.copy() x_orig = x.copy()
......
...@@ -6,15 +6,11 @@ import pytensor.tensor as pt ...@@ -6,15 +6,11 @@ import pytensor.tensor as pt
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor import config, function from pytensor import config, function
from pytensor.compile import get_mode from pytensor.compile import get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Add from pytensor.scalar import Add
from pytensor.tensor.shape import Unbroadcast from pytensor.tensor.shape import Unbroadcast
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
compare_shape_dtype, compare_shape_dtype,
set_test_value,
) )
from tests.tensor.test_basic import check_alloc_runtime_broadcast from tests.tensor.test_basic import check_alloc_runtime_broadcast
...@@ -31,21 +27,18 @@ rng = np.random.default_rng(42849) ...@@ -31,21 +27,18 @@ rng = np.random.default_rng(42849)
[ [
(0.0, (2, 3)), (0.0, (2, 3)),
(1.1, (2, 3)), (1.1, (2, 3)),
(set_test_value(pt.scalar("a"), np.array(10.0, dtype=config.floatX)), (20,)), ((pt.scalar("a"), np.array(10.0, dtype=config.floatX)), (20,)),
(set_test_value(pt.vector("a"), np.ones(10, dtype=config.floatX)), (20, 10)), ((pt.vector("a"), np.ones(10, dtype=config.floatX)), (20, 10)),
], ],
) )
def test_Alloc(v, shape): def test_Alloc(v, shape):
v, v_test = v if isinstance(v, tuple) else (v, None)
g = pt.alloc(v, *shape) g = pt.alloc(v, *shape)
g_fg = FunctionGraph(outputs=[g])
_, (numba_res,) = compare_numba_and_py( _, (numba_res,) = compare_numba_and_py(
g_fg, [v] if v_test is not None else [],
[ [g],
i.tag.test_value [v_test] if v_test is not None else [],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
assert numba_res.shape == shape assert numba_res.shape == shape
...@@ -57,58 +50,38 @@ def test_alloc_runtime_broadcast(): ...@@ -57,58 +50,38 @@ def test_alloc_runtime_broadcast():
def test_AllocEmpty(): def test_AllocEmpty():
x = pt.empty((2, 3), dtype="float32") x = pt.empty((2, 3), dtype="float32")
x_fg = FunctionGraph([], [x])
# We cannot compare the values in the arrays, only the shapes and dtypes # We cannot compare the values in the arrays, only the shapes and dtypes
compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype) compare_numba_and_py([], x, [], assert_fn=compare_shape_dtype)
@pytest.mark.parametrize( def test_TensorFromScalar():
"v", [set_test_value(ps.float64(), np.array(1.0, dtype="float64"))] v, v_test = ps.float64(), np.array(1.0, dtype="float64")
)
def test_TensorFromScalar(v):
g = ptb.TensorFromScalar()(v) g = ptb.TensorFromScalar()(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ g,
i.tag.test_value [v_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
@pytest.mark.parametrize( def test_ScalarFromTensor():
"v", v, v_test = pt.scalar(), np.array(1.0, dtype=config.floatX)
[
set_test_value(pt.scalar(), np.array(1.0, dtype=config.floatX)),
],
)
def test_ScalarFromTensor(v):
g = ptb.ScalarFromTensor()(v) g = ptb.ScalarFromTensor()(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ g,
i.tag.test_value [v_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
def test_Unbroadcast(): def test_Unbroadcast():
v = set_test_value(pt.row(), np.array([[1.0, 2.0]], dtype=config.floatX)) v, v_test = pt.row(), np.array([[1.0, 2.0]], dtype=config.floatX)
g = Unbroadcast(0)(v) g = Unbroadcast(0)(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [v],
[ g,
i.tag.test_value [v_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -117,65 +90,52 @@ def test_Unbroadcast(): ...@@ -117,65 +90,52 @@ def test_Unbroadcast():
[ [
( (
( (
set_test_value(pt.scalar(), np.array(1, dtype=config.floatX)), (pt.scalar(), np.array(1, dtype=config.floatX)),
set_test_value(pt.scalar(), np.array(2, dtype=config.floatX)), (pt.scalar(), np.array(2, dtype=config.floatX)),
set_test_value(pt.scalar(), np.array(3, dtype=config.floatX)), (pt.scalar(), np.array(3, dtype=config.floatX)),
), ),
config.floatX, config.floatX,
), ),
( (
( (
set_test_value(pt.dscalar(), np.array(1, dtype=np.float64)), (pt.dscalar(), np.array(1, dtype=np.float64)),
set_test_value(pt.lscalar(), np.array(3, dtype=np.int32)), (pt.lscalar(), np.array(3, dtype=np.int32)),
), ),
"float64", "float64",
), ),
( (
(set_test_value(pt.iscalar(), np.array(1, dtype=np.int32)),), ((pt.iscalar(), np.array(1, dtype=np.int32)),),
"float64", "float64",
), ),
( (
(set_test_value(pt.scalar(dtype=bool), True),), ((pt.scalar(dtype=bool), True),),
bool, bool,
), ),
], ],
) )
def test_MakeVector(vals, dtype): def test_MakeVector(vals, dtype):
vals, vals_test = zip(*vals, strict=True)
g = ptb.MakeVector(dtype)(*vals) g = ptb.MakeVector(dtype)(*vals)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, vals,
[ [g],
i.tag.test_value vals_test,
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
@pytest.mark.parametrize( def test_ARange():
"start, stop, step, dtype", start, start_test = pt.lscalar(), np.array(1)
[ stop, stop_tset = pt.lscalar(), np.array(10)
( step, step_test = pt.lscalar(), np.array(3)
set_test_value(pt.lscalar(), np.array(1)), dtype = config.floatX
set_test_value(pt.lscalar(), np.array(10)),
set_test_value(pt.lscalar(), np.array(3)),
config.floatX,
),
],
)
def test_ARange(start, stop, step, dtype):
g = ptb.ARange(dtype)(start, stop, step) g = ptb.ARange(dtype)(start, stop, step)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [start, stop, step],
[ g,
i.tag.test_value [start_test, stop_tset, step_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -184,80 +144,60 @@ def test_ARange(start, stop, step, dtype): ...@@ -184,80 +144,60 @@ def test_ARange(start, stop, step, dtype):
[ [
( (
( (
set_test_value( (pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)),
pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) (pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)),
),
set_test_value(
pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)
),
), ),
0, 0,
), ),
( (
( (
set_test_value( (pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX)),
pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX) (pt.matrix(), rng.normal(size=(3, 1)).astype(config.floatX)),
),
set_test_value(
pt.matrix(), rng.normal(size=(3, 1)).astype(config.floatX)
),
), ),
0, 0,
), ),
( (
( (
set_test_value( (pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)),
pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) (pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)),
),
set_test_value(
pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)
),
), ),
1, 1,
), ),
( (
( (
set_test_value( (pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX) (pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX)),
),
set_test_value(
pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX)
),
), ),
1, 1,
), ),
], ],
) )
def test_Join(vals, axis): def test_Join(vals, axis):
vals, vals_test = zip(*vals, strict=True)
g = pt.join(axis, *vals) g = pt.join(axis, *vals)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, vals,
[ g,
i.tag.test_value vals_test,
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
def test_Join_view(): def test_Join_view():
vals = ( vals, vals_test = zip(
set_test_value(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), *(
set_test_value(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), (pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
),
strict=True,
) )
g = ptb.Join(view=1)(1, *vals) g = ptb.Join(view=1)(1, *vals)
g_fg = FunctionGraph(outputs=[g])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
compare_numba_and_py( compare_numba_and_py(
g_fg, vals,
[ g,
i.tag.test_value vals_test,
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -267,57 +207,47 @@ def test_Join_view(): ...@@ -267,57 +207,47 @@ def test_Join_view():
( (
0, 0,
0, 0,
set_test_value(pt.vector(), rng.normal(size=20).astype(config.floatX)), (pt.vector(), rng.normal(size=20).astype(config.floatX)),
set_test_value(pt.vector(dtype="int64"), []), (pt.vector(dtype="int64"), []),
), ),
( (
5, 5,
0, 0,
set_test_value(pt.vector(), rng.normal(size=5).astype(config.floatX)), (pt.vector(), rng.normal(size=5).astype(config.floatX)),
set_test_value( (pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5)),
pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5)
),
), ),
( (
5, 5,
0, 0,
set_test_value(pt.vector(), rng.normal(size=10).astype(config.floatX)), (pt.vector(), rng.normal(size=10).astype(config.floatX)),
set_test_value( (pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5)),
pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5)
),
), ),
( (
5, 5,
-1, -1,
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), (pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
set_test_value( (pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5)),
pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5)
),
), ),
( (
5, 5,
-2, -2,
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), (pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
set_test_value( (pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5)),
pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5)
),
), ),
], ],
) )
def test_Split(n_splits, axis, values, sizes): def test_Split(n_splits, axis, values, sizes):
values, values_test = values
sizes, sizes_test = sizes
g = pt.split(values, sizes, n_splits, axis=axis) g = pt.split(values, sizes, n_splits, axis=axis)
assert len(g) == n_splits assert len(g) == n_splits
if n_splits == 0: if n_splits == 0:
return return
g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g)
compare_numba_and_py( compare_numba_and_py(
g_fg, [values, sizes],
[ g,
i.tag.test_value [values_test, sizes_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -349,34 +279,27 @@ def test_Split_view(): ...@@ -349,34 +279,27 @@ def test_Split_view():
"val, offset", "val, offset",
[ [
( (
set_test_value( (pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))),
pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))
),
0, 0,
), ),
( (
set_test_value( (pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))),
pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))
),
-1, -1,
), ),
( (
set_test_value(pt.vector(), np.arange(10, dtype=config.floatX)), (pt.vector(), np.arange(10, dtype=config.floatX)),
0, 0,
), ),
], ],
) )
def test_ExtractDiag(val, offset): def test_ExtractDiag(val, offset):
val, val_test = val
g = pt.diag(val, offset) g = pt.diag(val, offset)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [val],
[ g,
i.tag.test_value [val_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -407,30 +330,28 @@ def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis): ...@@ -407,30 +330,28 @@ def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"n, m, k, dtype", "n, m, k, dtype",
[ [
(set_test_value(pt.lscalar(), np.array(1, dtype=np.int64)), None, 0, None), ((pt.lscalar(), np.array(1, dtype=np.int64)), None, 0, None),
( (
set_test_value(pt.lscalar(), np.array(1, dtype=np.int64)), (pt.lscalar(), np.array(1, dtype=np.int64)),
set_test_value(pt.lscalar(), np.array(2, dtype=np.int64)), (pt.lscalar(), np.array(2, dtype=np.int64)),
0, 0,
"float32", "float32",
), ),
( (
set_test_value(pt.lscalar(), np.array(1, dtype=np.int64)), (pt.lscalar(), np.array(1, dtype=np.int64)),
set_test_value(pt.lscalar(), np.array(2, dtype=np.int64)), (pt.lscalar(), np.array(2, dtype=np.int64)),
1, 1,
"int64", "int64",
), ),
], ],
) )
def test_Eye(n, m, k, dtype): def test_Eye(n, m, k, dtype):
n, n_test = n
m, m_test = m if m is not None else (None, None)
g = pt.eye(n, m, k, dtype=dtype) g = pt.eye(n, m, k, dtype=dtype)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
g_fg, [n, m] if m is not None else [n],
[ g,
i.tag.test_value [n_test, m_test] if m is not None else [n_test],
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
) )
...@@ -9,10 +9,10 @@ import pytensor.tensor.basic as ptb ...@@ -9,10 +9,10 @@ import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import PYTORCH, Mode from pytensor.compile.mode import PYTORCH, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
...@@ -39,10 +39,10 @@ py_mode = Mode(linker="py", optimizer=None) ...@@ -39,10 +39,10 @@ py_mode = Mode(linker="py", optimizer=None)
def compare_pytorch_and_py( def compare_pytorch_and_py(
fgraph: FunctionGraph, graph_inputs: Iterable[Variable],
graph_outputs: Variable | Iterable[Variable],
test_inputs: Iterable, test_inputs: Iterable,
assert_fn: Callable | None = None, assert_fn: Callable | None = None,
must_be_device_array: bool = True,
pytorch_mode=pytorch_mode, pytorch_mode=pytorch_mode,
py_mode=py_mode, py_mode=py_mode,
): ):
...@@ -50,8 +50,10 @@ def compare_pytorch_and_py( ...@@ -50,8 +50,10 @@ def compare_pytorch_and_py(
Parameters Parameters
---------- ----------
fgraph: FunctionGraph graph_inputs
PyTensor function Graph object Symbolic inputs to the graph
graph_outputs:
Symbolic outputs of the graph
test_inputs: iter test_inputs: iter
Numerical inputs for testing the function graph Numerical inputs for testing the function graph
assert_fn: func, opt assert_fn: func, opt
...@@ -63,24 +65,22 @@ def compare_pytorch_and_py( ...@@ -63,24 +65,22 @@ def compare_pytorch_and_py(
if assert_fn is None: if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose) assert_fn = partial(np.testing.assert_allclose)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] if any(inp.owner is not None for inp in graph_inputs):
raise ValueError("Inputs must be root variables")
pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode) pytensor_torch_fn = function(graph_inputs, graph_outputs, mode=pytorch_mode)
pytorch_res = pytensor_torch_fn(*test_inputs) pytorch_res = pytensor_torch_fn(*test_inputs)
if isinstance(pytorch_res, list): pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
assert all(isinstance(res, np.ndarray) for res in pytorch_res)
else:
assert isinstance(pytorch_res, np.ndarray)
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs) py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1: if isinstance(graph_outputs, list | tuple):
for pytorch_res_i, py_res_i in zip(pytorch_res, py_res, strict=True): for pytorch_res_i, py_res_i in zip(pytorch_res, py_res, strict=True):
assert not isinstance(pytorch_res_i, torch.Tensor)
assert_fn(pytorch_res_i, py_res_i) assert_fn(pytorch_res_i, py_res_i)
else: else:
assert_fn(pytorch_res[0], py_res[0]) assert not isinstance(pytorch_res, torch.Tensor)
assert_fn(pytorch_res, py_res)
return pytensor_torch_fn, pytorch_res return pytensor_torch_fn, pytorch_res
...@@ -231,7 +231,8 @@ def test_alloc_and_empty(): ...@@ -231,7 +231,8 @@ def test_alloc_and_empty():
v = vector("v", shape=(3,), dtype="float64") v = vector("v", shape=(3,), dtype="float64")
out = alloc(v, dim0, dim1, 3) out = alloc(v, dim0, dim1, 3)
compare_pytorch_and_py( compare_pytorch_and_py(
FunctionGraph([v, dim1], [out]), [v, dim1],
[out],
[np.array([1, 2, 3]), np.array(7)], [np.array([1, 2, 3]), np.array(7)],
) )
...@@ -244,7 +245,8 @@ def test_arange(): ...@@ -244,7 +245,8 @@ def test_arange():
out = arange(start, stop, step, dtype="int16") out = arange(start, stop, step, dtype="int16")
compare_pytorch_and_py( compare_pytorch_and_py(
FunctionGraph([start, stop, step], [out]), [start, stop, step],
[out],
[np.array(1), np.array(10), np.array(2)], [np.array(1), np.array(10), np.array(2)],
) )
...@@ -254,16 +256,18 @@ def test_pytorch_Join(): ...@@ -254,16 +256,18 @@ def test_pytorch_Join():
b = matrix("b") b = matrix("b")
x = ptb.join(0, a, b) x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py( compare_pytorch_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
], ],
) )
compare_pytorch_and_py( compare_pytorch_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX), np.c_[[4.0, 5.0]].astype(config.floatX),
...@@ -271,16 +275,18 @@ def test_pytorch_Join(): ...@@ -271,16 +275,18 @@ def test_pytorch_Join():
) )
x = ptb.join(1, a, b) x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py( compare_pytorch_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
], ],
) )
compare_pytorch_and_py( compare_pytorch_and_py(
x_fg, [a, b],
[x],
[ [
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX), np.c_[[5.0, 6.0]].astype(config.floatX),
...@@ -309,9 +315,8 @@ def test_eye(dtype): ...@@ -309,9 +315,8 @@ def test_eye(dtype):
def test_pytorch_MakeVector(): def test_pytorch_MakeVector():
x = ptb.make_vector(1, 2, 3) x = ptb.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, []) compare_pytorch_and_py([], [x], [])
def test_pytorch_ifelse(): def test_pytorch_ifelse():
...@@ -320,15 +325,13 @@ def test_pytorch_ifelse(): ...@@ -320,15 +325,13 @@ def test_pytorch_ifelse():
a = scalar("a") a = scalar("a")
x = ifelse(a < 0.5, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals])) x = ifelse(a < 0.5, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)
compare_pytorch_and_py(x_fg, np.array([0.2], dtype=config.floatX)) compare_pytorch_and_py([a], x, np.array([0.2], dtype=config.floatX))
a = scalar("a") a = scalar("a")
x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals])) x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)
compare_pytorch_and_py(x_fg, np.array([0.5], dtype=config.floatX)) compare_pytorch_and_py([a], x, np.array([0.5], dtype=config.floatX))
def test_pytorch_OpFromGraph(): def test_pytorch_OpFromGraph():
...@@ -343,8 +346,7 @@ def test_pytorch_OpFromGraph(): ...@@ -343,8 +346,7 @@ def test_pytorch_OpFromGraph():
yv = np.ones((2, 2), dtype=config.floatX) * 3 yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5 zv = np.ones((2, 2), dtype=config.floatX) * 5
f = FunctionGraph([x, y, z], [out]) compare_pytorch_and_py([x, y, z], [out], [xv, yv, zv])
compare_pytorch_and_py(f, [xv, yv, zv])
def test_pytorch_link_references(): def test_pytorch_link_references():
...@@ -380,15 +382,13 @@ def test_pytorch_link_references(): ...@@ -380,15 +382,13 @@ def test_pytorch_link_references():
def test_pytorch_scipy(): def test_pytorch_scipy():
x = vector("a", shape=(3,)) x = vector("a", shape=(3,))
out = expit(x) out = expit(x)
f = FunctionGraph([x], [out]) compare_pytorch_and_py([x], [out], [np.random.rand(3)])
compare_pytorch_and_py(f, [np.random.rand(3)])
def test_pytorch_softplus(): def test_pytorch_softplus():
x = vector("a", shape=(3,)) x = vector("a", shape=(3,))
out = softplus(x) out = softplus(x)
f = FunctionGraph([x], [out]) compare_pytorch_and_py([x], [out], [np.random.rand(3)])
compare_pytorch_and_py(f, [np.random.rand(3)])
def test_ScalarLoop(): def test_ScalarLoop():
...@@ -436,13 +436,15 @@ def test_ScalarLoop_Elemwise_single_carries(): ...@@ -436,13 +436,15 @@ def test_ScalarLoop_Elemwise_single_carries():
x0 = pt.vector("x0", dtype="float32") x0 = pt.vector("x0", dtype="float32")
state, done = op(n_steps, x0) state, done = op(n_steps, x0)
f = FunctionGraph([n_steps, x0], [state, done])
args = [ args = [
np.array(10).astype("int32"), np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"), np.arange(0, 5).astype("float32"),
] ]
compare_pytorch_and_py( compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) [n_steps, x0],
[state, done],
args,
assert_fn=partial(np.testing.assert_allclose, rtol=1e-6),
) )
...@@ -462,14 +464,16 @@ def test_ScalarLoop_Elemwise_multi_carries(): ...@@ -462,14 +464,16 @@ def test_ScalarLoop_Elemwise_multi_carries():
x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1)) x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1))
*states, done = op(n_steps, x0, x1) *states, done = op(n_steps, x0, x1)
f = FunctionGraph([n_steps, x0, x1], [*states, done])
args = [ args = [
np.array(10).astype("int32"), np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"), np.arange(0, 5).astype("float32"),
np.random.rand(7, 3, 1).astype("float32"), np.random.rand(7, 3, 1).astype("float32"),
] ]
compare_pytorch_and_py( compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) [n_steps, x0, x1],
[*states, done],
args,
assert_fn=partial(np.testing.assert_allclose, rtol=1e-6),
) )
...@@ -518,6 +522,5 @@ def test_Split(n_splits, axis, values, sizes): ...@@ -518,6 +522,5 @@ def test_Split(n_splits, axis, values, sizes):
assert len(g) == n_splits assert len(g) == n_splits
if n_splits == 0: if n_splits == 0:
return return
g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g)
compare_pytorch_and_py(g_fg, [values, sizes]) compare_pytorch_and_py([i, s], g, [values, sizes])
...@@ -2,7 +2,6 @@ import numpy as np ...@@ -2,7 +2,6 @@ import numpy as np
import pytest import pytest
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import blas as pt_blas from pytensor.tensor import blas as pt_blas
from pytensor.tensor.type import tensor3 from pytensor.tensor.type import tensor3
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -15,8 +14,8 @@ def test_pytorch_BatchedDot(): ...@@ -15,8 +14,8 @@ def test_pytorch_BatchedDot():
b = tensor3("b") b = tensor3("b")
b_test = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) b_test = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
out = pt_blas.BatchedDot()(a, b) out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
pytensor_pytorch_fn, _ = compare_pytorch_and_py(fgraph, [a_test, b_test]) pytensor_pytorch_fn, _ = compare_pytorch_and_py([a, b], [out], [a_test, b_test])
# A dimension mismatch should raise a TypeError for compatibility # A dimension mismatch should raise a TypeError for compatibility
inputs = [a_test[:-1], b_test] inputs = [a_test[:-1], b_test]
......
...@@ -5,7 +5,6 @@ import pytensor ...@@ -5,7 +5,6 @@ import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
import pytensor.tensor.math as ptm import pytensor.tensor.math as ptm
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import ScalarOp, get_scalar_type from pytensor.scalar.basic import ScalarOp, get_scalar_type
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
...@@ -20,17 +19,23 @@ def test_pytorch_Dimshuffle(): ...@@ -20,17 +19,23 @@ def test_pytorch_Dimshuffle():
a_pt = matrix("a") a_pt = matrix("a")
x = a_pt.T x = a_pt.T
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) compare_pytorch_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)
x = a_pt.dimshuffle([0, 1, "x"]) x = a_pt.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) compare_pytorch_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)
a_pt = tensor(dtype=config.floatX, shape=(None, 1)) a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = a_pt.dimshuffle((0,)) x = a_pt.dimshuffle((0,))
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) compare_pytorch_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]
)
def test_multiple_input_output(): def test_multiple_input_output():
...@@ -38,24 +43,21 @@ def test_multiple_input_output(): ...@@ -38,24 +43,21 @@ def test_multiple_input_output():
y = vector("y") y = vector("y")
out = pt.mul(x, y) out = pt.mul(x, y)
fg = FunctionGraph(outputs=[out], clone=False) compare_pytorch_and_py([x, y], [out], [[1.5], [2.5]])
compare_pytorch_and_py(fg, [[1.5], [2.5]])
x = vector("x") x = vector("x")
y = vector("y") y = vector("y")
div = pt.int_div(x, y) div = pt.int_div(x, y)
pt_sum = pt.add(y, x) pt_sum = pt.add(y, x)
fg = FunctionGraph(outputs=[div, pt_sum], clone=False) compare_pytorch_and_py([x, y], [div, pt_sum], [[1.5], [2.5]])
compare_pytorch_and_py(fg, [[1.5], [2.5]])
def test_pytorch_elemwise(): def test_pytorch_elemwise():
x = pt.vector("x") x = pt.vector("x")
out = pt.log(1 - x) out = pt.log(1 - x)
fg = FunctionGraph([x], [out]) compare_pytorch_and_py([x], [out], [[0.9, 0.9]])
compare_pytorch_and_py(fg, [[0.9, 0.9]])
@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min]) @pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min])
...@@ -81,9 +83,8 @@ def test_pytorch_careduce(fn, axis): ...@@ -81,9 +83,8 @@ def test_pytorch_careduce(fn, axis):
).astype(config.floatX) ).astype(config.floatX)
x = fn(a_pt, axis=axis) x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [test_value]) compare_pytorch_and_py([a_pt], [x], [test_value])
@pytest.mark.parametrize("fn", [ptm.any, ptm.all]) @pytest.mark.parametrize("fn", [ptm.any, ptm.all])
...@@ -93,9 +94,8 @@ def test_pytorch_any_all(fn, axis): ...@@ -93,9 +94,8 @@ def test_pytorch_any_all(fn, axis):
test_value = np.array([[True, False, True], [False, True, True]]) test_value = np.array([[True, False, True], [False, True, True]])
x = fn(a_pt, axis=axis) x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [test_value]) compare_pytorch_and_py([a_pt], [x], [test_value])
@pytest.mark.parametrize("dtype", ["float64", "int64"]) @pytest.mark.parametrize("dtype", ["float64", "int64"])
...@@ -103,7 +103,6 @@ def test_pytorch_any_all(fn, axis): ...@@ -103,7 +103,6 @@ def test_pytorch_any_all(fn, axis):
def test_softmax(axis, dtype): def test_softmax(axis, dtype):
x = matrix("x", dtype=dtype) x = matrix("x", dtype=dtype)
out = softmax(x, axis=axis) out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)
if dtype == "int64": if dtype == "int64":
...@@ -111,9 +110,9 @@ def test_softmax(axis, dtype): ...@@ -111,9 +110,9 @@ def test_softmax(axis, dtype):
NotImplementedError, NotImplementedError,
match="Pytorch Softmax is not currently implemented for non-float types.", match="Pytorch Softmax is not currently implemented for non-float types.",
): ):
compare_pytorch_and_py(fgraph, [test_input]) compare_pytorch_and_py([x], [out], [test_input])
else: else:
compare_pytorch_and_py(fgraph, [test_input]) compare_pytorch_and_py([x], [out], [test_input])
@pytest.mark.parametrize("dtype", ["float64", "int64"]) @pytest.mark.parametrize("dtype", ["float64", "int64"])
...@@ -121,7 +120,6 @@ def test_softmax(axis, dtype): ...@@ -121,7 +120,6 @@ def test_softmax(axis, dtype):
def test_logsoftmax(axis, dtype): def test_logsoftmax(axis, dtype):
x = matrix("x", dtype=dtype) x = matrix("x", dtype=dtype)
out = log_softmax(x, axis=axis) out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3) test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)
if dtype == "int64": if dtype == "int64":
...@@ -129,9 +127,9 @@ def test_logsoftmax(axis, dtype): ...@@ -129,9 +127,9 @@ def test_logsoftmax(axis, dtype):
NotImplementedError, NotImplementedError,
match="Pytorch LogSoftmax is not currently implemented for non-float types.", match="Pytorch LogSoftmax is not currently implemented for non-float types.",
): ):
compare_pytorch_and_py(fgraph, [test_input]) compare_pytorch_and_py([x], [out], [test_input])
else: else:
compare_pytorch_and_py(fgraph, [test_input]) compare_pytorch_and_py([x], [out], [test_input])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
...@@ -141,16 +139,14 @@ def test_softmax_grad(axis): ...@@ -141,16 +139,14 @@ def test_softmax_grad(axis):
sm = matrix("sm") sm = matrix("sm")
sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3) sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm) out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out]) compare_pytorch_and_py([dy, sm], [out], [dy_value, sm_value])
compare_pytorch_and_py(fgraph, [dy_value, sm_value])
def test_cast(): def test_cast():
x = matrix("x", dtype="float32") x = matrix("x", dtype="float32")
out = pt.cast(x, "int32") out = pt.cast(x, "int32")
fgraph = FunctionGraph([x], [out])
_, [res] = compare_pytorch_and_py( _, [res] = compare_pytorch_and_py(
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)] [x], [out], [np.arange(6, dtype="float32").reshape(2, 3)]
) )
assert res.dtype == np.int32 assert res.dtype == np.int32
......
...@@ -2,7 +2,6 @@ import numpy as np ...@@ -2,7 +2,6 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -31,16 +30,14 @@ def test_pytorch_CumOp(axis, dtype): ...@@ -31,16 +30,14 @@ def test_pytorch_CumOp(axis, dtype):
out = pt.cumprod(a, axis=axis) out = pt.cumprod(a, axis=axis)
else: else:
out = pt.cumsum(a, axis=axis) out = pt.cumsum(a, axis=axis)
# Create a PyTensor `FunctionGraph`
fgraph = FunctionGraph([a], [out])
# Pass the graph and inputs to the testing function # Pass the inputs and outputs to the testing function
compare_pytorch_and_py(fgraph, [test_value]) compare_pytorch_and_py([a], [out], [test_value])
# For the second mode of CumOp # For the second mode of CumOp
out = pt.cumprod(a, axis=axis) out = pt.cumprod(a, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value]) compare_pytorch_and_py([a], [out], [test_value])
@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)]) @pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)])
...@@ -50,8 +47,8 @@ def test_pytorch_Repeat(axis, repeats): ...@@ -50,8 +47,8 @@ def test_pytorch_Repeat(axis, repeats):
test_value = np.arange(6, dtype="float64").reshape((3, 2)) test_value = np.arange(6, dtype="float64").reshape((3, 2))
out = pt.repeat(a, repeats, axis=axis) out = pt.repeat(a, repeats, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value]) compare_pytorch_and_py([a], [out], [test_value])
@pytest.mark.parametrize("axis", [None, 0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1])
...@@ -63,8 +60,8 @@ def test_pytorch_Unique_axis(axis): ...@@ -63,8 +60,8 @@ def test_pytorch_Unique_axis(axis):
) )
out = pt.unique(a, axis=axis) out = pt.unique(a, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value]) compare_pytorch_and_py([a], [out], [test_value])
@pytest.mark.parametrize("return_inverse", [False, True]) @pytest.mark.parametrize("return_inverse", [False, True])
...@@ -86,5 +83,7 @@ def test_pytorch_Unique_params(return_index, return_inverse, return_counts): ...@@ -86,5 +83,7 @@ def test_pytorch_Unique_params(return_index, return_inverse, return_counts):
return_counts=return_counts, return_counts=return_counts,
axis=0, axis=0,
) )
fgraph = FunctionGraph([a], [out[0] if isinstance(out, list) else out])
compare_pytorch_and_py(fgraph, [test_value]) compare_pytorch_and_py(
[a], [out[0] if isinstance(out, list) else out], [test_value]
)
import numpy as np import numpy as np
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.type import matrix, scalar, vector from pytensor.tensor.type import matrix, scalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -20,10 +19,12 @@ def test_pytorch_dot(): ...@@ -20,10 +19,12 @@ def test_pytorch_dot():
# 2D * 2D # 2D * 2D
out = A.dot(A * alpha) + beta * A out = A.dot(A * alpha) + beta * A
fgraph = FunctionGraph([A, alpha, beta], [out])
compare_pytorch_and_py(fgraph, [A_test, alpha_test, beta_test]) compare_pytorch_and_py([A, alpha, beta], [out], [A_test, alpha_test, beta_test])
# 1D * 2D and 1D * 1D # 1D * 2D and 1D * 1D
out = y.dot(alpha * A).dot(x) + beta * y out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_pytorch_and_py(fgraph, [y_test, x_test, A_test, alpha_test, beta_test]) compare_pytorch_and_py(
[y, x, A, alpha, beta], [out], [y_test, x_test, A_test, alpha_test, beta_test]
)
from collections.abc import Sequence
import numpy as np import numpy as np
import pytest import pytest
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg as pt_nla from pytensor.tensor import nlinalg as pt_nla
from pytensor.tensor.type import matrix from pytensor.tensor.type import matrix
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -29,13 +26,12 @@ def matrix_test(): ...@@ -29,13 +26,12 @@ def matrix_test():
def test_lin_alg_no_params(func, matrix_test): def test_lin_alg_no_params(func, matrix_test):
x, test_value = matrix_test x, test_value = matrix_test
out = func(x) outs = func(x)
out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out])
def assert_fn(x, y): def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3) np.testing.assert_allclose(x, y, rtol=1e-3)
compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn) compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -50,8 +46,8 @@ def test_lin_alg_no_params(func, matrix_test): ...@@ -50,8 +46,8 @@ def test_lin_alg_no_params(func, matrix_test):
def test_qr(mode, matrix_test): def test_qr(mode, matrix_test):
x, test_value = matrix_test x, test_value = matrix_test
outs = pt_nla.qr(x, mode=mode) outs = pt_nla.qr(x, mode=mode)
out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs])
compare_pytorch_and_py(out_fg, [test_value]) compare_pytorch_and_py([x], outs, [test_value])
@pytest.mark.parametrize("compute_uv", [True, False]) @pytest.mark.parametrize("compute_uv", [True, False])
...@@ -60,18 +56,16 @@ def test_svd(compute_uv, full_matrices, matrix_test): ...@@ -60,18 +56,16 @@ def test_svd(compute_uv, full_matrices, matrix_test):
x, test_value = matrix_test x, test_value = matrix_test
out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
compare_pytorch_and_py(out_fg, [test_value]) compare_pytorch_and_py([x], out, [test_value])
def test_pinv(): def test_pinv():
x = matrix("x") x = matrix("x")
x_inv = pt_nla.pinv(x) x_inv = pt_nla.pinv(x)
fgraph = FunctionGraph([x], [x_inv])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_pytorch_and_py(fgraph, [x_np]) compare_pytorch_and_py([x], [x_inv], [x_np])
@pytest.mark.parametrize("hermitian", [False, True]) @pytest.mark.parametrize("hermitian", [False, True])
...@@ -106,8 +100,7 @@ def test_kron(): ...@@ -106,8 +100,7 @@ def test_kron():
y = matrix("y") y = matrix("y")
z = pt_nla.kron(x, y) z = pt_nla.kron(x, y)
fgraph = FunctionGraph([x, y], [z])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_pytorch_and_py(fgraph, [x_np, y_np]) compare_pytorch_and_py([x, y], [z], [x_np, y_np])
...@@ -2,7 +2,6 @@ import numpy as np ...@@ -2,7 +2,6 @@ import numpy as np
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import iscalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -11,29 +10,27 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py ...@@ -11,29 +10,27 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_pytorch_shape_ops(): def test_pytorch_shape_ops():
x_np = np.zeros((20, 3)) x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np)) x = Shape()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, [], must_be_device_array=False) compare_pytorch_and_py([], [x], [])
x = Shape_i(1)(pt.as_tensor_variable(x_np)) x = Shape_i(1)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, [], must_be_device_array=False) compare_pytorch_and_py([], [x], [])
def test_pytorch_specify_shape(): def test_pytorch_specify_shape():
in_pt = pt.matrix("in") in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None)) x = pt.specify_shape(in_pt, (4, None))
x_fg = FunctionGraph([in_pt], [x]) compare_pytorch_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)])
compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
# When used to assert two arrays have similar shapes # When used to assert two arrays have similar shapes
in_pt = pt.matrix("in") in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape") shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape) x = pt.specify_shape(in_pt, shape_pt.shape)
x_fg = FunctionGraph([in_pt, shape_pt], [x])
compare_pytorch_and_py( compare_pytorch_and_py(
x_fg, [in_pt, shape_pt],
[x],
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
) )
...@@ -41,21 +38,22 @@ def test_pytorch_specify_shape(): ...@@ -41,21 +38,22 @@ def test_pytorch_specify_shape():
def test_pytorch_Reshape_constant(): def test_pytorch_Reshape_constant():
a = vector("a") a = vector("a")
x = reshape(a, (2, 2)) x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) compare_pytorch_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
def test_pytorch_Reshape_dynamic(): def test_pytorch_Reshape_dynamic():
a = vector("a") a = vector("a")
shape_pt = iscalar("b") shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt)) x = reshape(a, (shape_pt, shape_pt))
x_fg = FunctionGraph([a, shape_pt], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) compare_pytorch_and_py(
[a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
)
def test_pytorch_unbroadcast(): def test_pytorch_unbroadcast():
x_np = np.zeros((20, 1, 1)) x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np)) x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, []) compare_pytorch_and_py([], [x], [])
import numpy as np import numpy as np
import pytest import pytest
from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix from pytensor.tensor import matrix
from pytensor.tensor.sort import argsort, sort from pytensor.tensor.sort import argsort, sort
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -12,6 +11,5 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py ...@@ -12,6 +11,5 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_sort(func, axis): def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64") x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis) out = func(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]]) arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_pytorch_and_py(fgraph, [arr]) compare_pytorch_and_py([x], [out], [arr])
...@@ -6,7 +6,6 @@ import pytest ...@@ -6,7 +6,6 @@ import pytest
import pytensor.scalar as ps import pytensor.scalar as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import inc_subtensor, set_subtensor from pytensor.tensor import inc_subtensor, set_subtensor
from pytensor.tensor import subtensor as pt_subtensor from pytensor.tensor import subtensor as pt_subtensor
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -19,38 +18,33 @@ def test_pytorch_Subtensor(): ...@@ -19,38 +18,33 @@ def test_pytorch_Subtensor():
out_pt = x_pt[1, 2, 0] out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:, 1, :] out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[:2, 1, :] out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[1:2, 1, :] out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
compare_pytorch_and_py(out_fg, [x_np])
# symbolic index # symbolic index
a_pt = ps.int64("a") a_pt = ps.int64("a")
a_np = 1 a_np = 1
out_pt = x_pt[a_pt, 2, a_pt:2] out_pt = x_pt[a_pt, 2, a_pt:2]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt, a_pt], [out_pt]) compare_pytorch_and_py([x_pt, a_pt], [out_pt], [x_np, a_np])
compare_pytorch_and_py(out_fg, [x_np, a_np])
with pytest.raises( with pytest.raises(
NotImplementedError, match="Negative step sizes are not supported in Pytorch" NotImplementedError, match="Negative step sizes are not supported in Pytorch"
): ):
out_pt = x_pt[::-1] out_pt = x_pt[::-1]
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
compare_pytorch_and_py(out_fg, [x_np])
def test_pytorch_AdvSubtensor(): def test_pytorch_AdvSubtensor():
...@@ -60,52 +54,43 @@ def test_pytorch_AdvSubtensor(): ...@@ -60,52 +54,43 @@ def test_pytorch_AdvSubtensor():
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2]) out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], [2, 3]] out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], 1:] out_pt = x_pt[[1, 2], 1:]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], :, [3, 4]] out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
compare_pytorch_and_py(out_fg, [x_np])
out_pt = x_pt[[1, 2], None] out_pt = x_pt[[1, 2], None]
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
compare_pytorch_and_py(out_fg, [x_np])
a_pt = ps.int64("a") a_pt = ps.int64("a")
a_np = 2 a_np = 2
out_pt = x_pt[[1, a_pt], a_pt] out_pt = x_pt[[1, a_pt], a_pt]
out_fg = FunctionGraph([x_pt, a_pt], [out_pt]) compare_pytorch_and_py([x_pt, a_pt], [out_pt], [x_np, a_np])
compare_pytorch_and_py(out_fg, [x_np, a_np])
# boolean indices # boolean indices
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)] out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)]
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
compare_pytorch_and_py(out_fg, [x_np])
a_pt = pt.tensor3("a", dtype="bool") a_pt = pt.tensor3("a", dtype="bool")
a_np = np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool) a_np = np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)
out_pt = x_pt[a_pt] out_pt = x_pt[a_pt]
out_fg = FunctionGraph([x_pt, a_pt], [out_pt]) compare_pytorch_and_py([x_pt, a_pt], [out_pt], [x_np, a_np])
compare_pytorch_and_py(out_fg, [x_np, a_np])
with pytest.raises( with pytest.raises(
NotImplementedError, match="Negative step sizes are not supported in Pytorch" NotImplementedError, match="Negative step sizes are not supported in Pytorch"
): ):
out_pt = x_pt[[1, 2], ::-1] out_pt = x_pt[[1, 2], ::-1]
out_fg = FunctionGraph([x_pt], [out_pt])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_pytorch_and_py(out_fg, [x_np]) compare_pytorch_and_py([x_pt], [out_pt], [x_np])
@pytest.mark.parametrize("subtensor_op", [set_subtensor, inc_subtensor]) @pytest.mark.parametrize("subtensor_op", [set_subtensor, inc_subtensor])
...@@ -116,20 +101,17 @@ def test_pytorch_IncSubtensor(subtensor_op): ...@@ -116,20 +101,17 @@ def test_pytorch_IncSubtensor(subtensor_op):
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = subtensor_op(x_pt[1, 2, 3], st_pt) out_pt = subtensor_op(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_test])
compare_pytorch_and_py(out_fg, [x_test])
# Test different type update # Test different type update
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32")) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32"))
out_pt = subtensor_op(x_pt[:2, 0, 0], st_pt) out_pt = subtensor_op(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_test])
compare_pytorch_and_py(out_fg, [x_test])
out_pt = subtensor_op(x_pt[0, 1:3, 0], st_pt) out_pt = subtensor_op(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_test])
compare_pytorch_and_py(out_fg, [x_test])
def inc_subtensor_ignore_duplicates(x, y): def inc_subtensor_ignore_duplicates(x, y):
...@@ -150,14 +132,12 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op): ...@@ -150,14 +132,12 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
) )
out_pt = advsubtensor_op(x_pt[np.r_[0, 2]], st_pt) out_pt = advsubtensor_op(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_test])
compare_pytorch_and_py(out_fg, [x_test])
# Repeated indices # Repeated indices
out_pt = advsubtensor_op(x_pt[np.r_[0, 0]], st_pt) out_pt = advsubtensor_op(x_pt[np.r_[0, 0]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_test])
compare_pytorch_and_py(out_fg, [x_test])
# Mixing advanced and basic indexing # Mixing advanced and basic indexing
if advsubtensor_op is inc_subtensor: if advsubtensor_op is inc_subtensor:
...@@ -168,19 +148,16 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op): ...@@ -168,19 +148,16 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
st_pt = pt.as_tensor_variable(x_test[[0, 2], 0, :3]) st_pt = pt.as_tensor_variable(x_test[[0, 2], 0, :3])
out_pt = advsubtensor_op(x_pt[[0, 0], 0, :3], st_pt) out_pt = advsubtensor_op(x_pt[[0, 0], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
with expectation: with expectation:
compare_pytorch_and_py(out_fg, [x_test]) compare_pytorch_and_py([x_pt], [out_pt], [x_test])
# Test different dtype update # Test different dtype update
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32")) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32"))
out_pt = advsubtensor_op(x_pt[[0, 2], 0, 0], st_pt) out_pt = advsubtensor_op(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_test])
compare_pytorch_and_py(out_fg, [x_test])
# Boolean indices # Boolean indices
out_pt = advsubtensor_op(x_pt[x_pt > 5], 1.0) out_pt = advsubtensor_op(x_pt[x_pt > 5], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py([x_pt], [out_pt], [x_test])
compare_pytorch_and_py(out_fg, [x_test])
...@@ -63,11 +63,6 @@ from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH ...@@ -63,11 +63,6 @@ from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
from tests import unittest_tools as utt from tests import unittest_tools as utt
def set_test_value(x, v):
x.tag.test_value = v
return x
def test_cpu_contiguous(): def test_cpu_contiguous():
a = fmatrix("a") a = fmatrix("a")
i = iscalar("i") i = iscalar("i")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论