提交 abedb7fb authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Start using new API in tests that don't involve shared updates

上级 78293400
......@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax")
@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)])
def test_scan_sit_sot(view):
x0 = pt.scalar("x0", dtype="float64")
xs, _ = scan(
xs = scan(
lambda xtm1: xtm1 + 1,
outputs_info=[x0],
n_steps=10,
return_updates=False,
)
if view:
xs = xs[view]
......@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
def test_scan_mit_sot(view):
x0 = pt.vector("x0", dtype="float64", shape=(3,))
xs, _ = scan(
xs = scan(
lambda xtm3, xtm1: xtm3 + xtm1 + 1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
n_steps=10,
return_updates=False,
)
if view:
xs = xs[view]
......@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
def step(xtm3, xtm1, ytm4, ytm2):
return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2
[xs, ys], _ = scan(
[xs, ys] = scan(
fn=step,
outputs_info=[
{"initial": x0, "taps": [-3, -1]},
{"initial": y0, "taps": [-4, -2]},
],
n_steps=10,
return_updates=False,
)
if view_x:
xs = xs[view_x]
......@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
xs = pt.vector("x0", dtype="float64", shape=(10,))
ys, _ = scan(
lambda x: pt.exp(x),
outputs_info=[None],
sequences=[xs],
ys = scan(
lambda x: pt.exp(x), outputs_info=[None], sequences=[xs], return_updates=False
)
if view:
ys = ys[view]
......@@ -106,11 +107,12 @@ def test_scan_mit_mot():
rho = pt.scalar("rho", dtype="float64")
x0 = pt.vector("xs", shape=(2,))
y0 = pt.vector("ys", shape=(3,))
[outs, _], _ = scan(
[outs, _] = scan(
step,
outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}],
non_sequences=[rho],
n_steps=10,
return_updates=False,
)
grads = pt.grad(outs.sum(), wrt=[x0, y0, rho])
compare_jax_and_py(
......@@ -191,10 +193,11 @@ def test_scan_rng_update():
@pytest.mark.xfail(raises=NotImplementedError)
def test_scan_while():
xs, _ = scan(
xs = scan(
lambda x: (x + 1, until(x < 10)),
outputs_info=[pt.zeros(())],
n_steps=100,
return_updates=False,
)
compare_jax_and_py([], [xs], [])
......@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq():
res.name = "y_t"
return res
y_scan_pt, _ = scan(
y_scan_pt = scan(
fn=input_step_fn,
outputs_info=[
{
......@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq():
non_sequences=[a_pt],
n_steps=10,
name="y_scan",
return_updates=False,
)
y_scan_pt.name = "y"
y_scan_pt.owner.inputs[0].name = "y_all"
......@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
k = 3
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
xs = scan(
lambda X, A: A @ X,
non_sequences=[A],
outputs_info=[x0],
n_steps=n_steps,
return_updates=False,
)
x0_val = (
......@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
A = pt.matrix("A", shape=(k, k))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
xs = scan(
lambda X, A: A @ X,
non_sequences=[A],
sequences=[x],
n_steps=n_steps,
return_updates=False,
)
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
......@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
B = pt.matrix("B", shape=(3, 3))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
xs = scan(
lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
non_sequences=[A, B],
n_steps=10,
return_updates=False,
)
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
......@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry():
return A @ x, x.sum()
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
xs = scan(
step,
outputs_info=[x0, None],
non_sequences=[A],
n_steps=10,
mode=get_mode("JAX"),
return_updates=False,
)
x0_val = np.arange(3, dtype=config.floatX)
......@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
# See issue #426
A = matrix("A")
B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
out = scan(
lambda a, b: a @ b,
outputs_info=[A],
non_sequences=[B],
n_steps=2,
return_updates=False,
)
compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX")
......@@ -353,8 +367,11 @@ def test_dynamic_sequence_length():
x = pt.tensor("x", shape=(None, 3))
out, _ = scan(
lambda x: inc_without_static_shape(x), outputs_info=[None], sequences=[x]
out = scan(
lambda x: inc_without_static_shape(x),
outputs_info=[None],
sequences=[x],
return_updates=False,
)
f = function([x], out, mode=get_mode("JAX").excluding("scan"))
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
......@@ -364,10 +381,11 @@ def test_dynamic_sequence_length():
np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3)))
# With known static shape we should always manage, regardless of the internal implementation
out2, _ = scan(
out2 = scan(
lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape),
outputs_info=[None],
sequences=[x],
return_updates=False,
)
f2 = function([x], out2, mode=get_mode("JAX").excluding("scan"))
np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]]))
......@@ -418,11 +436,12 @@ def SEIR_model_logp():
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan(
(st, et, it, logp_c_all, logp_d_all) = scan(
fn=seir_one_step,
sequences=[C_t, D_t],
outputs_info=[st0, et0, it0, None, None],
non_sequences=[beta, gamma, delta],
return_updates=False,
)
st.name = "S_t"
et.name = "E_t"
......@@ -511,11 +530,12 @@ def cyclical_reduction():
max_iter = 100
tol = 1e-7
(*_, A1_hat, norm, _n_steps), _ = scan(
(*_, A1_hat, norm, _n_steps) = scan(
step,
outputs_info=[A, B, C, B, norm, step_num],
non_sequences=[tol],
n_steps=max_iter,
return_updates=False,
)
A1_hat = A1_hat[-1]
......
......@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark):
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan(
(st, et, it, logp_c_all, logp_d_all) = scan(
fn=seir_one_step,
sequences=[pt_C, pt_D],
outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta],
return_updates=False,
)
st.name = "S_t"
et.name = "E_t"
......@@ -268,7 +269,7 @@ def test_scan_tap_output():
y_t.name = "y_t"
return x_t, y_t, pt.fill((10,), z_t)
scan_res, _ = scan(
scan_res = scan(
fn=input_step_fn,
sequences=[
{
......@@ -297,6 +298,7 @@ def test_scan_tap_output():
n_steps=5,
name="yz_scan",
strict=True,
return_updates=False,
)
test_input_vals = [
......@@ -312,11 +314,12 @@ def test_scan_while():
return previous_power * 2, until(previous_power * 2 > max_value)
max_value = pt.scalar()
values, _ = scan(
values = scan(
power_of_2,
outputs_info=pt.constant(1.0),
non_sequences=max_value,
n_steps=1024,
return_updates=False,
)
test_input_vals = [
......@@ -331,11 +334,12 @@ def test_scan_multiple_none_output():
def power_step(prior_result, x):
return prior_result * x, prior_result * x * x, prior_result * x * x * x
result, _ = scan(
result = scan(
power_step,
non_sequences=[A],
outputs_info=[pt.ones_like(A), None, None],
n_steps=3,
return_updates=False,
)
test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py([A], result, test_input_vals)
......@@ -343,8 +347,12 @@ def test_scan_multiple_none_output():
def test_grad_sitsot():
def get_sum_of_grad(inp):
scan_outputs, _updates = scan(
fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA"
scan_outputs = scan(
fn=lambda x: x * 2,
outputs_info=[inp],
n_steps=5,
mode="NUMBA",
return_updates=False,
)
return grad(scan_outputs.sum(), inp).sum()
......@@ -362,8 +370,11 @@ def test_mitmots_basic():
def inner_fct(seq, state_old, state_current):
return state_old * 2 + state_current + seq
out, _ = scan(
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
out = scan(
inner_fct,
sequences=seq,
outputs_info={"initial": init_x, "taps": [-2, -1]},
return_updates=False,
)
g_outs = grad(out.sum(), [seq, init_x])
......@@ -383,10 +394,11 @@ def test_mitmots_basic():
def test_inner_graph_optimized():
"""Test that inner graph of Scan is optimized"""
xs = vector("xs")
seq, _ = scan(
seq = scan(
fn=lambda x: log(1 + x),
sequences=[xs],
mode=get_mode("NUMBA"),
return_updates=False,
)
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
......@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot2 = (sitsot1 + mitsot3) / np.sqrt(2)
return mitsot3, sitsot2
outs, _ = scan(
outs = scan(
fn=step,
sequences=[seq1, seq2],
outputs_info=[
dict(initial=mitsot_init, taps=[-2, -1]),
dict(initial=sitsot_init, taps=[-1]),
],
return_updates=False,
)
rng = np.random.default_rng(474)
......@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant):
y = ytm1 + 1 + ytm2 + a
return z, x, z + x + y, y
[zs, xs, ws, ys], _ = scan(
[zs, xs, ws, ys] = scan(
fn=step,
outputs_info=[
dict(initial=z0, taps=[-3, -1]),
......@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant):
],
non_sequences=[a],
n_steps=n_steps,
return_updates=False,
)
numba_fn, _ = compare_numba_and_py(
[n_steps] * (not n_steps_constant) + [a, x0, y0, z0],
......@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant):
class TestScanSITSOTBuffer:
def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
x0 = pt.vector(shape=(op_size,), dtype="float64")
xs, _ = pytensor.scan(
xs = pytensor.scan(
fn=lambda xtm1: (xtm1 + 1),
outputs_info=[x0],
n_steps=n_steps - 1, # 1- makes it easier to align/misalign
return_updates=False,
)
if buffer_size == "unit":
xs_kept = xs[-1] # Only last state is used
......@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer:
init_x = pt.vector("init_x", shape=(2,))
n_steps = pt.iscalar("n_steps")
output, _ = scan(
output = scan(
f_pow2,
sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[],
n_steps=n_steps_val if constant_n_steps else n_steps,
return_updates=False,
)
init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype)
......
差异被折叠。
......@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
A = tensor("A", shape=(3, 3))
x0 = tensor("b", shape=(3, 4))
xs, _ = scan(
xs = scan(
lambda xtm1, A: solve(A, xtm1, assume_a=assume_a, transposed=transposed),
outputs_info=[x0],
non_sequences=[A],
n_steps=10,
return_updates=False,
)
fn_no_opt = function(
......
......@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type():
def test_scan_gradient_core_type():
n_steps = 3
seq = tensor("seq", shape=(n_steps, 1), dtype="float64")
out, _ = scan(
out = scan(
lambda s: s,
sequences=[seq],
n_steps=n_steps,
return_updates=False,
)
vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论