提交 cc8c4992 authored 作者: Adv's avatar Adv 提交者: Ricardo Vieira

Stop using FunctionGraph and tag.test_value in linker tests

上级 51ea1a0b
......@@ -7,12 +7,12 @@ import pytest
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import JAX, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op
......@@ -34,25 +34,28 @@ py_mode = Mode(linker="py", optimizer=None)
def compare_jax_and_py(
fgraph: FunctionGraph,
graph_inputs: Iterable[Variable],
graph_outputs: Variable | Iterable[Variable],
test_inputs: Iterable,
*,
assert_fn: Callable | None = None,
must_be_device_array: bool = True,
jax_mode=jax_mode,
py_mode=py_mode,
):
"""Function to compare python graph output and jax compiled output for testing equality
"""Function to compare python function output and jax compiled output for testing equality
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
this function which then compiles the graphs in both jax and python, runs the calculation
in both and checks if the results are the same
The inputs and outputs are then passed to this function which then compiles the given function in both
jax and python, runs the calculation in both and checks if the results are the same
Parameters
----------
fgraph: FunctionGraph
PyTensor function Graph object
graph_inputs:
Symbolic inputs to the graph
outputs:
Symbolic outputs of the graph
test_inputs: iter
Numerical inputs for testing the function graph
Numerical inputs for testing the function.
assert_fn: func, opt
Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose
......@@ -68,8 +71,10 @@ def compare_jax_and_py(
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
pytensor_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
if any(inp.owner is not None for inp in graph_inputs):
raise ValueError("Inputs must be root variables")
pytensor_jax_fn = function(graph_inputs, graph_outputs, mode=jax_mode)
jax_res = pytensor_jax_fn(*test_inputs)
if must_be_device_array:
......@@ -78,10 +83,10 @@ def compare_jax_and_py(
else:
assert isinstance(jax_res, jax.Array)
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1:
if isinstance(graph_outputs, list | tuple):
for j, p in zip(jax_res, py_res, strict=True):
assert_fn(j, p)
else:
......@@ -187,16 +192,14 @@ def test_jax_ifelse():
false_vals = np.r_[-1, -2, -3]
x = ifelse(np.array(True), true_vals, false_vals)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
compare_jax_and_py([], [x], [])
a = dscalar("a")
a.tag.test_value = np.array(0.2, dtype=config.floatX)
a_test = np.array(0.2, dtype=config.floatX)
x = ifelse(a < 0.5, true_vals, false_vals)
x_fg = FunctionGraph([a], [x]) # I.e. False
compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
compare_jax_and_py([a], [x], [a_test])
def test_jax_checkandraise():
......@@ -209,11 +212,6 @@ def test_jax_checkandraise():
function((p,), res, mode=jax_mode)
def set_test_value(x, v):
x.tag.test_value = v
return x
def test_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
......@@ -221,10 +219,9 @@ def test_OpFromGraph():
o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
out_fg = FunctionGraph([x, y, z], [out])
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_jax_and_py(out_fg, [xv, yv, zv])
compare_jax_and_py([x, y, z], [out], [xv, yv, zv])
......@@ -4,8 +4,6 @@ import pytest
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas
......@@ -16,21 +14,20 @@ from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
a.tag.test_value = (
a_test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
)
b = tensor3("b")
b.tag.test_value = (
b_test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
)
out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([a, b], [out], [a_test_value, b_test_value])
# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
inputs = [a_test_value[:-1], b_test_value]
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
pytensor_jax_fn = function([a, b], [out], mode=jax_mode)
with pytest.raises(TypeError):
pytensor_jax_fn(*inputs)
......@@ -2,7 +2,6 @@ import numpy as np
import pytest
from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot, matmul
......@@ -32,8 +31,7 @@ def test_matmul(matmul_op):
out = matmul_op(a, b)
assert isinstance(out.owner.op, Blockwise)
fg = FunctionGraph([a, b], [out])
fn, _ = compare_jax_and_py(fg, test_values)
fn, _ = compare_jax_and_py([a, b], [out], test_values)
# Check we are not adding any unnecessary stuff
jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))
......
......@@ -2,7 +2,6 @@ import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -22,8 +21,7 @@ def test_jax_einsum():
}
x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
fg = FunctionGraph([x_pt, y_pt, z_pt], [out])
compare_jax_and_py(fg, [x, y, z])
compare_jax_and_py([x_pt, y_pt, z_pt], [out], [x, y, z])
def test_ellipsis_einsum():
......@@ -34,5 +32,4 @@ def test_ellipsis_einsum():
x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt)
fg = FunctionGraph([x_pt, y_pt], [out])
compare_jax_and_py(fg, [x, y])
compare_jax_and_py([x_pt, y_pt], [out], [x, y])
......@@ -6,8 +6,6 @@ import pytensor
import pytensor.tensor as pt
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import prod
......@@ -26,22 +24,22 @@ def test_jax_Dimshuffle():
a_pt = matrix("a")
x = a_pt.T
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)
x = a_pt.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = a_pt.dimshuffle((0,))
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle(input_ndim=2, new_order=(0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_jax_CAReduce():
......@@ -49,64 +47,58 @@ def test_jax_CAReduce():
a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
x = pt_sum(a_pt, axis=None)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)])
a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = pt_sum(a_pt, axis=0)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = pt_sum(a_pt, axis=1)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = prod(a_pt, axis=0)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = pt_all(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
compare_jax_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([x], [out], [x_test_value])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([x], [out], [x_test_value])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis):
dy = matrix("dy")
dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm")
sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([dy, sm], [out], [dy_test_value, sm_test_value])
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
......@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
def test_multiple_input_multiply():
x, y, z = vectors("xyz")
out = pt.mul(x, y, z)
fg = FunctionGraph(outputs=[out], clone=False)
compare_jax_and_py(fg, [[1.5], [2.5], [3.5]])
compare_jax_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]])
......@@ -3,8 +3,6 @@ import pytest
import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.sort import argsort
from pytensor.tensor.type import matrix, tensor
......@@ -19,57 +17,45 @@ def test_extra_ops():
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.cumsum(a, axis=0)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])
compare_jax_and_py([a], [out], [a_test])
out = pt_extra_ops.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])
compare_jax_and_py([a], [out], [a_test])
out = pt_extra_ops.diff(a, n=2, axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])
compare_jax_and_py([a], [out], [a_test])
out = pt_extra_ops.repeat(a, (3, 3), axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])
compare_jax_and_py([a], [out], [a_test])
c = ptb.as_tensor(5)
out = pt_extra_ops.fill_diagonal(a, c)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])
compare_jax_and_py([a], [out], [a_test])
with pytest.raises(NotImplementedError):
out = pt_extra_ops.fill_diagonal_offset(a, c, c)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])
compare_jax_and_py([a], [out], [a_test])
with pytest.raises(NotImplementedError):
out = pt_extra_ops.Unique(axis=1)(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])
compare_jax_and_py([a], [out], [a_test])
indices = np.arange(np.prod((3, 4)))
out = pt_extra_ops.unravel_index(indices, (3, 4), order="C")
fgraph = FunctionGraph([], out)
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
compare_jax_and_py([], out, [], must_be_device_array=False)
v = ptb.as_tensor_variable(6.0)
sorted_idx = argsort(a.ravel())
out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])
compare_jax_and_py([a], [out], [a_test])
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_bartlett_dynamic_shape():
c = tensor(shape=(), dtype=int)
out = pt_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [np.array(5)])
compare_jax_and_py([], [out], [np.array(5)])
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
......@@ -79,8 +65,7 @@ def test_ravel_multi_index_dynamic_shape():
x = tensor(shape=(None,), dtype=int)
y = tensor(shape=(None,), dtype=int)
out = pt_extra_ops.ravel_multi_index((x, y), (3, 4))
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [x_test, y_test])
compare_jax_and_py([], [out], [x_test, y_test])
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
......@@ -89,5 +74,4 @@ def test_unique_dynamic_shape():
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.Unique()(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])
compare_jax_and_py([a], [out], [a_test])
......@@ -2,8 +2,6 @@ import numpy as np
import pytest
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.math import Argmax, Max, maximum
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, vector
......@@ -20,33 +18,39 @@ def test_jax_max_and_argmax():
mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]])
compare_jax_and_py([x], [out], [np.r_[1, 2]])
def test_dot():
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
y_test_value = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
x_test_value = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
A_test_value = np.empty((2, 2), dtype=config.floatX)
alpha = scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
alpha_test_value = np.array(3.0, dtype=config.floatX)
beta = scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
beta_test_value = np.array(5.0, dtype=config.floatX)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(
[y, x, A, alpha, beta],
out,
[
y_test_value,
x_test_value,
A_test_value,
alpha_test_value,
beta_test_value,
],
)
out = maximum(y, x)
fgraph = FunctionGraph([y, x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([y, x], [out], [y_test_value, x_test_value])
out = pt_max(y)
fgraph = FunctionGraph([y], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([y], [out], [y_test_value])
......@@ -3,7 +3,6 @@ import pytest
from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor.type import matrix
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -21,41 +20,34 @@ def test_jax_basic_multiout():
x = matrix("x")
outs = pt_nlinalg.eig(x)
out_fg = FunctionGraph([x], outs)
def assert_fn(x, y):
np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.eigh(x)
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.qr(x, mode="full")
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.qr(x, mode="reduced")
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.svd(x)
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = pt_nlinalg.slogdet(x)
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
def test_pinv():
x = matrix("x")
x_inv = pt_nlinalg.pinv(x)
fgraph = FunctionGraph([x], [x_inv])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_jax_and_py(fgraph, [x_np])
compare_jax_and_py([x], [x_inv], [x_np])
def test_pinv_hermitian():
......@@ -94,8 +86,7 @@ def test_kron():
y = matrix("y")
z = pt_nlinalg.kron(x, y)
fgraph = FunctionGraph([x, y], [z])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_jax_and_py(fgraph, [x_np, y_np])
compare_jax_and_py([x, y], [z], [x_np, y_np])
......@@ -3,7 +3,6 @@ import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor.pad import PadMode
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -53,10 +52,10 @@ def test_jax_pad(mode: PadMode, kwargs):
x = np.random.normal(size=(3, 3))
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
res_fg = FunctionGraph([x_pt], [res])
compare_jax_and_py(
res_fg,
[x_pt],
[res],
[x],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
py_mode="FAST_RUN",
......
......@@ -7,13 +7,11 @@ import pytensor.tensor as pt
import pytensor.tensor.random.basic as ptr
from pytensor import clone_replace
from pytensor.compile.function import function
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.compile.sharedvalue import shared
from pytensor.tensor.random.basic import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.random.utils import RandomStream
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode, set_test_value
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode
from tests.tensor.random.test_basic import (
batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester,
......@@ -147,11 +145,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.beta,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -163,11 +161,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.cauchy,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -179,7 +177,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.exponential,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
......@@ -191,11 +189,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr._gamma,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([0.5, 3.0], dtype=np.float64),
),
......@@ -207,11 +205,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.gumbel,
[
set_test_value(
(
pt.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -223,8 +221,8 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.laplace,
[
set_test_value(pt.dvector(), np.array([1.0, 2.0], dtype=np.float64)),
set_test_value(pt.dscalar(), np.array(1.0, dtype=np.float64)),
(pt.dvector(), np.array([1.0, 2.0], dtype=np.float64)),
(pt.dscalar(), np.array(1.0, dtype=np.float64)),
],
(2,),
"laplace",
......@@ -233,11 +231,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.logistic,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -249,11 +247,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.lognormal,
[
set_test_value(
(
pt.lvector(),
np.array([0, 0], dtype=np.int64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -265,11 +263,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.normal,
[
set_test_value(
(
pt.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -281,11 +279,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.pareto,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([2.0, 10.0], dtype=np.float64),
),
......@@ -297,7 +295,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.poisson,
[
set_test_value(
(
pt.dvector(),
np.array([100000.0, 200000.0], dtype=np.float64),
),
......@@ -309,11 +307,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.integers,
[
set_test_value(
(
pt.lscalar(),
np.array(0, dtype=np.int64),
),
set_test_value( # high-value necessary since test on cdf
( # high-value necessary since test on cdf
pt.lscalar(),
np.array(1000, dtype=np.int64),
),
......@@ -332,15 +330,15 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.t,
[
set_test_value(
(
pt.dscalar(),
np.array(2.0, dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -352,11 +350,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.uniform,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1000.0, dtype=np.float64),
),
......@@ -368,11 +366,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.halfnormal,
[
set_test_value(
(
pt.dvector(),
np.array([-1.0, 200.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1000.0, dtype=np.float64),
),
......@@ -384,11 +382,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.invgamma,
[
set_test_value(
(
pt.dvector(),
np.array([10.4, 2.8], dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([3.4, 7.3], dtype=np.float64),
),
......@@ -400,7 +398,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.chisquare,
[
set_test_value(
(
pt.dvector(),
np.array([2.4, 4.9], dtype=np.float64),
),
......@@ -412,15 +410,15 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.gengamma,
[
set_test_value(
(
pt.dvector(),
np.array([10.4, 2.8], dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([3.4, 7.3], dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([0.9, 2.0], dtype=np.float64),
),
......@@ -432,11 +430,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
ptr.wald,
[
set_test_value(
(
pt.dvector(),
np.array([10.4, 2.8], dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([4.5, 2.0], dtype=np.float64),
),
......@@ -449,11 +447,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
pytest.param(
ptr.vonmises,
[
set_test_value(
(
pt.dvector(),
np.array([-0.5, 1.3], dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([5.5, 13.0], dtype=np.float64),
),
......@@ -478,20 +476,16 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
The transpiled `RandomVariable` `Op`.
dist_params
The parameters passed to the op.
"""
dist_params, test_values = (
zip(*dist_params, strict=True) if dist_params else ([], [])
)
rng = shared(np.random.default_rng(29403))
g = rv_op(*dist_params, size=(10000, *base_size), rng=rng)
g_fn = compile_random_function(dist_params, g, mode=jax_mode)
samples = g_fn(
*[
i.tag.test_value
for i in g_fn.maker.fgraph.inputs
if not isinstance(i, SharedVariable | Constant)
]
)
samples = g_fn(*test_values)
bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_params])
bcast_dist_args = np.broadcast_arrays(*test_values)
for idx in np.ndindex(*base_size):
cdf_params = params_conv(*(arg[idx] for arg in bcast_dist_args))
......@@ -775,13 +769,12 @@ def test_random_unimplemented():
nonexistentrv = NonExistentRV()
rng = shared(np.random.default_rng(123))
out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
with pytest.raises(NotImplementedError):
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
compare_jax_and_py(fgraph, [])
compare_jax_and_py([], [out], [])
def test_random_custom_implementation():
......@@ -810,11 +803,10 @@ def test_random_custom_implementation():
nonexistentrv = CustomRV()
rng = shared(np.random.default_rng(123))
out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
compare_jax_and_py(fgraph, [])
compare_jax_and_py([], [out], [])
def test_random_concrete_shape():
......
......@@ -5,7 +5,6 @@ import pytensor.scalar.basic as ps
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.scalar.basic import Composite
from pytensor.tensor import as_tensor
from pytensor.tensor.elemwise import Elemwise
......@@ -51,20 +50,19 @@ def test_second():
b = scalar("b")
out = ps.second(a0, b)
fgraph = FunctionGraph([a0, b], [out])
compare_jax_and_py(fgraph, [10.0, 5.0])
compare_jax_and_py([a0, b], [out], [10.0, 5.0])
a1 = vector("a1")
out = pt.second(a1, b)
fgraph = FunctionGraph([a1, b], [out])
compare_jax_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0])
compare_jax_and_py([a1, b], [out], [np.zeros([5], dtype=config.floatX), 5.0])
a2 = matrix("a2", shape=(1, None), dtype="float64")
b2 = matrix("b2", shape=(None, 1), dtype="int32")
out = pt.second(a2, b2)
fgraph = FunctionGraph([a2, b2], [out])
compare_jax_and_py(
fgraph, [np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")]
[a2, b2],
[out],
[np.zeros((1, 3), dtype="float64"), np.ones((5, 1), dtype="int32")],
)
......@@ -81,11 +79,10 @@ def test_second_constant_scalar():
def test_identity():
a = scalar("a")
a.tag.test_value = 10
a_test_value = 10
out = ps.identity(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([a], [out], [a_test_value])
@pytest.mark.parametrize(
......@@ -109,13 +106,11 @@ def test_jax_Composite_singe_output(x, y, x_val, y_val):
out = comp_op(x, y)
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [
x_val.astype(config.floatX),
y_val.astype(config.floatX),
]
_ = compare_jax_and_py(out_fg, test_input_vals)
_ = compare_jax_and_py([x, y], [out], test_input_vals)
def test_jax_Composite_multi_output():
......@@ -124,32 +119,28 @@ def test_jax_Composite_multi_output():
x_s = ps.float64("xs")
outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x)
fgraph = FunctionGraph([x], outs)
compare_jax_and_py(fgraph, [np.arange(10, dtype=config.floatX)])
compare_jax_and_py([x], outs, [np.arange(10, dtype=config.floatX)])
def test_erf():
x = scalar("x")
out = erf(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0])
compare_jax_and_py([x], [out], [1.0])
def test_erfc():
x = scalar("x")
out = erfc(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0])
compare_jax_and_py([x], [out], [1.0])
def test_erfinv():
x = scalar("x")
out = erfinv(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [0.95])
compare_jax_and_py([x], [out], [0.95])
@pytest.mark.parametrize(
......@@ -166,8 +157,7 @@ def test_tfp_ops(op, test_values):
inputs = [as_tensor(test_value).type() for test_value in test_values]
output = op(*inputs)
fg = FunctionGraph(inputs, [output])
compare_jax_and_py(fg, test_values)
compare_jax_and_py(inputs, [output], test_values)
def test_betaincinv():
......@@ -175,9 +165,10 @@ def test_betaincinv():
b = vector("b", dtype="float64")
x = vector("x", dtype="float64")
out = betaincinv(a, b, x)
fg = FunctionGraph([a, b, x], [out])
compare_jax_and_py(
fg,
[a, b, x],
[out],
[
np.array([5.5, 7.0]),
np.array([5.5, 7.0]),
......@@ -190,39 +181,40 @@ def test_gammaincinv():
k = vector("k", dtype="float64")
x = vector("x", dtype="float64")
out = gammaincinv(k, x)
fg = FunctionGraph([k, x], [out])
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
compare_jax_and_py([k, x], [out], [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
def test_gammainccinv():
k = vector("k", dtype="float64")
x = vector("x", dtype="float64")
out = gammainccinv(k, x)
fg = FunctionGraph([k, x], [out])
compare_jax_and_py(fg, [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
compare_jax_and_py([k, x], [out], [np.array([5.5, 7.0]), np.array([0.25, 0.7])])
def test_psi():
x = scalar("x")
out = psi(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [3.0])
compare_jax_and_py([x], [out], [3.0])
def test_tri_gamma():
x = vector("x", dtype="float64")
out = tri_gamma(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [np.array([3.0, 5.0])])
compare_jax_and_py([x], [out], [np.array([3.0, 5.0])])
def test_polygamma():
n = vector("n", dtype="int32")
x = vector("x", dtype="float32")
out = polygamma(n, x)
fg = FunctionGraph([n, x], [out])
compare_jax_and_py(
fg,
[n, x],
[out],
[
np.array([0, 1, 2]).astype("int32"),
np.array([0.5, 0.9, 2.5]).astype("float32"),
......@@ -233,41 +225,34 @@ def test_polygamma():
def test_log1mexp():
x = vector("x")
out = log1mexp(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]])
compare_jax_and_py([x], [out], [[-1.0, -0.75, -0.5, -0.25]])
def test_nnet():
x = vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
x_test_value = np.r_[1.0, 2.0].astype(config.floatX)
out = sigmoid(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([x], [out], [x_test_value])
out = softplus(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([x], [out], [x_test_value])
def test_jax_variadic_Scalar():
mu = vector("mu", dtype=config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX)
mu_test_value = np.r_[0.1, 1.1].astype(config.floatX)
tau = vector("tau", dtype=config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
tau_test_value = np.r_[1.0, 2.0].astype(config.floatX)
res = -tau * mu
fgraph = FunctionGraph([mu, tau], [res])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([mu, tau], [res], [mu_test_value, tau_test_value])
res = -tau * (tau - mu) ** 2
fgraph = FunctionGraph([mu, tau], [res])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([mu, tau], [res], [mu_test_value, tau_test_value])
def test_add_scalars():
......@@ -275,8 +260,7 @@ def test_add_scalars():
size = x.shape[0] + x.shape[0] + x.shape[1]
out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)])
compare_jax_and_py([x], [out], [np.ones((2, 3)).astype(config.floatX)])
def test_mul_scalars():
......@@ -284,8 +268,7 @@ def test_mul_scalars():
size = x.shape[0] * x.shape[0] * x.shape[1]
out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.ones((2, 3)).astype(config.floatX)])
compare_jax_and_py([x], [out], [np.ones((2, 3)).astype(config.floatX)])
def test_div_scalars():
......@@ -293,8 +276,7 @@ def test_div_scalars():
size = x.shape[0] // x.shape[1]
out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)])
compare_jax_and_py([x], [out], [np.ones((12, 3)).astype(config.floatX)])
def test_mod_scalars():
......@@ -302,39 +284,43 @@ def test_mod_scalars():
size = x.shape[0] % x.shape[1]
out = pt.ones(size).astype(config.floatX)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.ones((12, 3)).astype(config.floatX)])
compare_jax_and_py([x], [out], [np.ones((12, 3)).astype(config.floatX)])
def test_jax_multioutput():
x = vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
x_test_value = np.r_[1.0, 2.0].astype(config.floatX)
y = vector("y")
y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
y_test_value = np.r_[3.0, 4.0].astype(config.floatX)
w = cosh(x**2 + y / 3.0)
v = cosh(x / 3.0 + y**2)
fgraph = FunctionGraph([x, y], [w, v])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([x, y], [w, v], [x_test_value, y_test_value])
def test_jax_logp():
mu = vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX)
mu_test_value = np.r_[0.0, 0.0].astype(config.floatX)
tau = vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX)
tau_test_value = np.r_[1.0, 1.0].astype(config.floatX)
sigma = vector("sigma")
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX)
sigma_test_value = (1.0 / tau_test_value).astype(config.floatX)
value = vector("value")
value.tag.test_value = np.r_[0.1, -10].astype(config.floatX)
value_test_value = np.r_[0.1, -10].astype(config.floatX)
logp = (-tau * (value - mu) ** 2 + log(tau / np.pi / 2.0)) / 2.0
conditions = [sigma > 0]
alltrue = pt_all([pt_all(1 * val) for val in conditions])
normal_logp = pt.switch(alltrue, logp, -np.inf)
fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(
[mu, tau, sigma, value],
[normal_logp],
[
mu_test_value,
tau_test_value,
sigma_test_value,
value_test_value,
],
)
......@@ -7,7 +7,6 @@ import pytensor.tensor as pt
from pytensor import function, shared
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.scan import until
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
......@@ -30,9 +29,8 @@ def test_scan_sit_sot(view):
)
if view:
xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.e]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
compare_jax_and_py([x0], [xs], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
......@@ -45,9 +43,8 @@ def test_scan_mit_sot(view):
)
if view:
xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.full((3,), np.e)]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
compare_jax_and_py([x0], [xs], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)])
......@@ -72,9 +69,8 @@ def test_scan_multiple_mit_sot(view_x, view_y):
if view_y:
ys = ys[view_y]
fg = FunctionGraph([x0, y0], [xs, ys])
test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
compare_jax_and_py([x0, y0], [xs, ys], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)])
......@@ -90,12 +86,11 @@ def test_scan_nit_sot(view):
)
if view:
ys = ys[view]
fg = FunctionGraph([xs], [ys])
test_input_vals = [rng.normal(size=10)]
# We need to remove pushout rewrites, or the whole scan would just be
# converted to an Elemwise on xs
jax_fn, _ = compare_jax_and_py(
fg, test_input_vals, jax_mode=get_mode("JAX").excluding("scan_pushout")
[xs], [ys], test_input_vals, jax_mode=get_mode("JAX").excluding("scan_pushout")
)
scan_nodes = [
node for node in jax_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
......@@ -112,8 +107,7 @@ def test_scan_mit_mot():
n_steps=10,
)
grads_wrt_xs = pt.grad(ys.sum(), wrt=xs)
fg = FunctionGraph([xs], [grads_wrt_xs])
compare_jax_and_py(fg, [np.arange(10)])
compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)])
def test_scan_update():
......@@ -192,8 +186,7 @@ def test_scan_while():
n_steps=100,
)
fg = FunctionGraph([], [xs])
compare_jax_and_py(fg, [])
compare_jax_and_py([], [xs], [])
def test_scan_SEIR():
......@@ -257,11 +250,6 @@ def test_scan_SEIR():
logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp"
out_fg = FunctionGraph(
[at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=config.floatX)
logp_d0 = np.array(0.0, dtype=config.floatX)
......@@ -283,7 +271,12 @@ def test_scan_SEIR():
gamma_val,
delta_val,
]
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX")
compare_jax_and_py(
[at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
test_input_vals,
jax_mode="JAX",
)
def test_scan_mitsot_with_nonseq():
......@@ -313,10 +306,8 @@ def test_scan_mitsot_with_nonseq():
y_scan_pt.name = "y"
y_scan_pt.owner.inputs[0].name = "y_all"
out_fg = FunctionGraph([a_pt], [y_scan_pt])
test_input_vals = [np.array(10.0).astype(config.floatX)]
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX")
compare_jax_and_py([a_pt], [y_scan_pt], test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("x0_func", [dvector, dmatrix])
......@@ -343,9 +334,8 @@ def test_nd_scan_sit_sot(x0_func, A_func):
)
A_val = np.eye(k, dtype=config.floatX)
fg = FunctionGraph([x0, A], [xs])
test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
compare_jax_and_py([x0, A], [xs], test_input_vals, jax_mode="JAX")
def test_nd_scan_sit_sot_with_seq():
......@@ -366,9 +356,8 @@ def test_nd_scan_sit_sot_with_seq():
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
A_val = np.eye(k, dtype=config.floatX)
fg = FunctionGraph([x, A], [xs])
test_input_vals = [x_val, A_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
compare_jax_and_py([x, A], [xs], test_input_vals, jax_mode="JAX")
def test_nd_scan_mit_sot():
......@@ -384,13 +373,12 @@ def test_nd_scan_mit_sot():
n_steps=10,
)
fg = FunctionGraph([x0, A, B], [xs])
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
A_val = np.eye(3, dtype=config.floatX)
B_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val, B_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
compare_jax_and_py([x0, A, B], [xs], test_input_vals, jax_mode="JAX")
def test_nd_scan_sit_sot_with_carry():
......@@ -409,12 +397,11 @@ def test_nd_scan_sit_sot_with_carry():
mode=get_mode("JAX"),
)
fg = FunctionGraph([x0, A], xs)
x0_val = np.arange(3, dtype=config.floatX)
A_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
compare_jax_and_py([x0, A], xs, test_input_vals, jax_mode="JAX")
def test_default_mode_excludes_incompatible_rewrites():
......@@ -422,8 +409,7 @@ def test_default_mode_excludes_incompatible_rewrites():
A = matrix("A")
B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
fg = FunctionGraph([A, B], [out])
compare_jax_and_py(fg, [np.eye(3), np.eye(3)], jax_mode="JAX")
compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX")
def test_dynamic_sequence_length():
......
......@@ -4,7 +4,6 @@ import pytest
import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -13,29 +12,27 @@ from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_shape_ops():
x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [], must_be_device_array=False)
compare_jax_and_py([], [x], [], must_be_device_array=False)
x = Shape_i(1)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [], must_be_device_array=False)
compare_jax_and_py([], [x], [], must_be_device_array=False)
def test_jax_specify_shape():
in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None))
x_fg = FunctionGraph([in_pt], [x])
compare_jax_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
compare_jax_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)])
# When used to assert two arrays have similar shapes
in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape)
x_fg = FunctionGraph([in_pt, shape_pt], [x])
compare_jax_and_py(
x_fg,
[in_pt, shape_pt],
[x],
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
)
......@@ -43,20 +40,17 @@ def test_jax_specify_shape():
def test_jax_Reshape_constant():
a = vector("a")
x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
def test_jax_Reshape_concrete_shape():
"""JAX should compile when a concrete value is passed for the `shape` parameter."""
a = vector("a")
x = reshape(a, a.shape)
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
compare_jax_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
@pytest.mark.xfail(
......@@ -66,23 +60,20 @@ def test_jax_Reshape_shape_graph_input():
a = vector("a")
shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt))
x_fg = FunctionGraph([a, shape_pt], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
compare_jax_and_py(
[a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
)
def test_jax_compile_ops():
x = DeepCopyOp()(pt.as_tensor_variable(1.1))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
compare_jax_and_py([], [x], [])
x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
compare_jax_and_py([], [x], [])
x = ViewOp()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
compare_jax_and_py([], [x], [])
......@@ -6,7 +6,6 @@ import pytest
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor import slinalg as pt_slinalg
from pytensor.tensor import subtensor as pt_subtensor
......@@ -30,13 +29,11 @@ def test_jax_basic():
out = pt_subtensor.inc_subtensor(out[0, 1], 2.0)
out = out[:5, :3]
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [
np.tile(np.arange(10), (10, 1)).astype(config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
]
_, [jax_res] = compare_jax_and_py(out_fg, test_input_vals)
_, [jax_res] = compare_jax_and_py([x, y], [out], test_input_vals)
# Confirm that the `Subtensor` slice operations are correct
assert jax_res.shape == (5, 3)
......@@ -46,19 +43,17 @@ def test_jax_basic():
assert jax_res[0, 1] == -8.0
out = clip(x, y, 5)
out_fg = FunctionGraph([x, y], [out])
compare_jax_and_py(out_fg, test_input_vals)
compare_jax_and_py([x, y], [out], test_input_vals)
out = pt.diagonal(x, 0)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]
[x], [out], [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]
)
out = pt_slinalg.cholesky(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(
out_fg,
[x],
[out],
[
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX
......@@ -68,9 +63,9 @@ def test_jax_basic():
# not sure why this isn't working yet with lower=False
out = pt_slinalg.Cholesky(lower=False)(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(
out_fg,
[x],
[out],
[
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX
......@@ -79,9 +74,9 @@ def test_jax_basic():
)
out = pt_slinalg.solve(x, b)
out_fg = FunctionGraph([x, b], [out])
compare_jax_and_py(
out_fg,
[x, b],
[out],
[
np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX),
......@@ -89,19 +84,17 @@ def test_jax_basic():
)
out = pt.diag(b)
out_fg = FunctionGraph([b], [out])
compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])
compare_jax_and_py([b], [out], [np.arange(10).astype(config.floatX)])
out = pt_nlinalg.det(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]
[x], [out], [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]
)
out = pt_nlinalg.matrix_inverse(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(
out_fg,
[x],
[out],
[
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX
......@@ -124,9 +117,9 @@ def test_jax_SolveTriangular(trans, lower, check_finite):
lower=lower,
check_finite=check_finite,
)
out_fg = FunctionGraph([x, b], [out])
compare_jax_and_py(
out_fg,
[x, b],
[out],
[
np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX),
......@@ -141,10 +134,10 @@ def test_jax_block_diag():
D = matrix("D")
out = pt_slinalg.block_diag(A, B, C, D)
out_fg = FunctionGraph([A, B, C, D], [out])
compare_jax_and_py(
out_fg,
[A, B, C, D],
[out],
[
np.random.normal(size=(5, 5)).astype(config.floatX),
np.random.normal(size=(3, 3)).astype(config.floatX),
......@@ -158,9 +151,10 @@ def test_jax_block_diag_blockwise():
A = pt.tensor3("A")
B = pt.tensor3("B")
out = pt_slinalg.block_diag(A, B)
out_fg = FunctionGraph([A, B], [out])
compare_jax_and_py(
out_fg,
[A, B],
[out],
[
np.random.normal(size=(5, 5, 5)).astype(config.floatX),
np.random.normal(size=(5, 3, 3)).astype(config.floatX),
......@@ -174,11 +168,11 @@ def test_jax_eigvalsh(lower):
B = matrix("B")
out = pt_slinalg.eigvalsh(A, B, lower=lower)
out_fg = FunctionGraph([A, B], [out])
with pytest.raises(NotImplementedError):
compare_jax_and_py(
out_fg,
[A, B],
[out],
[
np.array(
[[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]
......@@ -189,7 +183,8 @@ def test_jax_eigvalsh(lower):
],
)
compare_jax_and_py(
out_fg,
[A, B],
[out],
[
np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype(
config.floatX
......@@ -207,11 +202,11 @@ def test_jax_solve_discrete_lyapunov(
A = pt.tensor(name="A", shape=shape)
B = pt.tensor(name="B", shape=shape)
out = pt_slinalg.solve_discrete_lyapunov(A, B, method=method)
out_fg = FunctionGraph([A, B], [out])
atol = rtol = 1e-8 if config.floatX == "float64" else 1e-3
compare_jax_and_py(
out_fg,
[A, B],
[out],
[
np.random.normal(size=shape).astype(config.floatX),
np.random.normal(size=shape).astype(config.floatX),
......
import numpy as np
import pytest
from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix
from pytensor.tensor.sort import argsort, sort
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -12,6 +11,5 @@ from tests.link.jax.test_basic import compare_jax_and_py
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_jax_and_py(fgraph, [arr])
compare_jax_and_py([x], [out], [arr])
......@@ -5,7 +5,6 @@ import scipy.sparse
import pytensor.sparse as ps
import pytensor.tensor as pt
from pytensor import function
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -50,8 +49,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):
test_values.append(y_test)
dot_pt = op(x_pt, y_pt)
fgraph = FunctionGraph(inputs, [dot_pt])
compare_jax_and_py(fgraph, test_values, jax_mode="JAX")
compare_jax_and_py(inputs, [dot_pt], test_values, jax_mode="JAX")
def test_sparse_dot_non_const_raises():
......
......@@ -21,55 +21,55 @@ def test_jax_Subtensor_constant():
# Basic indices
out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
# Advanced indexing
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
# Advanced and basic indexing
out_pt = x_pt[[1, 2], :]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
# Flipping
out_pt = x_pt[::-1]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
# Boolean indexing should work if indexes are constant
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
@pytest.mark.xfail(reason="`a` should be specified as static when JIT-compiling")
......@@ -78,8 +78,8 @@ def test_jax_Subtensor_dynamic():
x = pt.arange(3)
out_pt = x[:a]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([a], [out_pt])
compare_jax_and_py(out_fg, [1])
compare_jax_and_py([a], [out_pt], [1])
def test_jax_Subtensor_dynamic_boolean_mask():
......@@ -90,11 +90,9 @@ def test_jax_Subtensor_dynamic_boolean_mask():
out_pt = x_pt[x_pt < 0]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
x_pt_test = np.arange(-5, 5)
with pytest.raises(NonConcreteBooleanIndexError):
compare_jax_and_py(out_fg, [x_pt_test])
compare_jax_and_py([x_pt], [out_pt], [x_pt_test])
def test_jax_Subtensor_boolean_mask_reexpressible():
......@@ -110,8 +108,10 @@ def test_jax_Subtensor_boolean_mask_reexpressible():
"""
x_pt = pt.matrix("x")
out_pt = x_pt[x_pt < 0].sum()
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [np.arange(25).reshape(5, 5).astype(config.floatX)])
compare_jax_and_py(
[x_pt], [out_pt], [np.arange(25).reshape(5, 5).astype(config.floatX)]
)
def test_boolean_indexing_sum_not_applicable():
......@@ -136,19 +136,19 @@ def test_jax_IncSubtensor():
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
# "Set" advanced indices
st_pt = pt.as_tensor_variable(
......@@ -156,39 +156,39 @@ def test_jax_IncSubtensor():
)
out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
# "Set" boolean indices
mask_pt = pt.constant(x_np > 0)
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
# "Increment" basic indices
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
# "Increment" advanced indices
st_pt = pt.as_tensor_variable(
......@@ -196,33 +196,33 @@ def test_jax_IncSubtensor():
)
out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
# "Increment" boolean indices
mask_pt = pt.constant(x_np > 0)
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out_pt], [])
def test_jax_IncSubtensor_boolean_indexing_reexpressible():
......@@ -243,14 +243,14 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
mask_pt = pt.as_tensor(x_pt) > 0
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
mask_pt = pt.as_tensor(x_pt) > 0
out_pt = pt_subtensor.inc_subtensor(x_pt[mask_pt], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_jax_and_py(out_fg, [x_np])
compare_jax_and_py([x_pt], [out_pt], [x_np])
def test_boolean_indexing_set_or_inc_not_applicable():
......
......@@ -10,8 +10,6 @@ from jax import errors
import pytensor
import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.type import iscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_basic import check_alloc_runtime_broadcast
......@@ -19,38 +17,31 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast
def test_jax_Alloc():
x = ptb.alloc(0.0, 2, 3)
x_fg = FunctionGraph([], [x])
_, [jax_res] = compare_jax_and_py(x_fg, [])
_, [jax_res] = compare_jax_and_py([], [x], [])
assert jax_res.shape == (2, 3)
x = ptb.alloc(1.1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
compare_jax_and_py([], [x], [])
x = ptb.AllocEmpty("float32")(2, 3)
x_fg = FunctionGraph([], [x])
def compare_shape_dtype(x, y):
(x,) = x
(y,) = y
return x.shape == y.shape and x.dtype == y.dtype
np.testing.assert_array_equal(x, y, strict=True)
compare_jax_and_py(x_fg, [], assert_fn=compare_shape_dtype)
compare_jax_and_py([], [x], [], assert_fn=compare_shape_dtype)
a = scalar("a")
x = ptb.alloc(a, 20)
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [10.0])
compare_jax_and_py([a], [x], [10.0])
a = vector("a")
x = ptb.alloc(a, 20, 10)
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)])
compare_jax_and_py([a], [x], [np.ones(10, dtype=config.floatX)])
def test_alloc_runtime_broadcast():
......@@ -59,34 +50,31 @@ def test_alloc_runtime_broadcast():
def test_jax_MakeVector():
x = ptb.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
compare_jax_and_py([], [x], [])
def test_arange():
out = ptb.arange(1, 10, 2)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
compare_jax_and_py([], [out], [])
def test_arange_of_shape():
x = vector("x")
out = ptb.arange(1, x.shape[-1], 2)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [np.zeros((5,))], jax_mode="JAX")
compare_jax_and_py([x], [out], [np.zeros((5,))], jax_mode="JAX")
def test_arange_nonconcrete():
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
a = scalar("a")
a.tag.test_value = 10
a_test_value = 10
out = ptb.arange(a)
with pytest.raises(NotImplementedError):
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([a], [out], [a_test_value])
def test_jax_Join():
......@@ -94,16 +82,17 @@ def test_jax_Join():
b = matrix("b")
x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py(
x_fg,
[a, b],
[x],
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_jax_and_py(
x_fg,
[a, b],
[x],
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
......@@ -111,16 +100,17 @@ def test_jax_Join():
)
x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py(
x_fg,
[a, b],
[x],
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_jax_and_py(
x_fg,
[a, b],
[x],
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
......@@ -132,9 +122,9 @@ class TestJaxSplit:
def test_basic(self):
a = matrix("a")
a_splits = ptb.split(a, splits_size=[1, 2, 3], n_splits=3, axis=0)
fg = FunctionGraph([a], a_splits)
compare_jax_and_py(
fg,
[a],
a_splits,
[
np.zeros((6, 4)).astype(config.floatX),
],
......@@ -142,9 +132,9 @@ class TestJaxSplit:
a = matrix("a", shape=(6, None))
a_splits = ptb.split(a, splits_size=[2, a.shape[0] - 2], n_splits=2, axis=0)
fg = FunctionGraph([a], a_splits)
compare_jax_and_py(
fg,
[a],
a_splits,
[
np.zeros((6, 4)).astype(config.floatX),
],
......@@ -207,15 +197,14 @@ class TestJaxSplit:
def test_jax_eye():
"""Tests jaxification of the Eye operator"""
out = ptb.eye(3)
out_fg = FunctionGraph([], [out])
compare_jax_and_py(out_fg, [])
compare_jax_and_py([], [out], [])
def test_tri():
out = ptb.tri(10, 10, 0)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
compare_jax_and_py([], [out], [])
@pytest.mark.skipif(
......@@ -230,14 +219,13 @@ def test_tri_nonconcrete():
scalar("n", dtype="int64"),
scalar("k", dtype="int64"),
)
m.tag.test_value = 10
n.tag.test_value = 10
k.tag.test_value = 0
m_test_value = 10
n_test_value = 10
k_test_value = 0
out = ptb.tri(m, n, k)
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but
# the error handler raises an Attribute error first, so that's what this test needs to pass
with pytest.raises(AttributeError):
fgraph = FunctionGraph([m, n, k], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py([m, n, k], [out], [m_test_value, n_test_value, k_test_value])
import contextlib
import inspect
from collections.abc import Callable, Sequence
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any
from unittest import mock
......@@ -21,10 +21,8 @@ from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.compile.ops import ViewOp
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type
from pytensor.ifelse import ifelse
......@@ -39,7 +37,6 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
if TYPE_CHECKING:
from pytensor.graph.basic import Variable
from pytensor.tensor import TensorLike
class MyType(Type):
......@@ -128,11 +125,6 @@ py_mode = Mode("py", opts)
rng = np.random.default_rng(42849)
def set_test_value(x, v):
x.tag.test_value = v
return x
def compare_shape_dtype(x, y):
return x.shape == y.shape and x.dtype == y.dtype
......@@ -225,28 +217,30 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
def compare_numba_and_py(
fgraph: FunctionGraph | tuple[Sequence["Variable"], Sequence["Variable"]],
inputs: Sequence["TensorLike"],
assert_fn: Callable | None = None,
graph_inputs: Iterable[Variable],
graph_outputs: Variable | Iterable[Variable],
test_inputs: Iterable,
*,
assert_fn: Callable | None = None,
numba_mode=numba_mode,
py_mode=py_mode,
updates=None,
inplace: bool = False,
eval_obj_mode: bool = True,
) -> tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality
"""Function to compare python function output and Numba compiled output for testing equality
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
this function which then compiles the graphs in both Numba and python, runs the calculation
in both and checks if the results are the same
The inputs and outputs are then passed to this function which then compiles the given function in both
numba and python, runs the calculation in both and checks if the results are the same
Parameters
----------
fgraph
`FunctionGraph` or tuple(inputs, outputs) to compare.
inputs
Numeric inputs to be passed to the compiled graphs.
graph_inputs:
Symbolic inputs to the graph
graph_outputs:
Symbolic outputs of the graph
test_inputs
Numerical inputs with which to evaluate the graph.
assert_fn
Assert function used to check for equality between python and Numba. If not
provided uses `np.testing.assert_allclose`.
......@@ -267,42 +261,38 @@ def compare_numba_and_py(
x, y
)
if isinstance(fgraph, FunctionGraph):
fn_inputs = fgraph.inputs
fn_outputs = fgraph.outputs
else:
fn_inputs, fn_outputs = fgraph
fn_inputs = [i for i in fn_inputs if not isinstance(i, SharedVariable)]
if any(inp.owner is not None for inp in graph_inputs):
raise ValueError("Inputs must be root variables")
pytensor_py_fn = function(
fn_inputs, fn_outputs, mode=py_mode, accept_inplace=True, updates=updates
graph_inputs, graph_outputs, mode=py_mode, accept_inplace=True, updates=updates
)
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
py_res = pytensor_py_fn(*test_inputs)
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
py_res = pytensor_py_fn(*test_inputs_copy)
# Get some coverage (and catch errors in python mode before unreadable numba ones)
if eval_obj_mode:
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
eval_python_only(fn_inputs, fn_outputs, test_inputs, mode=numba_mode)
test_inputs_copy = (
(inp.copy() for inp in test_inputs) if inplace else test_inputs
)
eval_python_only(graph_inputs, graph_outputs, test_inputs_copy, mode=numba_mode)
pytensor_numba_fn = function(
fn_inputs,
fn_outputs,
graph_inputs,
graph_outputs,
mode=numba_mode,
accept_inplace=True,
updates=updates,
)
test_inputs_copy = (inp.copy() for inp in test_inputs) if inplace else test_inputs
numba_res = pytensor_numba_fn(*test_inputs_copy)
test_inputs = (inp.copy() for inp in inputs) if inplace else inputs
numba_res = pytensor_numba_fn(*test_inputs)
if len(fn_outputs) > 1:
if isinstance(graph_outputs, tuple | list):
for j, p in zip(numba_res, py_res, strict=True):
assert_fn(j, p)
else:
assert_fn(numba_res[0], py_res[0])
assert_fn(numba_res, py_res)
return pytensor_numba_fn, numba_res
......@@ -380,53 +370,53 @@ def test_create_numba_signature(v, expected, force_scalar):
)
def test_Shape(x, i):
g = Shape()(pt.as_tensor_variable(x))
g_fg = FunctionGraph([], [g])
compare_numba_and_py(g_fg, [])
compare_numba_and_py([], [g], [])
g = Shape_i(i)(pt.as_tensor_variable(x))
g_fg = FunctionGraph([], [g])
compare_numba_and_py(g_fg, [])
compare_numba_and_py([], [g], [])
@pytest.mark.parametrize(
"v, shape, ndim",
[
(set_test_value(pt.vector(), np.array([4], dtype=config.floatX)), (), 0),
(set_test_value(pt.vector(), np.arange(4, dtype=config.floatX)), (2, 2), 2),
((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0),
((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2),
(
set_test_value(pt.vector(), np.arange(4, dtype=config.floatX)),
set_test_value(pt.lvector(), np.array([2, 2], dtype="int64")),
(pt.vector(), np.arange(4, dtype=config.floatX)),
(pt.lvector(), np.array([2, 2], dtype="int64")),
2,
),
],
)
def test_Reshape(v, shape, ndim):
v, v_test_value = v
shape, shape_test_value = shape
g = Reshape(ndim)(v, shape)
g_fg = FunctionGraph(outputs=[g])
inputs = [v] if not isinstance(shape, Variable) else [v, shape]
test_values = (
[v_test_value]
if not isinstance(shape, Variable)
else [v_test_value, shape_test_value]
)
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
inputs,
[g],
test_values,
)
def test_Reshape_scalar():
v = pt.vector()
v.tag.test_value = np.array([1.0], dtype=config.floatX)
v_test_value = np.array([1.0], dtype=config.floatX)
g = Reshape(1)(v[0], (1,))
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
g,
[v_test_value],
)
......@@ -434,53 +424,44 @@ def test_Reshape_scalar():
"v, shape, fails",
[
(
set_test_value(pt.matrix(), np.array([[1.0]], dtype=config.floatX)),
(pt.matrix(), np.array([[1.0]], dtype=config.floatX)),
(1, 1),
False,
),
(
set_test_value(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, 1),
True,
),
(
set_test_value(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, None),
False,
),
],
)
def test_SpecifyShape(v, shape, fails):
v, v_test_value = v
g = SpecifyShape()(v, *shape)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if not fails else pytest.raises(AssertionError)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
[g],
[v_test_value],
)
@pytest.mark.parametrize(
"v",
[
set_test_value(pt.vector(), np.arange(4, dtype=config.floatX)),
],
)
def test_ViewOp(v):
def test_ViewOp():
v = pt.vector()
v_test_value = np.arange(4, dtype=config.floatX)
g = ViewOp()(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
[g],
[v_test_value],
)
......@@ -489,20 +470,16 @@ def test_ViewOp(v):
[
(
[
set_test_value(
pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)
),
set_test_value(pt.lmatrix(), rng.poisson(size=(2, 3))),
(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
(pt.lmatrix(), rng.poisson(size=(2, 3))),
],
MySingleOut,
UserWarning,
),
(
[
set_test_value(
pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)
),
set_test_value(pt.lmatrix(), rng.poisson(size=(2, 3))),
(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
(pt.lmatrix(), rng.poisson(size=(2, 3))),
],
MyMultiOut,
UserWarning,
......@@ -510,38 +487,32 @@ def test_ViewOp(v):
],
)
def test_perform(inputs, op, exc):
inputs, test_values = zip(*inputs, strict=True)
g = op()(*inputs)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
outputs = g
else:
g_fg = FunctionGraph(outputs=[g])
outputs = [g]
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
inputs,
outputs,
test_values,
)
def test_perform_params():
"""This tests for `Op.perform` implementations that require the `params` arguments."""
x = pt.vector()
x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX)
x = pt.vector(shape=(2,))
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x, np.array(True))
if not isinstance(out, list | tuple):
out = [out]
out_fg = FunctionGraph([x], out)
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
compare_numba_and_py([x], out, [x_test_value])
def test_perform_type_convert():
......@@ -552,59 +523,50 @@ def test_perform_type_convert():
"""
x = pt.vector()
x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX)
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x.sum(), np.array(True))
if not isinstance(out, list | tuple):
out = [out]
out_fg = FunctionGraph([x], out)
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
compare_numba_and_py([x], out, [x_test_value])
@pytest.mark.parametrize(
"x, y, exc",
[
(
set_test_value(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
set_test_value(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
None,
),
(
set_test_value(
pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")
),
set_test_value(
pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")
),
(pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
(pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
None,
),
(
set_test_value(pt.lmatrix(), rng.poisson(size=(3, 2))),
set_test_value(pt.fvector(), rng.random(size=(2,)).astype("float32")),
(pt.lmatrix(), rng.poisson(size=(3, 2))),
(pt.fvector(), rng.random(size=(2,)).astype("float32")),
None,
),
(
set_test_value(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
set_test_value(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
None,
),
],
)
def test_Dot(x, y, exc):
x, x_test_value = x
y, y_test_value = y
g = ptm.Dot()(x, y)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x, y],
[g],
[x_test_value, y_test_value],
)
......@@ -612,44 +574,41 @@ def test_Dot(x, y, exc):
"x, exc",
[
(
set_test_value(ps.float64(), np.array(0.0, dtype="float64")),
(ps.float64(), np.array(0.0, dtype="float64")),
None,
),
(
set_test_value(ps.float64(), np.array(-32.0, dtype="float64")),
(ps.float64(), np.array(-32.0, dtype="float64")),
None,
),
(
set_test_value(ps.float64(), np.array(-40.0, dtype="float64")),
(ps.float64(), np.array(-40.0, dtype="float64")),
None,
),
(
set_test_value(ps.float64(), np.array(32.0, dtype="float64")),
(ps.float64(), np.array(32.0, dtype="float64")),
None,
),
(
set_test_value(ps.float64(), np.array(40.0, dtype="float64")),
(ps.float64(), np.array(40.0, dtype="float64")),
None,
),
(
set_test_value(ps.int64(), np.array(32, dtype="int64")),
(ps.int64(), np.array(32, dtype="int64")),
None,
),
],
)
def test_Softplus(x, exc):
x, x_test_value = x
g = psm.Softplus(ps.upgrade_to_float)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
[g],
[x_test_value],
)
......@@ -657,22 +616,22 @@ def test_Softplus(x, exc):
"x, y, exc",
[
(
set_test_value(
(
pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
set_test_value(
(
pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
None,
),
(
set_test_value(
(
pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
set_test_value(
(
pt.ltensor3(),
rng.poisson(size=(2, 3, 3)).astype("int64"),
),
......@@ -681,22 +640,17 @@ def test_Softplus(x, exc):
],
)
def test_BatchedDot(x, y, exc):
g = blas.BatchedDot()(x, y)
x, x_test_value = x
y, y_test_value = y
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
g = blas.BatchedDot()(x, y)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x, y],
g,
[x_test_value, y_test_value],
)
......@@ -767,15 +721,15 @@ y = np.array(
[
([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]),
(
[set_test_value(pt.dscalar(), np.array(0.2, dtype=np.float64))],
[(pt.dscalar(), np.array(0.2, dtype=np.float64))],
lambda x: x < 0.5,
np.r_[1, 2, 3],
np.r_[-1, -2, -3],
),
(
[
set_test_value(pt.dscalar(), np.array(0.3, dtype=np.float64)),
set_test_value(pt.dscalar(), np.array(0.5, dtype=np.float64)),
(pt.dscalar(), np.array(0.3, dtype=np.float64)),
(pt.dscalar(), np.array(0.5, dtype=np.float64)),
],
lambda x, y: x > y,
x,
......@@ -783,8 +737,8 @@ y = np.array(
),
(
[
set_test_value(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
set_test_value(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: pt.all(x > y),
x,
......@@ -792,8 +746,8 @@ y = np.array(
),
(
[
set_test_value(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
set_test_value(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: pt.all(x > y),
[x, 2 * x],
......@@ -801,8 +755,8 @@ y = np.array(
),
(
[
set_test_value(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
set_test_value(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
],
lambda x, y: pt.all(x > y),
[x, 2 * x],
......@@ -811,14 +765,9 @@ y = np.array(
],
)
def test_IfElse(inputs, cond_fn, true_vals, false_vals):
inputs, test_values = zip(*inputs, strict=True) if inputs else ([], [])
out = ifelse(cond_fn(*inputs), true_vals, false_vals)
if not isinstance(out, list):
out = [out]
out_fg = FunctionGraph(inputs, out)
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
compare_numba_and_py(inputs, out, test_values)
@pytest.mark.xfail(reason="https://github.com/numba/numba/issues/7409")
......@@ -883,7 +832,7 @@ def test_OpFromGraph():
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv])
compare_numba_and_py([x, y, z], [out], [xv, yv, zv])
@pytest.mark.filterwarnings("error")
......
......@@ -27,7 +27,8 @@ def test_blockwise(core_op, shape_opt):
)
x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
compare_numba_and_py(
([x], outs),
[x],
outs,
[x_test],
numba_mode=mode,
eval_obj_mode=False,
......
......@@ -11,10 +11,7 @@ import pytensor.tensor.math as ptm
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import float64
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
......@@ -22,7 +19,6 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
compare_numba_and_py,
scalar_my_multi_out,
set_test_value,
)
from tests.tensor.test_elemwise import (
careduce_benchmark_tester,
......@@ -116,13 +112,13 @@ rng = np.random.default_rng(42849)
def test_Elemwise(inputs, input_vals, output_fn, exc):
outputs = output_fn(*inputs)
out_fg = FunctionGraph(
outputs=[outputs] if not isinstance(outputs, list) else outputs
)
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(out_fg, input_vals)
compare_numba_and_py(
inputs,
outputs,
input_vals,
)
@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
......@@ -135,7 +131,7 @@ def test_elemwise_runtime_broadcast():
[
# `{'drop': [], 'shuffle': [], 'augment': [0, 1]}`
(
set_test_value(
(
pt.lscalar(name="a"),
np.array(1, dtype=np.int64),
),
......@@ -144,21 +140,17 @@ def test_elemwise_runtime_broadcast():
# I.e. `a_pt.T`
# `{'drop': [], 'shuffle': [1, 0], 'augment': []}`
(
set_test_value(
pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
),
(pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)),
(1, 0),
),
# `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}`
(
set_test_value(
pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
),
(pt.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)),
(1, 0, "x"),
),
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
(
set_test_value(
(
pt.tensor(dtype=config.floatX, shape=(None, 1, None), name="a"),
np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX),
),
......@@ -167,21 +159,21 @@ def test_elemwise_runtime_broadcast():
# I.e. `a_pt.dimshuffle((0,))`
# `{'drop': [1], 'shuffle': [0], 'augment': []}`
(
set_test_value(
(
pt.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
),
(0,),
),
(
set_test_value(
(
pt.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
),
(0,),
),
(
set_test_value(
(
pt.tensor(dtype=config.floatX, shape=(1, 1, 1), name="a"),
np.array([[[1.0]]], dtype=config.floatX),
),
......@@ -190,15 +182,12 @@ def test_elemwise_runtime_broadcast():
],
)
def test_Dimshuffle(v, new_order):
v, v_test_value = v
g = v.dimshuffle(new_order)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
[g],
[v_test_value],
)
......@@ -229,79 +218,68 @@ def test_Dimshuffle_non_contiguous():
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
0,
set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])),
(pt.vector(dtype="bool"), np.array([False, True, False])),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
0,
set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])),
(pt.vector(dtype="bool"), np.array([False, True, False])),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(0, 1),
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(1, 0),
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
None,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(), # Empty axes would normally be rewritten away, but we want to test it still works
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
None,
set_test_value(
pt.scalar(), np.array(99.0, dtype=config.floatX)
(
pt.scalar(),
np.array(99.0, dtype=config.floatX),
), # Scalar input would normally be rewritten away, but we want to test it still works
),
(
......@@ -309,77 +287,62 @@ def test_Dimshuffle_non_contiguous():
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None,
set_test_value(
pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))
),
(pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None,
set_test_value(
pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))
),
(pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))),
),
],
)
def test_CAReduce(careduce_fn, axis, v):
v, v_test_value = v
g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g])
fn, _ = compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
[g],
[v_test_value],
)
# Confirm CAReduce is in the compiled function
fn.dprint()
# fn.dprint()
[node] = fn.maker.fgraph.apply_nodes
assert isinstance(node.op, CAReduce)
......@@ -387,102 +350,91 @@ def test_CAReduce(careduce_fn, axis, v):
def test_scalar_Elemwise_Clip():
a = pt.scalar("a")
b = pt.scalar("b")
inputs = [a, b]
z = pt.switch(1, a, b)
c = pt.clip(z, 1, 3)
c_fg = FunctionGraph(outputs=[c])
compare_numba_and_py(c_fg, [1, 1])
compare_numba_and_py(inputs, [c], [1, 1])
@pytest.mark.parametrize(
"dy, sm, axis, exc",
[
(
set_test_value(
pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
),
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
(pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)),
(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
None,
None,
),
(
set_test_value(
pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
),
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
(pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)),
(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None,
),
(
set_test_value(
pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
),
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
(pt.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)),
(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1,
None,
),
],
)
def test_SoftmaxGrad(dy, sm, axis, exc):
dy, dy_test_value = dy
sm, sm_test_value = sm
g = SoftmaxGrad(axis=axis)(dy, sm)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[dy, sm],
[g],
[dy_test_value, sm_test_value],
)
def test_SoftMaxGrad_constant_dy():
dy = pt.constant(np.zeros((3,), dtype=config.floatX))
sm = pt.vector(shape=(3,))
inputs = [sm]
g = SoftmaxGrad(axis=None)(dy, sm)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(g_fg, [np.ones((3,), dtype=config.floatX)])
compare_numba_and_py(inputs, [g], [np.ones((3,), dtype=config.floatX)])
@pytest.mark.parametrize(
"x, axis, exc",
[
(
set_test_value(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
None,
None,
),
(
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
None,
None,
),
(
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None,
),
],
)
def test_Softmax(x, axis, exc):
x, x_test_value = x
g = Softmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
[g],
[x_test_value],
)
......@@ -490,35 +442,32 @@ def test_Softmax(x, axis, exc):
"x, axis, exc",
[
(
set_test_value(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
None,
None,
),
(
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
0,
None,
),
(
set_test_value(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
(pt.matrix(), rng.random(size=(2, 3)).astype(config.floatX)),
1,
None,
),
],
)
def test_LogSoftmax(x, axis, exc):
x, x_test_value = x
g = LogSoftmax(axis=axis)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
[g],
[x_test_value],
)
......@@ -526,44 +475,37 @@ def test_LogSoftmax(x, axis, exc):
"x, axes, exc",
[
(
set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")),
(pt.dscalar(), np.array(0.0, dtype="float64")),
[],
None,
),
(
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
(pt.dvector(), rng.random(size=(3,)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0, 1],
None,
),
],
)
def test_Max(x, axes, exc):
x, x_test_value = x
g = ptm.Max(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
[g],
[x_test_value],
)
......@@ -571,44 +513,37 @@ def test_Max(x, axes, exc):
"x, axes, exc",
[
(
set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")),
(pt.dscalar(), np.array(0.0, dtype="float64")),
[],
None,
),
(
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
(pt.dvector(), rng.random(size=(3,)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0, 1],
None,
),
],
)
def test_Argmax(x, axes, exc):
x, x_test_value = x
g = ptm.Argmax(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
[g],
[x_test_value],
)
......@@ -636,7 +571,8 @@ def test_scalar_loop():
with pytest.warns(UserWarning, match="object mode"):
compare_numba_and_py(
([x], [elemwise_loop]),
[x],
[elemwise_loop],
(np.array([1, 2, 3], dtype="float64"),),
)
......
......@@ -5,11 +5,8 @@ import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import extra_ops
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849)
......@@ -18,20 +15,17 @@ rng = np.random.default_rng(42849)
@pytest.mark.parametrize(
"val",
[
set_test_value(pt.lscalar(), np.array(6, dtype="int64")),
(pt.lscalar(), np.array(6, dtype="int64")),
],
)
def test_Bartlett(val):
val, test_val = val
g = extra_ops.bartlett(val)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[val],
g,
[test_val],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, atol=1e-15),
)
......@@ -40,97 +34,71 @@ def test_Bartlett(val):
"val, axis, mode",
[
(
set_test_value(
pt.matrix(), np.arange(3, dtype=config.floatX).reshape((3, 1))
),
(pt.matrix(), np.arange(3, dtype=config.floatX).reshape((3, 1))),
1,
"add",
),
(
set_test_value(
pt.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))
),
(pt.dtensor3(), np.arange(30, dtype=config.floatX).reshape((2, 3, 5))),
-1,
"add",
),
(
set_test_value(
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
0,
"add",
),
(
set_test_value(
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
1,
"add",
),
(
set_test_value(
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
None,
"add",
),
(
set_test_value(
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
0,
"mul",
),
(
set_test_value(
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
1,
"mul",
),
(
set_test_value(
pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))
),
(pt.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2))),
None,
"mul",
),
],
)
def test_CumOp(val, axis, mode):
val, test_val = val
g = extra_ops.CumOp(axis=axis, mode=mode)(val)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[val],
g,
[test_val],
)
@pytest.mark.parametrize(
"a, val",
[
(
set_test_value(pt.lmatrix(), np.zeros((10, 2), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")),
)
],
)
def test_FillDiagonal(a, val):
def test_FillDiagonal():
a = pt.lmatrix("a")
test_a = np.zeros((10, 2), dtype="int64")
val = pt.lscalar("val")
test_val = np.array(1, dtype="int64")
g = extra_ops.FillDiagonal()(a, val)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[a, val],
g,
[test_a, test_val],
)
......@@ -138,33 +106,32 @@ def test_FillDiagonal(a, val):
"a, val, offset",
[
(
set_test_value(pt.lmatrix(), np.zeros((10, 2), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")),
set_test_value(pt.lscalar(), np.array(-1, dtype="int64")),
(pt.lmatrix(), np.zeros((10, 2), dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lscalar(), np.array(-1, dtype="int64")),
),
(
set_test_value(pt.lmatrix(), np.zeros((10, 2), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")),
set_test_value(pt.lscalar(), np.array(0, dtype="int64")),
(pt.lmatrix(), np.zeros((10, 2), dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lscalar(), np.array(0, dtype="int64")),
),
(
set_test_value(pt.lmatrix(), np.zeros((10, 3), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lmatrix(), np.zeros((10, 3), dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
),
],
)
def test_FillDiagonalOffset(a, val, offset):
a, test_a = a
val, test_val = val
offset, test_offset = offset
g = extra_ops.FillDiagonalOffset()(a, val, offset)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[a, val, offset],
g,
[test_a, test_val, test_offset],
)
......@@ -172,65 +139,56 @@ def test_FillDiagonalOffset(a, val, offset):
"arr, shape, mode, order, exc",
[
(
tuple(set_test_value(pt.lscalar(), v) for v in np.array([0])),
set_test_value(pt.lvector(), np.array([2])),
tuple((pt.lscalar(), v) for v in np.array([0])),
(pt.lvector(), np.array([2])),
"raise",
"C",
None,
),
(
tuple(set_test_value(pt.lscalar(), v) for v in np.array([0, 0, 3])),
set_test_value(pt.lvector(), np.array([2, 3, 4])),
tuple((pt.lscalar(), v) for v in np.array([0, 0, 3])),
(pt.lvector(), np.array([2, 3, 4])),
"raise",
"C",
None,
),
(
tuple(
set_test_value(pt.lvector(), v)
for v in np.array([[0, 1], [2, 0], [1, 3]])
),
set_test_value(pt.lvector(), np.array([2, 3, 4])),
tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])),
(pt.lvector(), np.array([2, 3, 4])),
"raise",
"C",
None,
),
(
tuple(
set_test_value(pt.lvector(), v)
for v in np.array([[0, 1], [2, 0], [1, 3]])
),
set_test_value(pt.lvector(), np.array([2, 3, 4])),
tuple((pt.lvector(), v) for v in np.array([[0, 1], [2, 0], [1, 3]])),
(pt.lvector(), np.array([2, 3, 4])),
"raise",
"F",
NotImplementedError,
),
(
tuple(
set_test_value(pt.lvector(), v)
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
(pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
),
set_test_value(pt.lvector(), np.array([2, 3, 4])),
(pt.lvector(), np.array([2, 3, 4])),
"raise",
"C",
ValueError,
),
(
tuple(
set_test_value(pt.lvector(), v)
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
(pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
),
set_test_value(pt.lvector(), np.array([2, 3, 4])),
(pt.lvector(), np.array([2, 3, 4])),
"wrap",
"C",
None,
),
(
tuple(
set_test_value(pt.lvector(), v)
for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
(pt.lvector(), v) for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]])
),
set_test_value(pt.lvector(), np.array([2, 3, 4])),
(pt.lvector(), np.array([2, 3, 4])),
"clip",
"C",
None,
......@@ -238,18 +196,16 @@ def test_FillDiagonalOffset(a, val, offset):
],
)
def test_RavelMultiIndex(arr, shape, mode, order, exc):
g = extra_ops.RavelMultiIndex(mode, order)(*((*arr, shape)))
g_fg = FunctionGraph(outputs=[g])
arr, test_arr = zip(*arr, strict=True)
shape, test_shape = shape
g = extra_ops.RavelMultiIndex(mode, order)(*arr, shape)
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[*arr, shape],
g,
[*test_arr, test_shape],
)
......@@ -257,44 +213,42 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc):
"x, repeats, axis, exc",
[
(
set_test_value(pt.lscalar(), np.array(1, dtype="int64")),
set_test_value(pt.lscalar(), np.array(0, dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lscalar(), np.array(0, dtype="int64")),
None,
None,
),
(
set_test_value(pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
None,
None,
),
(
set_test_value(pt.lvector(), np.arange(2, dtype="int64")),
set_test_value(pt.lvector(), np.array([1, 1], dtype="int64")),
(pt.lvector(), np.arange(2, dtype="int64")),
(pt.lvector(), np.array([1, 1], dtype="int64")),
None,
None,
),
(
set_test_value(pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
set_test_value(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
0,
UserWarning,
),
],
)
def test_Repeat(x, repeats, axis, exc):
x, test_x = x
repeats, test_repeats = repeats
g = extra_ops.Repeat(axis)(x, repeats)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x, repeats],
g,
[test_x, test_repeats],
)
......@@ -302,7 +256,7 @@ def test_Repeat(x, repeats, axis, exc):
"x, axis, return_index, return_inverse, return_counts, exc",
[
(
set_test_value(pt.lscalar(), np.array(1, dtype="int64")),
(pt.lscalar(), np.array(1, dtype="int64")),
None,
False,
False,
......@@ -310,7 +264,7 @@ def test_Repeat(x, repeats, axis, exc):
None,
),
(
set_test_value(pt.lvector(), np.array([1, 1, 2], dtype="int64")),
(pt.lvector(), np.array([1, 1, 2], dtype="int64")),
None,
False,
False,
......@@ -318,7 +272,7 @@ def test_Repeat(x, repeats, axis, exc):
None,
),
(
set_test_value(pt.lmatrix(), np.array([[1, 1], [2, 2]], dtype="int64")),
(pt.lmatrix(), np.array([[1, 1], [2, 2]], dtype="int64")),
None,
False,
False,
......@@ -326,9 +280,7 @@ def test_Repeat(x, repeats, axis, exc):
None,
),
(
set_test_value(
pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")
),
(pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")),
0,
False,
False,
......@@ -336,9 +288,7 @@ def test_Repeat(x, repeats, axis, exc):
UserWarning,
),
(
set_test_value(
pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")
),
(pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")),
0,
True,
True,
......@@ -348,22 +298,15 @@ def test_Repeat(x, repeats, axis, exc):
],
)
def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
x, test_x = x
g = extra_ops.Unique(return_index, return_inverse, return_counts, axis)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
g,
[test_x],
)
......@@ -371,19 +314,19 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
"arr, shape, order, exc",
[
(
set_test_value(pt.lvector(), np.array([9, 15, 1], dtype="int64")),
(pt.lvector(), np.array([9, 15, 1], dtype="int64")),
pt.as_tensor([2, 3, 4]),
"C",
None,
),
(
set_test_value(pt.lvector(), np.array([1, 0], dtype="int64")),
(pt.lvector(), np.array([1, 0], dtype="int64")),
pt.as_tensor([2]),
"C",
None,
),
(
set_test_value(pt.lvector(), np.array([9, 15, 1], dtype="int64")),
(pt.lvector(), np.array([9, 15, 1], dtype="int64")),
pt.as_tensor([2, 3, 4]),
"F",
NotImplementedError,
......@@ -391,22 +334,15 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
],
)
def test_UnravelIndex(arr, shape, order, exc):
arr, test_arr = arr
g = extra_ops.UnravelIndex(order)(arr, shape)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[arr],
g,
[test_arr],
)
......@@ -414,18 +350,18 @@ def test_UnravelIndex(arr, shape, order, exc):
"a, v, side, sorter, exc",
[
(
set_test_value(pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)),
set_test_value(pt.matrix(), rng.random((3, 2)).astype(config.floatX)),
(pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)),
(pt.matrix(), rng.random((3, 2)).astype(config.floatX)),
"left",
None,
None,
),
pytest.param(
set_test_value(
(
pt.vector(),
np.array([0.29769574, 0.71649186, 0.20475563]).astype(config.floatX),
),
set_test_value(
(
pt.matrix(),
np.array(
[
......@@ -440,25 +376,26 @@ def test_UnravelIndex(arr, shape, order, exc):
None,
),
(
set_test_value(pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)),
set_test_value(pt.matrix(), rng.random((3, 2)).astype(config.floatX)),
(pt.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)),
(pt.matrix(), rng.random((3, 2)).astype(config.floatX)),
"right",
set_test_value(pt.lvector(), np.array([0, 2, 1])),
(pt.lvector(), np.array([0, 2, 1])),
UserWarning,
),
],
)
def test_Searchsorted(a, v, side, sorter, exc):
a, test_a = a
v, test_v = v
if sorter is not None:
sorter, test_sorter = sorter
g = extra_ops.SearchsortedOp(side)(a, v, sorter)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[a, v] if sorter is None else [a, v, sorter],
g,
[test_a, test_v] if sorter is None else [test_a, test_v, test_sorter],
)
......@@ -4,11 +4,8 @@ import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849)
......@@ -18,14 +15,14 @@ rng = np.random.default_rng(42849)
"x, exc",
[
(
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
),
(
set_test_value(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
),
......@@ -34,18 +31,15 @@ rng = np.random.default_rng(42849)
],
)
def test_Det(x, exc):
x, test_x = x
g = nlinalg.Det()(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
g,
[test_x],
)
......@@ -53,14 +47,14 @@ def test_Det(x, exc):
"x, exc",
[
(
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
),
(
set_test_value(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")),
),
......@@ -69,18 +63,15 @@ def test_Det(x, exc):
],
)
def test_SLogDet(x, exc):
x, test_x = x
g = nlinalg.SLogDet()(x)
g_fg = FunctionGraph(outputs=g)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
g,
[test_x],
)
......@@ -112,21 +103,21 @@ y = np.array(
"x, exc",
[
(
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(x),
),
None,
),
(
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(y),
),
None,
),
(
set_test_value(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
......@@ -137,22 +128,15 @@ y = np.array(
],
)
def test_Eig(x, exc):
x, test_x = x
g = nlinalg.Eig()(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
g,
[test_x],
)
......@@ -160,7 +144,7 @@ def test_Eig(x, exc):
"x, uplo, exc",
[
(
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
......@@ -168,7 +152,7 @@ def test_Eig(x, exc):
None,
),
(
set_test_value(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
......@@ -180,22 +164,15 @@ def test_Eig(x, exc):
],
)
def test_Eigh(x, uplo, exc):
x, test_x = x
g = nlinalg.Eigh(uplo)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
g,
[test_x],
)
......@@ -204,7 +181,7 @@ def test_Eigh(x, uplo, exc):
[
(
nlinalg.MatrixInverse,
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
......@@ -213,7 +190,7 @@ def test_Eigh(x, uplo, exc):
),
(
nlinalg.MatrixInverse,
set_test_value(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
......@@ -224,7 +201,7 @@ def test_Eigh(x, uplo, exc):
),
(
nlinalg.MatrixPinv,
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
......@@ -233,7 +210,7 @@ def test_Eigh(x, uplo, exc):
),
(
nlinalg.MatrixPinv,
set_test_value(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
......@@ -245,18 +222,15 @@ def test_Eigh(x, uplo, exc):
],
)
def test_matrix_inverses(op, x, exc, op_args):
x, test_x = x
g = op(*op_args)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
g,
[test_x],
)
......@@ -264,7 +238,7 @@ def test_matrix_inverses(op, x, exc, op_args):
"x, mode, exc",
[
(
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
......@@ -272,7 +246,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None,
),
(
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
......@@ -280,7 +254,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None,
),
(
set_test_value(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
......@@ -290,7 +264,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None,
),
(
set_test_value(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
......@@ -302,22 +276,15 @@ def test_matrix_inverses(op, x, exc, op_args):
],
)
def test_QRFull(x, mode, exc):
x, test_x = x
g = nlinalg.QRFull(mode)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
g,
[test_x],
)
......@@ -325,7 +292,7 @@ def test_QRFull(x, mode, exc):
"x, full_matrices, compute_uv, exc",
[
(
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
......@@ -334,7 +301,7 @@ def test_QRFull(x, mode, exc):
None,
),
(
set_test_value(
(
pt.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
......@@ -343,7 +310,7 @@ def test_QRFull(x, mode, exc):
None,
),
(
set_test_value(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
......@@ -354,7 +321,7 @@ def test_QRFull(x, mode, exc):
None,
),
(
set_test_value(
(
pt.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
......@@ -367,20 +334,13 @@ def test_QRFull(x, mode, exc):
],
)
def test_SVD(x, full_matrices, compute_uv, exc):
x, test_x = x
g = nlinalg.SVD(full_matrices, compute_uv)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x],
g,
[test_x],
)
......@@ -3,7 +3,6 @@ import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor.pad import PadMode
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -58,10 +57,10 @@ def test_numba_pad(mode: PadMode, kwargs):
x = np.random.normal(size=(3, 3))
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
res_fg = FunctionGraph([x_pt], [res])
compare_numba_and_py(
res_fg,
[x_pt],
[res],
[x],
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
py_mode="FAST_RUN",
......
......@@ -10,13 +10,9 @@ import pytensor.tensor.random.basic as ptr
from pytensor import shared
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from tests.link.numba.test_basic import (
compare_numba_and_py,
numba_mode,
set_test_value,
)
from tests.tensor.random.test_basic import (
batched_permutation_tester,
......@@ -159,11 +155,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.uniform,
[
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
......@@ -173,15 +169,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.triangular,
[
set_test_value(
(
pt.dscalar(),
np.array(-5.0, dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(5.0, dtype=np.float64),
),
......@@ -191,11 +187,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.lognormal,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -205,11 +201,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.pareto,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([2.0, 10.0], dtype=np.float64),
),
......@@ -219,7 +215,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.exponential,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
......@@ -229,7 +225,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.weibull,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
......@@ -239,11 +235,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.logistic,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -253,7 +249,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.geometric,
[
set_test_value(
(
pt.dvector(),
np.array([0.3, 0.4], dtype=np.float64),
),
......@@ -263,15 +259,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
pytest.param(
ptr.hypergeometric,
[
set_test_value(
(
pt.lscalar(),
np.array(7, dtype=np.int64),
),
set_test_value(
(
pt.lscalar(),
np.array(8, dtype=np.int64),
),
set_test_value(
(
pt.lscalar(),
np.array(15, dtype=np.int64),
),
......@@ -282,11 +278,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.wald,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -296,11 +292,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.laplace,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -310,11 +306,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.binomial,
[
set_test_value(
(
pt.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
(
pt.dscalar(),
np.array(0.9, dtype=np.float64),
),
......@@ -324,21 +320,21 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.normal,
[
set_test_value(
(
pt.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
pt.as_tensor(tuple(set_test_value(pt.lscalar(), v) for v in [3, 2])),
pt.as_tensor([3, 2]),
),
(
ptr.poisson,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
......@@ -348,11 +344,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.halfnormal,
[
set_test_value(
(
pt.lvector(),
np.array([1, 2], dtype=np.int64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -362,7 +358,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.bernoulli,
[
set_test_value(
(
pt.dvector(),
np.array([0.1, 0.9], dtype=np.float64),
),
......@@ -372,11 +368,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.beta,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -386,11 +382,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr._gamma,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dvector(),
np.array([0.5, 3.0], dtype=np.float64),
),
......@@ -400,7 +396,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.chisquare,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
......@@ -410,11 +406,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.negative_binomial,
[
set_test_value(
(
pt.lvector(),
np.array([100, 200], dtype=np.int64),
),
set_test_value(
(
pt.dscalar(),
np.array(0.09, dtype=np.float64),
),
......@@ -424,11 +420,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.vonmises,
[
set_test_value(
(
pt.dvector(),
np.array([-0.5, 0.5], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -438,14 +434,14 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
ptr.permutation,
[
set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)),
(pt.dmatrix(), np.eye(5, dtype=np.float64)),
],
(),
),
(
partial(ptr.choice, replace=True),
[
set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)),
(pt.dmatrix(), np.eye(5, dtype=np.float64)),
],
pt.as_tensor([2]),
),
......@@ -455,17 +451,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a, p=p, size=size, replace=True, rng=rng
),
[
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value(
pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
),
(pt.dmatrix(), np.eye(3, dtype=np.float64)),
(pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)),
],
(pt.as_tensor([2, 3])),
),
pytest.param(
partial(ptr.choice, replace=False),
[
set_test_value(pt.dvector(), np.arange(5, dtype=np.float64)),
(pt.dvector(), np.arange(5, dtype=np.float64)),
],
pt.as_tensor([2]),
marks=pytest.mark.xfail(
......@@ -476,7 +470,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
pytest.param(
partial(ptr.choice, replace=False),
[
set_test_value(pt.dmatrix(), np.eye(5, dtype=np.float64)),
(pt.dmatrix(), np.eye(5, dtype=np.float64)),
],
pt.as_tensor([2]),
marks=pytest.mark.xfail(
......@@ -490,8 +484,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a, p=p, size=size, replace=False, rng=rng
),
[
set_test_value(pt.vector(), np.arange(5, dtype=np.float64)),
set_test_value(
(pt.vector(), np.arange(5, dtype=np.float64)),
(
pt.dvector(),
np.array([0.5, 0.0, 0.25, 0.0, 0.25], dtype=np.float64),
),
......@@ -504,10 +498,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a, p=p, size=size, replace=False, rng=rng
),
[
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value(
pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
),
(pt.dmatrix(), np.eye(3, dtype=np.float64)),
(pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)),
],
(),
),
......@@ -517,10 +509,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a, p=p, size=size, replace=False, rng=rng
),
[
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value(
pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
),
(pt.dmatrix(), np.eye(3, dtype=np.float64)),
(pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)),
],
(pt.as_tensor([2, 1])),
),
......@@ -529,17 +519,14 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
)
def test_aligned_RandomVariable(rv_op, dist_args, size):
"""Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers."""
dist_args, test_dist_args = zip(*dist_args, strict=True)
rng = shared(np.random.default_rng(29402))
g = rv_op(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
dist_args,
[g],
test_dist_args,
eval_obj_mode=False, # No python impl
)
......@@ -550,11 +537,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
(
ptr.cauchy,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -566,11 +553,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
(
ptr.gumbel,
[
set_test_value(
(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
......@@ -583,18 +570,13 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
)
def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv):
"""Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers."""
dist_args, test_dist_args = zip(*dist_args, strict=True)
rng = shared(np.random.default_rng(29402))
g = rv_op(*dist_args, size=(2000, *base_size), rng=rng)
g_fn = function(dist_args, g, mode=numba_mode)
samples = g_fn(
*[
i.tag.test_value
for i in g_fn.maker.fgraph.inputs
if not isinstance(i, SharedVariable | Constant)
]
)
samples = g_fn(*test_dist_args)
bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_args])
bcast_dist_args = np.broadcast_arrays(*test_dist_args)
for idx in np.ndindex(*base_size):
cdf_params = params_conv(*(arg[idx] for arg in bcast_dist_args))
......@@ -608,7 +590,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
"a, size, cm",
[
pytest.param(
set_test_value(
(
pt.dvector(),
np.array([100000, 1, 1], dtype=np.float64),
),
......@@ -616,7 +598,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
contextlib.suppress(),
),
pytest.param(
set_test_value(
(
pt.dmatrix(),
np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
......@@ -627,7 +609,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
contextlib.suppress(),
),
pytest.param(
set_test_value(
(
pt.dmatrix(),
np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
......@@ -643,13 +625,12 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
],
)
def test_DirichletRV(a, size, cm):
a, a_val = a
rng = shared(np.random.default_rng(29402))
g = ptr.dirichlet(a, size=size, rng=rng)
g_fn = function([a], g, mode=numba_mode)
with cm:
a_val = a.tag.test_value
all_samples = []
for i in range(1000):
samples = g_fn(a_val)
......
......@@ -5,13 +5,10 @@ import pytensor.scalar as ps
import pytensor.scalar.basic as psb
import pytensor.tensor as pt
from pytensor import config
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value
from tests.link.numba.test_basic import compare_numba_and_py
rng = np.random.default_rng(42849)
......@@ -21,48 +18,43 @@ rng = np.random.default_rng(42849)
"x, y",
[
(
set_test_value(pt.lvector(), np.arange(4, dtype="int64")),
set_test_value(pt.dvector(), np.arange(4, dtype="float64")),
(pt.lvector(), np.arange(4, dtype="int64")),
(pt.dvector(), np.arange(4, dtype="float64")),
),
(
set_test_value(pt.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))),
set_test_value(pt.lscalar(), np.array(4, dtype="int64")),
(pt.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))),
(pt.lscalar(), np.array(4, dtype="int64")),
),
],
)
def test_Second(x, y):
x, x_test = x
y, y_test = y
# We use the `Elemwise`-wrapped version of `Second`
g = pt.second(x, y)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[x, y],
g,
[x_test, y_test],
)
@pytest.mark.parametrize(
"v, min, max",
[
(set_test_value(pt.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0),
(set_test_value(pt.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0),
(set_test_value(pt.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0),
((pt.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0),
((pt.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0),
((pt.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0),
],
)
def test_Clip(v, min, max):
v, v_test = v
g = ps.clip(v, min, max)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
[g],
[v_test],
)
......@@ -100,46 +92,39 @@ def test_Clip(v, min, max):
def test_Composite(inputs, input_values, scalar_fn):
composite_inputs = [ps.ScalarType(config.floatX)(name=i.name) for i in inputs]
comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)]))
out_fg = FunctionGraph(inputs, [comp_op(*inputs)])
compare_numba_and_py(out_fg, input_values)
compare_numba_and_py(inputs, [comp_op(*inputs)], input_values)
@pytest.mark.parametrize(
"v, dtype",
[
(set_test_value(pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64),
(set_test_value(pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32),
((pt.fscalar(), np.array(1.0, dtype="float32")), psb.float64),
((pt.dscalar(), np.array(1.0, dtype="float64")), psb.float32),
],
)
def test_Cast(v, dtype):
v, v_test = v
g = psb.Cast(dtype)(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
[g],
[v_test],
)
@pytest.mark.parametrize(
"v, dtype",
[
(set_test_value(pt.iscalar(), np.array(10, dtype="int32")), psb.float64),
((pt.iscalar(), np.array(10, dtype="int32")), psb.float64),
],
)
def test_reciprocal(v, dtype):
v, v_test = v
g = psb.reciprocal(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
[g],
[v_test],
)
......@@ -156,6 +141,7 @@ def test_isnan(composite):
out = pt.isnan(x)
compare_numba_and_py(
([x], [out]),
[x],
[out],
[np.array([1, 0], dtype="float64")],
)
......@@ -5,7 +5,6 @@ import pytensor
import pytensor.tensor as pt
from pytensor import config, function, grad
from pytensor.compile.mode import Mode, get_mode
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Log1p
from pytensor.scan.basic import scan
from pytensor.scan.op import Scan
......@@ -147,7 +146,7 @@ def test_xit_xot_types(
if output_vals is None:
compare_numba_and_py(
(sequences + non_sequences, res), input_vals, updates=updates
sequences + non_sequences, res, input_vals, updates=updates
)
else:
numba_mode = get_mode("NUMBA")
......@@ -217,10 +216,7 @@ def test_scan_multiple_output(benchmark):
logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp"
out_fg = FunctionGraph(
[pt_C, pt_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
out = [st, et, it, logp_c_all, logp_d_all]
s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=config.floatX)
......@@ -243,21 +239,21 @@ def test_scan_multiple_output(benchmark):
gamma_val,
delta_val,
]
scan_fn, _ = compare_numba_and_py(out_fg, test_input_vals)
scan_fn, _ = compare_numba_and_py(
[pt_C, pt_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
out,
test_input_vals,
)
benchmark(scan_fn, *test_input_vals)
@config.change_flags(compute_test_value="raise")
def test_scan_tap_output():
a_pt = pt.scalar("a")
a_pt.tag.test_value = 10.0
b_pt = pt.arange(11).astype(config.floatX)
b_pt.name = "b"
b_pt = pt.vector("b")
c_pt = pt.arange(20, 31, dtype=config.floatX)
c_pt.name = "c"
c_pt = pt.vector("c")
def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a):
x_tm1.name = "x_tm1"
......@@ -301,14 +297,12 @@ def test_scan_tap_output():
strict=True,
)
out_fg = FunctionGraph([a_pt, b_pt, c_pt], scan_res)
test_input_vals = [
np.array(10.0).astype(config.floatX),
np.arange(11, dtype=config.floatX),
np.arange(20, 31, dtype=config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)
compare_numba_and_py([a_pt, b_pt, c_pt], scan_res, test_input_vals)
def test_scan_while():
......@@ -323,12 +317,10 @@ def test_scan_while():
n_steps=1024,
)
out_fg = FunctionGraph([max_value], [values])
test_input_vals = [
np.array(45).astype(config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)
compare_numba_and_py([max_value], [values], test_input_vals)
def test_scan_multiple_none_output():
......@@ -343,11 +335,8 @@ def test_scan_multiple_none_output():
outputs_info=[pt.ones_like(A), None, None],
n_steps=3,
)
out_fg = FunctionGraph([A], result)
test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py(out_fg, test_input_vals)
compare_numba_and_py([A], result, test_input_vals)
@pytest.mark.parametrize("n_steps_val", [1, 5])
......@@ -372,11 +361,14 @@ def test_scan_save_mem_basic(n_steps_val):
numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([init_x, n_steps], [output])
test_input_vals = (state_val, n_steps_val)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
[init_x, n_steps],
[output],
test_input_vals,
numba_mode=numba_mode,
py_mode=py_mode,
)
......@@ -410,14 +402,12 @@ def test_mitmots_basic():
numba_mode = get_mode("NUMBA").including("scan_save_mem")
py_mode = Mode("py").including("scan_save_mem")
out_fg = FunctionGraph([seq, init_x], g_outs)
seq_val = np.arange(3)
init_x_val = np.r_[-2, -1]
test_input_vals = (seq_val, init_x_val)
compare_numba_and_py(
out_fg, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
[seq, init_x], g_outs, test_input_vals, numba_mode=numba_mode, py_mode=py_mode
)
......
......@@ -9,14 +9,14 @@ from scipy import linalg as scipy_linalg
import pytensor
import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from pytensor import config
from tests import unittest_tools as utt
from tests.link.numba.test_basic import compare_numba_and_py
numba = pytest.importorskip("numba")
floatX = pytensor.config.floatX
floatX = config.floatX
rng = np.random.default_rng(42849)
......@@ -88,7 +88,12 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
np.testing.assert_allclose(test_input @ X_np, b_val, atol=ATOL, rtol=RTOL)
compare_numba_and_py(f.maker.fgraph, [A_func(A_val.copy()), b_val.copy()])
compiled_fgraph = f.maker.fgraph
compare_numba_and_py(
compiled_fgraph.inputs,
compiled_fgraph.outputs,
[A_func(A_val.copy()), b_val.copy()],
)
@pytest.mark.parametrize(
......@@ -159,12 +164,10 @@ def test_numba_Cholesky(lower, trans):
cov_ = cov
chol = pt.linalg.cholesky(cov_, lower=lower)
fg = FunctionGraph(outputs=[chol])
x = np.array([0.1, 0.2, 0.3]).astype(floatX)
val = np.eye(3).astype(floatX) + x[None, :] * x[:, None]
compare_numba_and_py(fg, [val])
compare_numba_and_py([cov], [chol], [val])
def test_numba_Cholesky_raises_on_nan_input():
......@@ -218,8 +221,7 @@ def test_block_diag():
B_val = np.random.normal(size=(3, 3)).astype(floatX)
C_val = np.random.normal(size=(2, 2)).astype(floatX)
D_val = np.random.normal(size=(4, 4)).astype(floatX)
out_fg = pytensor.graph.FunctionGraph([A, B, C, D], [X])
compare_numba_and_py(out_fg, [A_val, B_val, C_val, D_val])
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
def test_lamch():
......@@ -390,7 +392,7 @@ def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
)
op = f.maker.fgraph.outputs[0].owner.op
compare_numba_and_py(([A, b], [X]), inputs=[A_val, b_val], inplace=True)
compare_numba_and_py([A, b], [X], test_inputs=[A_val, b_val], inplace=True)
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
A_val_copy = A_val.copy()
......
......@@ -100,4 +100,4 @@ def test_sparse_objmode():
UserWarning,
match="Numba will use object mode to run SparseDot's perform method",
):
compare_numba_and_py(((x, y), (out,)), [x_val, y_val])
compare_numba_and_py([x, y], out, [x_val, y_val])
......@@ -4,7 +4,6 @@ import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -44,8 +43,7 @@ def test_Subtensor(x, indices):
"""Test NumPy's basic indexing."""
out_pt = x[indices]
assert isinstance(out_pt.owner.op, Subtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
compare_numba_and_py([], [out_pt], [])
@pytest.mark.parametrize(
......@@ -59,16 +57,14 @@ def test_AdvancedSubtensor1(x, indices):
"""Test NumPy's advanced indexing in one dimension."""
out_pt = advanced_subtensor1(x, *indices)
assert isinstance(out_pt.owner.op, AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
compare_numba_and_py([], [out_pt], [])
def test_AdvancedSubtensor1_out_of_bounds():
out_pt = advanced_subtensor1(np.arange(3), [4])
assert isinstance(out_pt.owner.op, AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_pt])
with pytest.raises(IndexError):
compare_numba_and_py(out_fg, [])
compare_numba_and_py([], [out_pt], [])
@pytest.mark.parametrize(
......@@ -151,7 +147,6 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
x_pt = x.type()
out_pt = x_pt[indices]
assert isinstance(out_pt.owner.op, AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
with (
pytest.warns(
UserWarning,
......@@ -161,7 +156,8 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
else contextlib.nullcontext()
):
compare_numba_and_py(
out_fg,
[x_pt],
[out_pt],
[x.data],
numba_mode=numba_mode.including("specialize"),
)
......@@ -195,19 +191,16 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
def test_IncSubtensor(x, y, indices):
out_pt = set_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
compare_numba_and_py([], [out_pt], [])
out_pt = inc_subtensor(x[indices], y)
assert isinstance(out_pt.owner.op, IncSubtensor)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
compare_numba_and_py([], [out_pt], [])
x_pt = x.type()
out_pt = set_subtensor(x_pt[indices], y, inplace=True)
assert isinstance(out_pt.owner.op, IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data])
compare_numba_and_py([x_pt], [out_pt], [x.data])
@pytest.mark.parametrize(
......@@ -249,13 +242,11 @@ def test_IncSubtensor(x, y, indices):
def test_AdvancedIncSubtensor1(x, y, indices):
out_pt = advanced_set_subtensor1(x, y, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
compare_numba_and_py([], [out_pt], [])
out_pt = advanced_inc_subtensor1(x, y, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([], [out_pt])
compare_numba_and_py(out_fg, [])
compare_numba_and_py([], [out_pt], [])
# With symbolic inputs
x_pt = x.type()
......@@ -263,15 +254,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
out_pt = AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data, y.data])
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data])
out_pt = AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)(
x_pt, y_pt, *indices
)
assert isinstance(out_pt.owner.op, AdvancedIncSubtensor1)
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
compare_numba_and_py(out_fg, [x.data, y.data])
compare_numba_and_py([x_pt, y_pt], [out_pt], [x.data, y.data])
@pytest.mark.parametrize(
......@@ -454,7 +443,7 @@ def test_AdvancedIncSubtensor(
if set_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode)
fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode)
if inplace:
# Test updates inplace
......@@ -474,7 +463,7 @@ def test_AdvancedIncSubtensor(
if inc_requires_objmode
else contextlib.nullcontext()
):
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode)
fn, _ = compare_numba_and_py([x_pt, y_pt], out_pt, [x, y], numba_mode=mode)
if inplace:
# Test updates inplace
x_orig = x.copy()
......
......@@ -6,15 +6,11 @@ import pytensor.tensor as pt
import pytensor.tensor.basic as ptb
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Add
from pytensor.tensor.shape import Unbroadcast
from tests.link.numba.test_basic import (
compare_numba_and_py,
compare_shape_dtype,
set_test_value,
)
from tests.tensor.test_basic import check_alloc_runtime_broadcast
......@@ -31,21 +27,18 @@ rng = np.random.default_rng(42849)
[
(0.0, (2, 3)),
(1.1, (2, 3)),
(set_test_value(pt.scalar("a"), np.array(10.0, dtype=config.floatX)), (20,)),
(set_test_value(pt.vector("a"), np.ones(10, dtype=config.floatX)), (20, 10)),
((pt.scalar("a"), np.array(10.0, dtype=config.floatX)), (20,)),
((pt.vector("a"), np.ones(10, dtype=config.floatX)), (20, 10)),
],
)
def test_Alloc(v, shape):
v, v_test = v if isinstance(v, tuple) else (v, None)
g = pt.alloc(v, *shape)
g_fg = FunctionGraph(outputs=[g])
_, (numba_res,) = compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v] if v_test is not None else [],
[g],
[v_test] if v_test is not None else [],
)
assert numba_res.shape == shape
......@@ -57,58 +50,38 @@ def test_alloc_runtime_broadcast():
def test_AllocEmpty():
x = pt.empty((2, 3), dtype="float32")
x_fg = FunctionGraph([], [x])
# We cannot compare the values in the arrays, only the shapes and dtypes
compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype)
compare_numba_and_py([], x, [], assert_fn=compare_shape_dtype)
@pytest.mark.parametrize(
"v", [set_test_value(ps.float64(), np.array(1.0, dtype="float64"))]
)
def test_TensorFromScalar(v):
def test_TensorFromScalar():
v, v_test = ps.float64(), np.array(1.0, dtype="float64")
g = ptb.TensorFromScalar()(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
g,
[v_test],
)
@pytest.mark.parametrize(
"v",
[
set_test_value(pt.scalar(), np.array(1.0, dtype=config.floatX)),
],
)
def test_ScalarFromTensor(v):
def test_ScalarFromTensor():
v, v_test = pt.scalar(), np.array(1.0, dtype=config.floatX)
g = ptb.ScalarFromTensor()(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
g,
[v_test],
)
def test_Unbroadcast():
v = set_test_value(pt.row(), np.array([[1.0, 2.0]], dtype=config.floatX))
v, v_test = pt.row(), np.array([[1.0, 2.0]], dtype=config.floatX)
g = Unbroadcast(0)(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[v],
g,
[v_test],
)
......@@ -117,65 +90,52 @@ def test_Unbroadcast():
[
(
(
set_test_value(pt.scalar(), np.array(1, dtype=config.floatX)),
set_test_value(pt.scalar(), np.array(2, dtype=config.floatX)),
set_test_value(pt.scalar(), np.array(3, dtype=config.floatX)),
(pt.scalar(), np.array(1, dtype=config.floatX)),
(pt.scalar(), np.array(2, dtype=config.floatX)),
(pt.scalar(), np.array(3, dtype=config.floatX)),
),
config.floatX,
),
(
(
set_test_value(pt.dscalar(), np.array(1, dtype=np.float64)),
set_test_value(pt.lscalar(), np.array(3, dtype=np.int32)),
(pt.dscalar(), np.array(1, dtype=np.float64)),
(pt.lscalar(), np.array(3, dtype=np.int32)),
),
"float64",
),
(
(set_test_value(pt.iscalar(), np.array(1, dtype=np.int32)),),
((pt.iscalar(), np.array(1, dtype=np.int32)),),
"float64",
),
(
(set_test_value(pt.scalar(dtype=bool), True),),
((pt.scalar(dtype=bool), True),),
bool,
),
],
)
def test_MakeVector(vals, dtype):
vals, vals_test = zip(*vals, strict=True)
g = ptb.MakeVector(dtype)(*vals)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
vals,
[g],
vals_test,
)
@pytest.mark.parametrize(
"start, stop, step, dtype",
[
(
set_test_value(pt.lscalar(), np.array(1)),
set_test_value(pt.lscalar(), np.array(10)),
set_test_value(pt.lscalar(), np.array(3)),
config.floatX,
),
],
)
def test_ARange(start, stop, step, dtype):
def test_ARange():
start, start_test = pt.lscalar(), np.array(1)
stop, stop_tset = pt.lscalar(), np.array(10)
step, step_test = pt.lscalar(), np.array(3)
dtype = config.floatX
g = ptb.ARange(dtype)(start, stop, step)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[start, stop, step],
g,
[start_test, stop_tset, step_test],
)
......@@ -184,80 +144,60 @@ def test_ARange(start, stop, step, dtype):
[
(
(
set_test_value(
pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)
),
set_test_value(
pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)
),
(pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)),
(pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)),
),
0,
),
(
(
set_test_value(
pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX)
),
set_test_value(
pt.matrix(), rng.normal(size=(3, 1)).astype(config.floatX)
),
(pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX)),
(pt.matrix(), rng.normal(size=(3, 1)).astype(config.floatX)),
),
0,
),
(
(
set_test_value(
pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)
),
set_test_value(
pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)
),
(pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)),
(pt.matrix(), rng.normal(size=(1, 2)).astype(config.floatX)),
),
1,
),
(
(
set_test_value(
pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)
),
set_test_value(
pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX)
),
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
(pt.matrix(), rng.normal(size=(2, 1)).astype(config.floatX)),
),
1,
),
],
)
def test_Join(vals, axis):
vals, vals_test = zip(*vals, strict=True)
g = pt.join(axis, *vals)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
vals,
g,
vals_test,
)
def test_Join_view():
vals = (
set_test_value(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
set_test_value(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
vals, vals_test = zip(
*(
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
),
strict=True,
)
g = ptb.Join(view=1)(1, *vals)
g_fg = FunctionGraph(outputs=[g])
with pytest.raises(NotImplementedError):
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
vals,
g,
vals_test,
)
......@@ -267,57 +207,47 @@ def test_Join_view():
(
0,
0,
set_test_value(pt.vector(), rng.normal(size=20).astype(config.floatX)),
set_test_value(pt.vector(dtype="int64"), []),
(pt.vector(), rng.normal(size=20).astype(config.floatX)),
(pt.vector(dtype="int64"), []),
),
(
5,
0,
set_test_value(pt.vector(), rng.normal(size=5).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5)
),
(pt.vector(), rng.normal(size=5).astype(config.floatX)),
(pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5)),
),
(
5,
0,
set_test_value(pt.vector(), rng.normal(size=10).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5)
),
(pt.vector(), rng.normal(size=10).astype(config.floatX)),
(pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5)),
),
(
5,
-1,
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5)
),
(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
(pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5)),
),
(
5,
-2,
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5)
),
(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
(pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5)),
),
],
)
def test_Split(n_splits, axis, values, sizes):
values, values_test = values
sizes, sizes_test = sizes
g = pt.split(values, sizes, n_splits, axis=axis)
assert len(g) == n_splits
if n_splits == 0:
return
g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g)
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[values, sizes],
g,
[values_test, sizes_test],
)
......@@ -349,34 +279,27 @@ def test_Split_view():
"val, offset",
[
(
set_test_value(
pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))
),
(pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))),
0,
),
(
set_test_value(
pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))
),
(pt.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))),
-1,
),
(
set_test_value(pt.vector(), np.arange(10, dtype=config.floatX)),
(pt.vector(), np.arange(10, dtype=config.floatX)),
0,
),
],
)
def test_ExtractDiag(val, offset):
val, val_test = val
g = pt.diag(val, offset)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[val],
g,
[val_test],
)
......@@ -407,30 +330,28 @@ def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
@pytest.mark.parametrize(
"n, m, k, dtype",
[
(set_test_value(pt.lscalar(), np.array(1, dtype=np.int64)), None, 0, None),
((pt.lscalar(), np.array(1, dtype=np.int64)), None, 0, None),
(
set_test_value(pt.lscalar(), np.array(1, dtype=np.int64)),
set_test_value(pt.lscalar(), np.array(2, dtype=np.int64)),
(pt.lscalar(), np.array(1, dtype=np.int64)),
(pt.lscalar(), np.array(2, dtype=np.int64)),
0,
"float32",
),
(
set_test_value(pt.lscalar(), np.array(1, dtype=np.int64)),
set_test_value(pt.lscalar(), np.array(2, dtype=np.int64)),
(pt.lscalar(), np.array(1, dtype=np.int64)),
(pt.lscalar(), np.array(2, dtype=np.int64)),
1,
"int64",
),
],
)
def test_Eye(n, m, k, dtype):
n, n_test = n
m, m_test = m if m is not None else (None, None)
g = pt.eye(n, m, k, dtype=dtype)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
[n, m] if m is not None else [n],
g,
[n_test, m_test] if m is not None else [n_test],
)
......@@ -9,10 +9,10 @@ import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import PYTORCH, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
......@@ -39,10 +39,10 @@ py_mode = Mode(linker="py", optimizer=None)
def compare_pytorch_and_py(
fgraph: FunctionGraph,
graph_inputs: Iterable[Variable],
graph_outputs: Variable | Iterable[Variable],
test_inputs: Iterable,
assert_fn: Callable | None = None,
must_be_device_array: bool = True,
pytorch_mode=pytorch_mode,
py_mode=py_mode,
):
......@@ -50,8 +50,10 @@ def compare_pytorch_and_py(
Parameters
----------
fgraph: FunctionGraph
PyTensor function Graph object
graph_inputs
Symbolic inputs to the graph
graph_outputs:
Symbolic outputs of the graph
test_inputs: iter
Numerical inputs for testing the function graph
assert_fn: func, opt
......@@ -63,24 +65,22 @@ def compare_pytorch_and_py(
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
if any(inp.owner is not None for inp in graph_inputs):
raise ValueError("Inputs must be root variables")
pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode)
pytensor_torch_fn = function(graph_inputs, graph_outputs, mode=pytorch_mode)
pytorch_res = pytensor_torch_fn(*test_inputs)
if isinstance(pytorch_res, list):
assert all(isinstance(res, np.ndarray) for res in pytorch_res)
else:
assert isinstance(pytorch_res, np.ndarray)
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1:
if isinstance(graph_outputs, list | tuple):
for pytorch_res_i, py_res_i in zip(pytorch_res, py_res, strict=True):
assert not isinstance(pytorch_res_i, torch.Tensor)
assert_fn(pytorch_res_i, py_res_i)
else:
assert_fn(pytorch_res[0], py_res[0])
assert not isinstance(pytorch_res, torch.Tensor)
assert_fn(pytorch_res, py_res)
return pytensor_torch_fn, pytorch_res
......@@ -231,7 +231,8 @@ def test_alloc_and_empty():
v = vector("v", shape=(3,), dtype="float64")
out = alloc(v, dim0, dim1, 3)
compare_pytorch_and_py(
FunctionGraph([v, dim1], [out]),
[v, dim1],
[out],
[np.array([1, 2, 3]), np.array(7)],
)
......@@ -244,7 +245,8 @@ def test_arange():
out = arange(start, stop, step, dtype="int16")
compare_pytorch_and_py(
FunctionGraph([start, stop, step], [out]),
[start, stop, step],
[out],
[np.array(1), np.array(10), np.array(2)],
)
......@@ -254,16 +256,18 @@ def test_pytorch_Join():
b = matrix("b")
x = ptb.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[a, b],
[x],
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[a, b],
[x],
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
......@@ -271,16 +275,18 @@ def test_pytorch_Join():
)
x = ptb.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_pytorch_and_py(
x_fg,
[a, b],
[x],
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_pytorch_and_py(
x_fg,
[a, b],
[x],
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
......@@ -309,9 +315,8 @@ def test_eye(dtype):
def test_pytorch_MakeVector():
x = ptb.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, [])
compare_pytorch_and_py([], [x], [])
def test_pytorch_ifelse():
......@@ -320,15 +325,13 @@ def test_pytorch_ifelse():
a = scalar("a")
x = ifelse(a < 0.5, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)
compare_pytorch_and_py(x_fg, np.array([0.2], dtype=config.floatX))
compare_pytorch_and_py([a], x, np.array([0.2], dtype=config.floatX))
a = scalar("a")
x = ifelse(a < 0.4, tuple(np.r_[p1_vals, p2_vals]), tuple(np.r_[p2_vals, p1_vals]))
x_fg = FunctionGraph([a], x)
compare_pytorch_and_py(x_fg, np.array([0.5], dtype=config.floatX))
compare_pytorch_and_py([a], x, np.array([0.5], dtype=config.floatX))
def test_pytorch_OpFromGraph():
......@@ -343,8 +346,7 @@ def test_pytorch_OpFromGraph():
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])
compare_pytorch_and_py([x, y, z], [out], [xv, yv, zv])
def test_pytorch_link_references():
......@@ -380,15 +382,13 @@ def test_pytorch_link_references():
def test_pytorch_scipy():
x = vector("a", shape=(3,))
out = expit(x)
f = FunctionGraph([x], [out])
compare_pytorch_and_py(f, [np.random.rand(3)])
compare_pytorch_and_py([x], [out], [np.random.rand(3)])
def test_pytorch_softplus():
x = vector("a", shape=(3,))
out = softplus(x)
f = FunctionGraph([x], [out])
compare_pytorch_and_py(f, [np.random.rand(3)])
compare_pytorch_and_py([x], [out], [np.random.rand(3)])
def test_ScalarLoop():
......@@ -436,13 +436,15 @@ def test_ScalarLoop_Elemwise_single_carries():
x0 = pt.vector("x0", dtype="float32")
state, done = op(n_steps, x0)
f = FunctionGraph([n_steps, x0], [state, done])
args = [
np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"),
]
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
[n_steps, x0],
[state, done],
args,
assert_fn=partial(np.testing.assert_allclose, rtol=1e-6),
)
......@@ -462,14 +464,16 @@ def test_ScalarLoop_Elemwise_multi_carries():
x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1))
*states, done = op(n_steps, x0, x1)
f = FunctionGraph([n_steps, x0, x1], [*states, done])
args = [
np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"),
np.random.rand(7, 3, 1).astype("float32"),
]
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
[n_steps, x0, x1],
[*states, done],
args,
assert_fn=partial(np.testing.assert_allclose, rtol=1e-6),
)
......@@ -518,6 +522,5 @@ def test_Split(n_splits, axis, values, sizes):
assert len(g) == n_splits
if n_splits == 0:
return
g_fg = FunctionGraph(inputs=[i, s], outputs=[g] if n_splits == 1 else g)
compare_pytorch_and_py(g_fg, [values, sizes])
compare_pytorch_and_py([i, s], g, [values, sizes])
......@@ -2,7 +2,6 @@ import numpy as np
import pytest
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import blas as pt_blas
from pytensor.tensor.type import tensor3
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -15,8 +14,8 @@ def test_pytorch_BatchedDot():
b = tensor3("b")
b_test = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
pytensor_pytorch_fn, _ = compare_pytorch_and_py(fgraph, [a_test, b_test])
pytensor_pytorch_fn, _ = compare_pytorch_and_py([a, b], [out], [a_test, b_test])
# A dimension mismatch should raise a TypeError for compatibility
inputs = [a_test[:-1], b_test]
......
......@@ -5,7 +5,6 @@ import pytensor
import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import ScalarOp, get_scalar_type
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
......@@ -20,17 +19,23 @@ def test_pytorch_Dimshuffle():
a_pt = matrix("a")
x = a_pt.T
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
compare_pytorch_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)
x = a_pt.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
compare_pytorch_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]
)
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = a_pt.dimshuffle((0,))
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
compare_pytorch_and_py(
[a_pt], [x], [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]
)
def test_multiple_input_output():
......@@ -38,24 +43,21 @@ def test_multiple_input_output():
y = vector("y")
out = pt.mul(x, y)
fg = FunctionGraph(outputs=[out], clone=False)
compare_pytorch_and_py(fg, [[1.5], [2.5]])
compare_pytorch_and_py([x, y], [out], [[1.5], [2.5]])
x = vector("x")
y = vector("y")
div = pt.int_div(x, y)
pt_sum = pt.add(y, x)
fg = FunctionGraph(outputs=[div, pt_sum], clone=False)
compare_pytorch_and_py(fg, [[1.5], [2.5]])
compare_pytorch_and_py([x, y], [div, pt_sum], [[1.5], [2.5]])
def test_pytorch_elemwise():
x = pt.vector("x")
out = pt.log(1 - x)
fg = FunctionGraph([x], [out])
compare_pytorch_and_py(fg, [[0.9, 0.9]])
compare_pytorch_and_py([x], [out], [[0.9, 0.9]])
@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min])
......@@ -81,9 +83,8 @@ def test_pytorch_careduce(fn, axis):
).astype(config.floatX)
x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [test_value])
compare_pytorch_and_py([a_pt], [x], [test_value])
@pytest.mark.parametrize("fn", [ptm.any, ptm.all])
......@@ -93,9 +94,8 @@ def test_pytorch_any_all(fn, axis):
test_value = np.array([[True, False, True], [False, True, True]])
x = fn(a_pt, axis=axis)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [test_value])
compare_pytorch_and_py([a_pt], [x], [test_value])
@pytest.mark.parametrize("dtype", ["float64", "int64"])
......@@ -103,7 +103,6 @@ def test_pytorch_any_all(fn, axis):
def test_softmax(axis, dtype):
x = matrix("x", dtype=dtype)
out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)
if dtype == "int64":
......@@ -111,9 +110,9 @@ def test_softmax(axis, dtype):
NotImplementedError,
match="Pytorch Softmax is not currently implemented for non-float types.",
):
compare_pytorch_and_py(fgraph, [test_input])
compare_pytorch_and_py([x], [out], [test_input])
else:
compare_pytorch_and_py(fgraph, [test_input])
compare_pytorch_and_py([x], [out], [test_input])
@pytest.mark.parametrize("dtype", ["float64", "int64"])
......@@ -121,7 +120,6 @@ def test_softmax(axis, dtype):
def test_logsoftmax(axis, dtype):
x = matrix("x", dtype=dtype)
out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)
if dtype == "int64":
......@@ -129,9 +127,9 @@ def test_logsoftmax(axis, dtype):
NotImplementedError,
match="Pytorch LogSoftmax is not currently implemented for non-float types.",
):
compare_pytorch_and_py(fgraph, [test_input])
compare_pytorch_and_py([x], [out], [test_input])
else:
compare_pytorch_and_py(fgraph, [test_input])
compare_pytorch_and_py([x], [out], [test_input])
@pytest.mark.parametrize("axis", [None, 0, 1])
......@@ -141,16 +139,14 @@ def test_softmax_grad(axis):
sm = matrix("sm")
sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_pytorch_and_py(fgraph, [dy_value, sm_value])
compare_pytorch_and_py([dy, sm], [out], [dy_value, sm_value])
def test_cast():
x = matrix("x", dtype="float32")
out = pt.cast(x, "int32")
fgraph = FunctionGraph([x], [out])
_, [res] = compare_pytorch_and_py(
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
[x], [out], [np.arange(6, dtype="float32").reshape(2, 3)]
)
assert res.dtype == np.int32
......
......@@ -2,7 +2,6 @@ import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -31,16 +30,14 @@ def test_pytorch_CumOp(axis, dtype):
out = pt.cumprod(a, axis=axis)
else:
out = pt.cumsum(a, axis=axis)
# Create a PyTensor `FunctionGraph`
fgraph = FunctionGraph([a], [out])
# Pass the graph and inputs to the testing function
compare_pytorch_and_py(fgraph, [test_value])
# Pass the inputs and outputs to the testing function
compare_pytorch_and_py([a], [out], [test_value])
# For the second mode of CumOp
out = pt.cumprod(a, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])
compare_pytorch_and_py([a], [out], [test_value])
@pytest.mark.parametrize("axis, repeats", [(0, (1, 2, 3)), (1, (3, 3)), (None, 3)])
......@@ -50,8 +47,8 @@ def test_pytorch_Repeat(axis, repeats):
test_value = np.arange(6, dtype="float64").reshape((3, 2))
out = pt.repeat(a, repeats, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])
compare_pytorch_and_py([a], [out], [test_value])
@pytest.mark.parametrize("axis", [None, 0, 1])
......@@ -63,8 +60,8 @@ def test_pytorch_Unique_axis(axis):
)
out = pt.unique(a, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])
compare_pytorch_and_py([a], [out], [test_value])
@pytest.mark.parametrize("return_inverse", [False, True])
......@@ -86,5 +83,7 @@ def test_pytorch_Unique_params(return_index, return_inverse, return_counts):
return_counts=return_counts,
axis=0,
)
fgraph = FunctionGraph([a], [out[0] if isinstance(out, list) else out])
compare_pytorch_and_py(fgraph, [test_value])
compare_pytorch_and_py(
[a], [out[0] if isinstance(out, list) else out], [test_value]
)
import numpy as np
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.type import matrix, scalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -20,10 +19,12 @@ def test_pytorch_dot():
# 2D * 2D
out = A.dot(A * alpha) + beta * A
fgraph = FunctionGraph([A, alpha, beta], [out])
compare_pytorch_and_py(fgraph, [A_test, alpha_test, beta_test])
compare_pytorch_and_py([A, alpha, beta], [out], [A_test, alpha_test, beta_test])
# 1D * 2D and 1D * 1D
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_pytorch_and_py(fgraph, [y_test, x_test, A_test, alpha_test, beta_test])
compare_pytorch_and_py(
[y, x, A, alpha, beta], [out], [y_test, x_test, A_test, alpha_test, beta_test]
)
from collections.abc import Sequence
import numpy as np
import pytest
from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg as pt_nla
from pytensor.tensor.type import matrix
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -29,13 +26,12 @@ def matrix_test():
def test_lin_alg_no_params(func, matrix_test):
x, test_value = matrix_test
out = func(x)
out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out])
outs = func(x)
def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)
compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn)
compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn)
@pytest.mark.parametrize(
......@@ -50,8 +46,8 @@ def test_lin_alg_no_params(func, matrix_test):
def test_qr(mode, matrix_test):
x, test_value = matrix_test
outs = pt_nla.qr(x, mode=mode)
out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs])
compare_pytorch_and_py(out_fg, [test_value])
compare_pytorch_and_py([x], outs, [test_value])
@pytest.mark.parametrize("compute_uv", [True, False])
......@@ -60,18 +56,16 @@ def test_svd(compute_uv, full_matrices, matrix_test):
x, test_value = matrix_test
out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
compare_pytorch_and_py(out_fg, [test_value])
compare_pytorch_and_py([x], out, [test_value])
def test_pinv():
x = matrix("x")
x_inv = pt_nla.pinv(x)
fgraph = FunctionGraph([x], [x_inv])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_pytorch_and_py(fgraph, [x_np])
compare_pytorch_and_py([x], [x_inv], [x_np])
@pytest.mark.parametrize("hermitian", [False, True])
......@@ -106,8 +100,7 @@ def test_kron():
y = matrix("y")
z = pt_nla.kron(x, y)
fgraph = FunctionGraph([x, y], [z])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_pytorch_and_py(fgraph, [x_np, y_np])
compare_pytorch_and_py([x, y], [z], [x_np, y_np])
......@@ -2,7 +2,6 @@ import numpy as np
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -11,29 +10,27 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_pytorch_shape_ops():
x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, [], must_be_device_array=False)
compare_pytorch_and_py([], [x], [])
x = Shape_i(1)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, [], must_be_device_array=False)
compare_pytorch_and_py([], [x], [])
def test_pytorch_specify_shape():
in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None))
x_fg = FunctionGraph([in_pt], [x])
compare_pytorch_and_py(x_fg, [np.ones((4, 5)).astype(config.floatX)])
compare_pytorch_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)])
# When used to assert two arrays have similar shapes
in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape)
x_fg = FunctionGraph([in_pt, shape_pt], [x])
compare_pytorch_and_py(
x_fg,
[in_pt, shape_pt],
[x],
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
)
......@@ -41,21 +38,22 @@ def test_pytorch_specify_shape():
def test_pytorch_Reshape_constant():
a = vector("a")
x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
compare_pytorch_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
def test_pytorch_Reshape_dynamic():
a = vector("a")
shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt))
x_fg = FunctionGraph([a, shape_pt], [x])
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
compare_pytorch_and_py(
[a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
)
def test_pytorch_unbroadcast():
x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_pytorch_and_py(x_fg, [])
compare_pytorch_and_py([], [x], [])
import numpy as np
import pytest
from pytensor.graph import FunctionGraph
from pytensor.tensor import matrix
from pytensor.tensor.sort import argsort, sort
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -12,6 +11,5 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
fgraph = FunctionGraph([x], [out])
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_pytorch_and_py(fgraph, [arr])
compare_pytorch_and_py([x], [out], [arr])
......@@ -6,7 +6,6 @@ import pytest
import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import inc_subtensor, set_subtensor
from pytensor.tensor import subtensor as pt_subtensor
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -19,38 +18,33 @@ def test_pytorch_Subtensor():
out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
# symbolic index
a_pt = ps.int64("a")
a_np = 1
out_pt = x_pt[a_pt, 2, a_pt:2]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np, a_np])
compare_pytorch_and_py([x_pt, a_pt], [out_pt], [x_np, a_np])
with pytest.raises(
NotImplementedError, match="Negative step sizes are not supported in Pytorch"
):
out_pt = x_pt[::-1]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
def test_pytorch_AdvSubtensor():
......@@ -60,52 +54,43 @@ def test_pytorch_AdvSubtensor():
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], 1:]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], None]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
a_pt = ps.int64("a")
a_np = 2
out_pt = x_pt[[1, a_pt], a_pt]
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np, a_np])
compare_pytorch_and_py([x_pt, a_pt], [out_pt], [x_np, a_np])
# boolean indices
out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)]
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
a_pt = pt.tensor3("a", dtype="bool")
a_np = np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)
out_pt = x_pt[a_pt]
out_fg = FunctionGraph([x_pt, a_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_np, a_np])
compare_pytorch_and_py([x_pt, a_pt], [out_pt], [x_np, a_np])
with pytest.raises(
NotImplementedError, match="Negative step sizes are not supported in Pytorch"
):
out_pt = x_pt[[1, 2], ::-1]
out_fg = FunctionGraph([x_pt], [out_pt])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_pytorch_and_py(out_fg, [x_np])
compare_pytorch_and_py([x_pt], [out_pt], [x_np])
@pytest.mark.parametrize("subtensor_op", [set_subtensor, inc_subtensor])
......@@ -116,20 +101,17 @@ def test_pytorch_IncSubtensor(subtensor_op):
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = subtensor_op(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
compare_pytorch_and_py([x_pt], [out_pt], [x_test])
# Test different type update
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32"))
out_pt = subtensor_op(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
compare_pytorch_and_py([x_pt], [out_pt], [x_test])
out_pt = subtensor_op(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
compare_pytorch_and_py([x_pt], [out_pt], [x_test])
def inc_subtensor_ignore_duplicates(x, y):
......@@ -150,14 +132,12 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
)
out_pt = advsubtensor_op(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
compare_pytorch_and_py([x_pt], [out_pt], [x_test])
# Repeated indices
out_pt = advsubtensor_op(x_pt[np.r_[0, 0]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
compare_pytorch_and_py([x_pt], [out_pt], [x_test])
# Mixing advanced and basic indexing
if advsubtensor_op is inc_subtensor:
......@@ -168,19 +148,16 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
st_pt = pt.as_tensor_variable(x_test[[0, 2], 0, :3])
out_pt = advsubtensor_op(x_pt[[0, 0], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
with expectation:
compare_pytorch_and_py(out_fg, [x_test])
compare_pytorch_and_py([x_pt], [out_pt], [x_test])
# Test different dtype update
st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32"))
out_pt = advsubtensor_op(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
compare_pytorch_and_py([x_pt], [out_pt], [x_test])
# Boolean indices
out_pt = advsubtensor_op(x_pt[x_pt > 5], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])
compare_pytorch_and_py([x_pt], [out_pt], [x_test])
......@@ -63,11 +63,6 @@ from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
from tests import unittest_tools as utt
def set_test_value(x, v):
x.tag.test_value = v
return x
def test_cpu_contiguous():
a = fmatrix("a")
i = iscalar("i")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论