提交 b28a3a79 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Only run required rewrites in JAX and PyTorch tests

Only run required rewrites in JAX tests Several tests ended up not testing the backend Op implementations due to constant folding of inputs.
上级 171bb8a4
......@@ -15,7 +15,7 @@ def pytorch_funcify_Reshape(op, node, **kwargs):
@pytorch_funcify.register(Shape)
def pytorch_funcify_Shape(op, **kwargs):
def shape(x):
return x.shape
return torch.tensor(x.shape)
return shape
......
......@@ -34,8 +34,13 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(*x):
return slice(x)
def makeslice(start, stop, step):
# Torch does not like numpy integers in indexing slices
return slice(
None if start is None else int(start),
None if stop is None else int(stop),
None if step is None else int(step),
)
return makeslice
......
......@@ -6,13 +6,15 @@ import pytest
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.mode import JAX, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, matrices, scalar, vector
......@@ -26,9 +28,9 @@ def set_pytensor_flags():
jax = pytest.importorskip("jax")
# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs
jax_mode = get_mode("JAX")
py_mode = get_mode("FAST_COMPILE")
optimizer = RewriteDatabaseQuery(include=["jax"], exclude=JAX._optimizer.exclude)
jax_mode = Mode(linker=JAXLinker(), optimizer=optimizer)
py_mode = Mode(linker="py", optimizer=None)
def compare_jax_and_py(
......
import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
......@@ -19,9 +20,8 @@ def test_jax_einsum():
pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes)
)
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX")
np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z))
fg = FunctionGraph([x_pt, y_pt, z_pt], [out])
compare_jax_and_py(fg, [x, y, z])
@pytest.mark.xfail(raises=NotImplementedError)
......@@ -33,6 +33,5 @@ def test_ellipsis_einsum():
x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt)
f = pytensor.function([x_pt, y_pt], out, mode="JAX")
np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y))
fg = FunctionGraph([x_pt, y_pt], [out])
compare_jax_and_py(fg, [x, y])
import numpy as np
import pytest
from packaging.version import parse as version_parse
import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.type import matrix
from pytensor.tensor.type import matrix, tensor
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
def set_test_value(x, v):
x.tag.test_value = v
return x
def test_extra_ops():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.cumsum(a, axis=0)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.diff(a, n=2, axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.repeat(a, (3, 3), axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])
c = ptb.as_tensor(5)
out = pt_extra_ops.fill_diagonal(a, c)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])
with pytest.raises(NotImplementedError):
out = pt_extra_ops.fill_diagonal_offset(a, c, c)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])
with pytest.raises(NotImplementedError):
out = pt_extra_ops.Unique(axis=1)(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])
indices = np.arange(np.prod((3, 4)))
out = pt_extra_ops.unravel_index(indices, (3, 4), order="C")
......@@ -63,40 +56,30 @@ def test_extra_ops():
)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="JAX Numpy API does not support dynamic shapes",
)
def test_extra_ops_dynamic_shapes():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
# This function also cannot take symbolic input.
c = ptb.as_tensor(5)
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_bartlett_dynamic_shape():
c = tensor(shape=(), dtype=int)
out = pt_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [np.array(5)])
multi_index = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4))
out = pt_extra_ops.ravel_multi_index(multi_index, (3, 4))
fgraph = FunctionGraph([], [out])
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
# The inputs are "concrete", yet it still has problems?
out = pt_extra_ops.Unique()(
ptb.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2)))
)
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_ravel_multi_index_dynamic_shape():
x_test, y_test = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4))
x = tensor(shape=(None,), dtype=int)
y = tensor(shape=(None,), dtype=int)
out = pt_extra_ops.ravel_multi_index((x, y), (3, 4))
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
compare_jax_and_py(fgraph, [x_test, y_test])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_unique_nonconcrete():
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_unique_dynamic_shape():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.Unique()(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
compare_jax_and_py(fgraph, [a_test])
......@@ -705,7 +705,7 @@ def test_multinomial():
n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose(
......
......@@ -32,7 +32,7 @@ def test_scan_sit_sot(view):
xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.e]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
......@@ -47,7 +47,7 @@ def test_scan_mit_sot(view):
xs = xs[view]
fg = FunctionGraph([x0], [xs])
test_input_vals = [np.full((3,), np.e)]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)])
......@@ -74,7 +74,7 @@ def test_scan_multiple_mit_sot(view_x, view_y):
fg = FunctionGraph([x0, y0], [xs, ys])
test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)])
......@@ -283,7 +283,7 @@ def test_scan_SEIR():
gamma_val,
delta_val,
]
compare_jax_and_py(out_fg, test_input_vals)
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX")
def test_scan_mitsot_with_nonseq():
......@@ -316,7 +316,7 @@ def test_scan_mitsot_with_nonseq():
out_fg = FunctionGraph([a_pt], [y_scan_pt])
test_input_vals = [np.array(10.0).astype(config.floatX)]
compare_jax_and_py(out_fg, test_input_vals)
compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("x0_func", [dvector, dmatrix])
......@@ -334,7 +334,6 @@ def test_nd_scan_sit_sot(x0_func, A_func):
non_sequences=[A],
outputs_info=[x0],
n_steps=n_steps,
mode=get_mode("JAX"),
)
x0_val = (
......@@ -346,7 +345,7 @@ def test_nd_scan_sit_sot(x0_func, A_func):
fg = FunctionGraph([x0, A], [xs])
test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
def test_nd_scan_sit_sot_with_seq():
......@@ -362,7 +361,6 @@ def test_nd_scan_sit_sot_with_seq():
non_sequences=[A],
sequences=[x],
n_steps=n_steps,
mode=get_mode("JAX"),
)
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
......@@ -370,7 +368,7 @@ def test_nd_scan_sit_sot_with_seq():
fg = FunctionGraph([x, A], [xs])
test_input_vals = [x_val, A_val]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
def test_nd_scan_mit_sot():
......@@ -384,7 +382,6 @@ def test_nd_scan_mit_sot():
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
non_sequences=[A, B],
n_steps=10,
mode=get_mode("JAX"),
)
fg = FunctionGraph([x0, A, B], [xs])
......@@ -393,7 +390,7 @@ def test_nd_scan_mit_sot():
B_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val, B_val]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
def test_nd_scan_sit_sot_with_carry():
......@@ -417,7 +414,7 @@ def test_nd_scan_sit_sot_with_carry():
A_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals)
compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
def test_default_mode_excludes_incompatible_rewrites():
......@@ -426,7 +423,7 @@ def test_default_mode_excludes_incompatible_rewrites():
B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
fg = FunctionGraph([A, B], [out])
compare_jax_and_py(fg, [np.eye(3), np.eye(3)])
compare_jax_and_py(fg, [np.eye(3), np.eye(3)], jax_mode="JAX")
def test_dynamic_sequence_length():
......
......@@ -51,7 +51,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):
dot_pt = op(x_pt, y_pt)
fgraph = FunctionGraph(inputs, [dot_pt])
compare_jax_and_py(fgraph, test_values)
compare_jax_and_py(fgraph, test_values, jax_mode="JAX")
def test_sparse_dot_non_const_raises():
......
......@@ -74,7 +74,7 @@ def test_arange_of_shape():
x = vector("x")
out = ptb.arange(1, x.shape[-1], 2)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [np.zeros((5,))])
compare_jax_and_py(fgraph, [np.zeros((5,))], jax_mode="JAX")
def test_arange_nonconcrete():
......
......@@ -7,13 +7,15 @@ import pytest
import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.mode import PYTORCH, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrices, matrix, scalar, vector
......@@ -22,8 +24,13 @@ from pytensor.tensor.type import matrices, matrix, scalar, vector
torch = pytest.importorskip("torch")
pytorch_mode = get_mode("PYTORCH")
py_mode = get_mode("FAST_COMPILE")
optimizer = RewriteDatabaseQuery(
# While we don't have a PyTorch implementation of Blockwise
include=["local_useless_unbatched_blockwise"],
exclude=PYTORCH._optimizer.exclude,
)
pytorch_mode = Mode(linker=PytorchLinker(), optimizer=optimizer)
py_mode = Mode(linker="py", optimizer=None)
def compare_pytorch_and_py(
......@@ -220,7 +227,7 @@ def test_alloc_and_empty():
assert res.dtype == torch.float32
v = vector("v", shape=(3,), dtype="float64")
out = alloc(v, (dim0, dim1, 3))
out = alloc(v, dim0, dim1, 3)
compare_pytorch_and_py(
FunctionGraph([v, dim1], [out]),
[np.array([1, 2, 3]), np.array(7)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论