提交 b28a3a79 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Only run required rewrites in JAX and PyTorch tests

Only run required rewrites in JAX tests Several tests ended up not testing the backend Op implementations due to constant folding of inputs.
上级 171bb8a4
...@@ -15,7 +15,7 @@ def pytorch_funcify_Reshape(op, node, **kwargs): ...@@ -15,7 +15,7 @@ def pytorch_funcify_Reshape(op, node, **kwargs):
@pytorch_funcify.register(Shape) @pytorch_funcify.register(Shape)
def pytorch_funcify_Shape(op, **kwargs): def pytorch_funcify_Shape(op, **kwargs):
def shape(x): def shape(x):
return x.shape return torch.tensor(x.shape)
return shape return shape
......
...@@ -34,8 +34,13 @@ def pytorch_funcify_Subtensor(op, node, **kwargs): ...@@ -34,8 +34,13 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
@pytorch_funcify.register(MakeSlice) @pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs): def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(*x): def makeslice(start, stop, step):
return slice(x) # Torch does not like numpy integers in indexing slices
return slice(
None if start is None else int(start),
None if stop is None else int(stop),
None if step is None else int(step),
)
return makeslice return makeslice
......
...@@ -6,13 +6,15 @@ import pytest ...@@ -6,13 +6,15 @@ import pytest
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import get_mode from pytensor.compile.mode import JAX, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value from pytensor.graph.op import Op, get_test_value
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, matrices, scalar, vector from pytensor.tensor.type import dscalar, matrices, scalar, vector
...@@ -26,9 +28,9 @@ def set_pytensor_flags(): ...@@ -26,9 +28,9 @@ def set_pytensor_flags():
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs optimizer = RewriteDatabaseQuery(include=["jax"], exclude=JAX._optimizer.exclude)
jax_mode = get_mode("JAX") jax_mode = Mode(linker=JAXLinker(), optimizer=optimizer)
py_mode = get_mode("FAST_COMPILE") py_mode = Mode(linker="py", optimizer=None)
def compare_jax_and_py( def compare_jax_and_py(
......
import numpy as np import numpy as np
import pytest import pytest
import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
...@@ -19,9 +20,8 @@ def test_jax_einsum(): ...@@ -19,9 +20,8 @@ def test_jax_einsum():
pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes) pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes)
) )
out = pt.einsum(subscripts, x_pt, y_pt, z_pt) out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX") fg = FunctionGraph([x_pt, y_pt, z_pt], [out])
compare_jax_and_py(fg, [x, y, z])
np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z))
@pytest.mark.xfail(raises=NotImplementedError) @pytest.mark.xfail(raises=NotImplementedError)
...@@ -33,6 +33,5 @@ def test_ellipsis_einsum(): ...@@ -33,6 +33,5 @@ def test_ellipsis_einsum():
x_pt = pt.tensor("x", shape=x.shape) x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape) y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt) out = pt.einsum(subscripts, x_pt, y_pt)
f = pytensor.function([x_pt, y_pt], out, mode="JAX") fg = FunctionGraph([x_pt, y_pt], [out])
compare_jax_and_py(fg, [x, y])
np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y))
import numpy as np import numpy as np
import pytest import pytest
from packaging.version import parse as version_parse
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as pt_extra_ops from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.type import matrix from pytensor.tensor.type import matrix, tensor
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
def set_test_value(x, v):
x.tag.test_value = v
return x
def test_extra_ops(): def test_extra_ops():
a = matrix("a") a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.cumsum(a, axis=0) out = pt_extra_ops.cumsum(a, axis=0)
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.cumprod(a, axis=1) out = pt_extra_ops.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.diff(a, n=2, axis=1) out = pt_extra_ops.diff(a, n=2, axis=1)
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [a_test])
out = pt_extra_ops.repeat(a, (3, 3), axis=1) out = pt_extra_ops.repeat(a, (3, 3), axis=1)
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [a_test])
c = ptb.as_tensor(5) c = ptb.as_tensor(5)
out = pt_extra_ops.fill_diagonal(a, c) out = pt_extra_ops.fill_diagonal(a, c)
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [a_test])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
out = pt_extra_ops.fill_diagonal_offset(a, c, c) out = pt_extra_ops.fill_diagonal_offset(a, c, c)
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [a_test])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
out = pt_extra_ops.Unique(axis=1)(a) out = pt_extra_ops.Unique(axis=1)(a)
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [a_test])
indices = np.arange(np.prod((3, 4))) indices = np.arange(np.prod((3, 4)))
out = pt_extra_ops.unravel_index(indices, (3, 4), order="C") out = pt_extra_ops.unravel_index(indices, (3, 4), order="C")
...@@ -63,40 +56,30 @@ def test_extra_ops(): ...@@ -63,40 +56,30 @@ def test_extra_ops():
) )
@pytest.mark.xfail( @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
version_parse(jax.__version__) >= version_parse("0.2.12"), def test_bartlett_dynamic_shape():
reason="JAX Numpy API does not support dynamic shapes", c = tensor(shape=(), dtype=int)
)
def test_extra_ops_dynamic_shapes():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
# This function also cannot take symbolic input.
c = ptb.as_tensor(5)
out = pt_extra_ops.bartlett(c) out = pt_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out]) fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [np.array(5)])
multi_index = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4))
out = pt_extra_ops.ravel_multi_index(multi_index, (3, 4))
fgraph = FunctionGraph([], [out])
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
# The inputs are "concrete", yet it still has problems? @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
out = pt_extra_ops.Unique()( def test_ravel_multi_index_dynamic_shape():
ptb.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2))) x_test, y_test = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4))
)
x = tensor(shape=(None,), dtype=int)
y = tensor(shape=(None,), dtype=int)
out = pt_extra_ops.ravel_multi_index((x, y), (3, 4))
fgraph = FunctionGraph([], [out]) fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, []) compare_jax_and_py(fgraph, [x_test, y_test])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_unique_nonconcrete(): def test_unique_dynamic_shape():
a = matrix("a") a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.Unique()(a) out = pt_extra_ops.Unique()(a)
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [a_test])
...@@ -705,7 +705,7 @@ def test_multinomial(): ...@@ -705,7 +705,7 @@ def test_multinomial():
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode="JAX")
samples = g_fn() samples = g_fn()
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
np.testing.assert_allclose( np.testing.assert_allclose(
......
...@@ -32,7 +32,7 @@ def test_scan_sit_sot(view): ...@@ -32,7 +32,7 @@ def test_scan_sit_sot(view):
xs = xs[view] xs = xs[view]
fg = FunctionGraph([x0], [xs]) fg = FunctionGraph([x0], [xs])
test_input_vals = [np.e] test_input_vals = [np.e]
compare_jax_and_py(fg, test_input_vals) compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) @pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
...@@ -47,7 +47,7 @@ def test_scan_mit_sot(view): ...@@ -47,7 +47,7 @@ def test_scan_mit_sot(view):
xs = xs[view] xs = xs[view]
fg = FunctionGraph([x0], [xs]) fg = FunctionGraph([x0], [xs])
test_input_vals = [np.full((3,), np.e)] test_input_vals = [np.full((3,), np.e)]
compare_jax_and_py(fg, test_input_vals) compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)]) @pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)])
...@@ -74,7 +74,7 @@ def test_scan_multiple_mit_sot(view_x, view_y): ...@@ -74,7 +74,7 @@ def test_scan_multiple_mit_sot(view_x, view_y):
fg = FunctionGraph([x0, y0], [xs, ys]) fg = FunctionGraph([x0, y0], [xs, ys])
test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)] test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)]
compare_jax_and_py(fg, test_input_vals) compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)]) @pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)])
...@@ -283,7 +283,7 @@ def test_scan_SEIR(): ...@@ -283,7 +283,7 @@ def test_scan_SEIR():
gamma_val, gamma_val,
delta_val, delta_val,
] ]
compare_jax_and_py(out_fg, test_input_vals) compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX")
def test_scan_mitsot_with_nonseq(): def test_scan_mitsot_with_nonseq():
...@@ -316,7 +316,7 @@ def test_scan_mitsot_with_nonseq(): ...@@ -316,7 +316,7 @@ def test_scan_mitsot_with_nonseq():
out_fg = FunctionGraph([a_pt], [y_scan_pt]) out_fg = FunctionGraph([a_pt], [y_scan_pt])
test_input_vals = [np.array(10.0).astype(config.floatX)] test_input_vals = [np.array(10.0).astype(config.floatX)]
compare_jax_and_py(out_fg, test_input_vals) compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX")
@pytest.mark.parametrize("x0_func", [dvector, dmatrix]) @pytest.mark.parametrize("x0_func", [dvector, dmatrix])
...@@ -334,7 +334,6 @@ def test_nd_scan_sit_sot(x0_func, A_func): ...@@ -334,7 +334,6 @@ def test_nd_scan_sit_sot(x0_func, A_func):
non_sequences=[A], non_sequences=[A],
outputs_info=[x0], outputs_info=[x0],
n_steps=n_steps, n_steps=n_steps,
mode=get_mode("JAX"),
) )
x0_val = ( x0_val = (
...@@ -346,7 +345,7 @@ def test_nd_scan_sit_sot(x0_func, A_func): ...@@ -346,7 +345,7 @@ def test_nd_scan_sit_sot(x0_func, A_func):
fg = FunctionGraph([x0, A], [xs]) fg = FunctionGraph([x0, A], [xs])
test_input_vals = [x0_val, A_val] test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals) compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
def test_nd_scan_sit_sot_with_seq(): def test_nd_scan_sit_sot_with_seq():
...@@ -362,7 +361,6 @@ def test_nd_scan_sit_sot_with_seq(): ...@@ -362,7 +361,6 @@ def test_nd_scan_sit_sot_with_seq():
non_sequences=[A], non_sequences=[A],
sequences=[x], sequences=[x],
n_steps=n_steps, n_steps=n_steps,
mode=get_mode("JAX"),
) )
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
...@@ -370,7 +368,7 @@ def test_nd_scan_sit_sot_with_seq(): ...@@ -370,7 +368,7 @@ def test_nd_scan_sit_sot_with_seq():
fg = FunctionGraph([x, A], [xs]) fg = FunctionGraph([x, A], [xs])
test_input_vals = [x_val, A_val] test_input_vals = [x_val, A_val]
compare_jax_and_py(fg, test_input_vals) compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
def test_nd_scan_mit_sot(): def test_nd_scan_mit_sot():
...@@ -384,7 +382,6 @@ def test_nd_scan_mit_sot(): ...@@ -384,7 +382,6 @@ def test_nd_scan_mit_sot():
outputs_info=[{"initial": x0, "taps": [-3, -1]}], outputs_info=[{"initial": x0, "taps": [-3, -1]}],
non_sequences=[A, B], non_sequences=[A, B],
n_steps=10, n_steps=10,
mode=get_mode("JAX"),
) )
fg = FunctionGraph([x0, A, B], [xs]) fg = FunctionGraph([x0, A, B], [xs])
...@@ -393,7 +390,7 @@ def test_nd_scan_mit_sot(): ...@@ -393,7 +390,7 @@ def test_nd_scan_mit_sot():
B_val = np.eye(3, dtype=config.floatX) B_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val, B_val] test_input_vals = [x0_val, A_val, B_val]
compare_jax_and_py(fg, test_input_vals) compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
def test_nd_scan_sit_sot_with_carry(): def test_nd_scan_sit_sot_with_carry():
...@@ -417,7 +414,7 @@ def test_nd_scan_sit_sot_with_carry(): ...@@ -417,7 +414,7 @@ def test_nd_scan_sit_sot_with_carry():
A_val = np.eye(3, dtype=config.floatX) A_val = np.eye(3, dtype=config.floatX)
test_input_vals = [x0_val, A_val] test_input_vals = [x0_val, A_val]
compare_jax_and_py(fg, test_input_vals) compare_jax_and_py(fg, test_input_vals, jax_mode="JAX")
def test_default_mode_excludes_incompatible_rewrites(): def test_default_mode_excludes_incompatible_rewrites():
...@@ -426,7 +423,7 @@ def test_default_mode_excludes_incompatible_rewrites(): ...@@ -426,7 +423,7 @@ def test_default_mode_excludes_incompatible_rewrites():
B = matrix("B") B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
fg = FunctionGraph([A, B], [out]) fg = FunctionGraph([A, B], [out])
compare_jax_and_py(fg, [np.eye(3), np.eye(3)]) compare_jax_and_py(fg, [np.eye(3), np.eye(3)], jax_mode="JAX")
def test_dynamic_sequence_length(): def test_dynamic_sequence_length():
......
...@@ -51,7 +51,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op): ...@@ -51,7 +51,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):
dot_pt = op(x_pt, y_pt) dot_pt = op(x_pt, y_pt)
fgraph = FunctionGraph(inputs, [dot_pt]) fgraph = FunctionGraph(inputs, [dot_pt])
compare_jax_and_py(fgraph, test_values) compare_jax_and_py(fgraph, test_values, jax_mode="JAX")
def test_sparse_dot_non_const_raises(): def test_sparse_dot_non_const_raises():
......
...@@ -74,7 +74,7 @@ def test_arange_of_shape(): ...@@ -74,7 +74,7 @@ def test_arange_of_shape():
x = vector("x") x = vector("x")
out = ptb.arange(1, x.shape[-1], 2) out = ptb.arange(1, x.shape[-1], 2)
fgraph = FunctionGraph([x], [out]) fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [np.zeros((5,))]) compare_jax_and_py(fgraph, [np.zeros((5,))], jax_mode="JAX")
def test_arange_nonconcrete(): def test_arange_nonconcrete():
......
...@@ -7,13 +7,15 @@ import pytest ...@@ -7,13 +7,15 @@ import pytest
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import get_mode from pytensor.compile.mode import PYTORCH, Mode
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrices, matrix, scalar, vector from pytensor.tensor.type import matrices, matrix, scalar, vector
...@@ -22,8 +24,13 @@ from pytensor.tensor.type import matrices, matrix, scalar, vector ...@@ -22,8 +24,13 @@ from pytensor.tensor.type import matrices, matrix, scalar, vector
torch = pytest.importorskip("torch") torch = pytest.importorskip("torch")
pytorch_mode = get_mode("PYTORCH") optimizer = RewriteDatabaseQuery(
py_mode = get_mode("FAST_COMPILE") # While we don't have a PyTorch implementation of Blockwise
include=["local_useless_unbatched_blockwise"],
exclude=PYTORCH._optimizer.exclude,
)
pytorch_mode = Mode(linker=PytorchLinker(), optimizer=optimizer)
py_mode = Mode(linker="py", optimizer=None)
def compare_pytorch_and_py( def compare_pytorch_and_py(
...@@ -220,7 +227,7 @@ def test_alloc_and_empty(): ...@@ -220,7 +227,7 @@ def test_alloc_and_empty():
assert res.dtype == torch.float32 assert res.dtype == torch.float32
v = vector("v", shape=(3,), dtype="float64") v = vector("v", shape=(3,), dtype="float64")
out = alloc(v, (dim0, dim1, 3)) out = alloc(v, dim0, dim1, 3)
compare_pytorch_and_py( compare_pytorch_and_py(
FunctionGraph([v, dim1], [out]), FunctionGraph([v, dim1], [out]),
[np.array([1, 2, 3]), np.array(7)], [np.array([1, 2, 3]), np.array(7)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论