提交 abedb7fb authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Start using new API in tests that don't involve shared updates

上级 78293400
......@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax")
@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)])
def test_scan_sit_sot(view):
x0 = pt.scalar("x0", dtype="float64")
xs, _ = scan(
xs = scan(
lambda xtm1: xtm1 + 1,
outputs_info=[x0],
n_steps=10,
return_updates=False,
)
if view:
xs = xs[view]
......@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
def test_scan_mit_sot(view):
x0 = pt.vector("x0", dtype="float64", shape=(3,))
xs, _ = scan(
xs = scan(
lambda xtm3, xtm1: xtm3 + xtm1 + 1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
n_steps=10,
return_updates=False,
)
if view:
xs = xs[view]
......@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
def step(xtm3, xtm1, ytm4, ytm2):
return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2
[xs, ys], _ = scan(
[xs, ys] = scan(
fn=step,
outputs_info=[
{"initial": x0, "taps": [-3, -1]},
{"initial": y0, "taps": [-4, -2]},
],
n_steps=10,
return_updates=False,
)
if view_x:
xs = xs[view_x]
......@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
xs = pt.vector("x0", dtype="float64", shape=(10,))
ys, _ = scan(
lambda x: pt.exp(x),
outputs_info=[None],
sequences=[xs],
ys = scan(
lambda x: pt.exp(x), outputs_info=[None], sequences=[xs], return_updates=False
)
if view:
ys = ys[view]
......@@ -106,11 +107,12 @@ def test_scan_mit_mot():
rho = pt.scalar("rho", dtype="float64")
x0 = pt.vector("xs", shape=(2,))
y0 = pt.vector("ys", shape=(3,))
[outs, _], _ = scan(
[outs, _] = scan(
step,
outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}],
non_sequences=[rho],
n_steps=10,
return_updates=False,
)
grads = pt.grad(outs.sum(), wrt=[x0, y0, rho])
compare_jax_and_py(
......@@ -191,10 +193,11 @@ def test_scan_rng_update():
@pytest.mark.xfail(raises=NotImplementedError)
def test_scan_while():
xs, _ = scan(
xs = scan(
lambda x: (x + 1, until(x < 10)),
outputs_info=[pt.zeros(())],
n_steps=100,
return_updates=False,
)
compare_jax_and_py([], [xs], [])
......@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq():
res.name = "y_t"
return res
y_scan_pt, _ = scan(
y_scan_pt = scan(
fn=input_step_fn,
outputs_info=[
{
......@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq():
non_sequences=[a_pt],
n_steps=10,
name="y_scan",
return_updates=False,
)
y_scan_pt.name = "y"
y_scan_pt.owner.inputs[0].name = "y_all"
......@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
k = 3
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
xs = scan(
lambda X, A: A @ X,
non_sequences=[A],
outputs_info=[x0],
n_steps=n_steps,
return_updates=False,
)
x0_val = (
......@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
A = pt.matrix("A", shape=(k, k))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
xs = scan(
lambda X, A: A @ X,
non_sequences=[A],
sequences=[x],
n_steps=n_steps,
return_updates=False,
)
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
......@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
B = pt.matrix("B", shape=(3, 3))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
xs = scan(
lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}],
non_sequences=[A, B],
n_steps=10,
return_updates=False,
)
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
......@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry():
return A @ x, x.sum()
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan(
xs = scan(
step,
outputs_info=[x0, None],
non_sequences=[A],
n_steps=10,
mode=get_mode("JAX"),
return_updates=False,
)
x0_val = np.arange(3, dtype=config.floatX)
......@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
# See issue #426
A = matrix("A")
B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2)
out = scan(
lambda a, b: a @ b,
outputs_info=[A],
non_sequences=[B],
n_steps=2,
return_updates=False,
)
compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX")
......@@ -353,8 +367,11 @@ def test_dynamic_sequence_length():
x = pt.tensor("x", shape=(None, 3))
out, _ = scan(
lambda x: inc_without_static_shape(x), outputs_info=[None], sequences=[x]
out = scan(
lambda x: inc_without_static_shape(x),
outputs_info=[None],
sequences=[x],
return_updates=False,
)
f = function([x], out, mode=get_mode("JAX").excluding("scan"))
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
......@@ -364,10 +381,11 @@ def test_dynamic_sequence_length():
np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3)))
# With known static shape we should always manage, regardless of the internal implementation
out2, _ = scan(
out2 = scan(
lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape),
outputs_info=[None],
sequences=[x],
return_updates=False,
)
f2 = function([x], out2, mode=get_mode("JAX").excluding("scan"))
np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]]))
......@@ -418,11 +436,12 @@ def SEIR_model_logp():
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan(
(st, et, it, logp_c_all, logp_d_all) = scan(
fn=seir_one_step,
sequences=[C_t, D_t],
outputs_info=[st0, et0, it0, None, None],
non_sequences=[beta, gamma, delta],
return_updates=False,
)
st.name = "S_t"
et.name = "E_t"
......@@ -511,11 +530,12 @@ def cyclical_reduction():
max_iter = 100
tol = 1e-7
(*_, A1_hat, norm, _n_steps), _ = scan(
(*_, A1_hat, norm, _n_steps) = scan(
step,
outputs_info=[A, B, C, B, norm, step_num],
non_sequences=[tol],
n_steps=max_iter,
return_updates=False,
)
A1_hat = A1_hat[-1]
......
......@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark):
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan(
(st, et, it, logp_c_all, logp_d_all) = scan(
fn=seir_one_step,
sequences=[pt_C, pt_D],
outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta],
return_updates=False,
)
st.name = "S_t"
et.name = "E_t"
......@@ -268,7 +269,7 @@ def test_scan_tap_output():
y_t.name = "y_t"
return x_t, y_t, pt.fill((10,), z_t)
scan_res, _ = scan(
scan_res = scan(
fn=input_step_fn,
sequences=[
{
......@@ -297,6 +298,7 @@ def test_scan_tap_output():
n_steps=5,
name="yz_scan",
strict=True,
return_updates=False,
)
test_input_vals = [
......@@ -312,11 +314,12 @@ def test_scan_while():
return previous_power * 2, until(previous_power * 2 > max_value)
max_value = pt.scalar()
values, _ = scan(
values = scan(
power_of_2,
outputs_info=pt.constant(1.0),
non_sequences=max_value,
n_steps=1024,
return_updates=False,
)
test_input_vals = [
......@@ -331,11 +334,12 @@ def test_scan_multiple_none_output():
def power_step(prior_result, x):
return prior_result * x, prior_result * x * x, prior_result * x * x * x
result, _ = scan(
result = scan(
power_step,
non_sequences=[A],
outputs_info=[pt.ones_like(A), None, None],
n_steps=3,
return_updates=False,
)
test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py([A], result, test_input_vals)
......@@ -343,8 +347,12 @@ def test_scan_multiple_none_output():
def test_grad_sitsot():
def get_sum_of_grad(inp):
scan_outputs, _updates = scan(
fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA"
scan_outputs = scan(
fn=lambda x: x * 2,
outputs_info=[inp],
n_steps=5,
mode="NUMBA",
return_updates=False,
)
return grad(scan_outputs.sum(), inp).sum()
......@@ -362,8 +370,11 @@ def test_mitmots_basic():
def inner_fct(seq, state_old, state_current):
return state_old * 2 + state_current + seq
out, _ = scan(
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
out = scan(
inner_fct,
sequences=seq,
outputs_info={"initial": init_x, "taps": [-2, -1]},
return_updates=False,
)
g_outs = grad(out.sum(), [seq, init_x])
......@@ -383,10 +394,11 @@ def test_mitmots_basic():
def test_inner_graph_optimized():
"""Test that inner graph of Scan is optimized"""
xs = vector("xs")
seq, _ = scan(
seq = scan(
fn=lambda x: log(1 + x),
sequences=[xs],
mode=get_mode("NUMBA"),
return_updates=False,
)
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
......@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot2 = (sitsot1 + mitsot3) / np.sqrt(2)
return mitsot3, sitsot2
outs, _ = scan(
outs = scan(
fn=step,
sequences=[seq1, seq2],
outputs_info=[
dict(initial=mitsot_init, taps=[-2, -1]),
dict(initial=sitsot_init, taps=[-1]),
],
return_updates=False,
)
rng = np.random.default_rng(474)
......@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant):
y = ytm1 + 1 + ytm2 + a
return z, x, z + x + y, y
[zs, xs, ws, ys], _ = scan(
[zs, xs, ws, ys] = scan(
fn=step,
outputs_info=[
dict(initial=z0, taps=[-3, -1]),
......@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant):
],
non_sequences=[a],
n_steps=n_steps,
return_updates=False,
)
numba_fn, _ = compare_numba_and_py(
[n_steps] * (not n_steps_constant) + [a, x0, y0, z0],
......@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant):
class TestScanSITSOTBuffer:
def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
x0 = pt.vector(shape=(op_size,), dtype="float64")
xs, _ = pytensor.scan(
xs = pytensor.scan(
fn=lambda xtm1: (xtm1 + 1),
outputs_info=[x0],
n_steps=n_steps - 1, # 1- makes it easier to align/misalign
return_updates=False,
)
if buffer_size == "unit":
xs_kept = xs[-1] # Only last state is used
......@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer:
init_x = pt.vector("init_x", shape=(2,))
n_steps = pt.iscalar("n_steps")
output, _ = scan(
output = scan(
f_pow2,
sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[],
n_steps=n_steps_val if constant_n_steps else n_steps,
return_updates=False,
)
init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype)
......
......@@ -294,7 +294,7 @@ class TestScan:
def test_clone(self):
a = vector()
output, _ = scan(fn=lambda x: x**2, sequences=[a])
output = scan(fn=lambda x: x**2, sequences=[a], return_updates=False)
scan_op = output.owner.op
assert isinstance(scan_op, Scan)
......@@ -320,7 +320,7 @@ class TestScan:
state = scalar("state")
n_steps = iscalar("nsteps")
output, updates = scan(
output = scan(
f_pow2,
[],
state,
......@@ -328,10 +328,9 @@ class TestScan:
n_steps=n_steps,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
_my_f = function(
[state, n_steps], output, updates=updates, allow_input_downcast=True
)
_my_f = function([state, n_steps], output, allow_input_downcast=True)
origdir = Path.cwd()
tmpdir = None
......@@ -368,11 +367,9 @@ class TestScan:
state = scalar("state")
n_steps = iscalar("nsteps")
output, updates = scan(f_pow2, [], state, [], n_steps=n_steps)
output = scan(f_pow2, [], state, [], n_steps=n_steps, return_updates=False)
f = function(
[state, n_steps], output, updates=updates, allow_input_downcast=True
)
f = function([state, n_steps], output, allow_input_downcast=True)
scan_node = [
node for node in f.maker.fgraph.toposort() if isinstance(node.op, Scan)
......@@ -410,7 +407,9 @@ class TestScan:
return 2 * x_tm1
n_steps = iscalar("n_steps")
values, _ = scan(f_pow, outputs_info=(x_init,), n_steps=n_steps)
values = scan(
f_pow, outputs_info=(x_init,), n_steps=n_steps, return_updates=False
)
update_fn = function((x_init, n_steps), values, mode=mode)
......@@ -443,7 +442,9 @@ class TestScan:
return 2 * x_i
with config.change_flags(mode=mode):
values, _ = scan(inner_fn, outputs_info=(x_init,), sequences=x)
values = scan(
inner_fn, outputs_info=(x_init,), sequences=x, return_updates=False
)
values_fn = function((x_init, x), values)
assert isinstance(values.owner.inputs[0].owner.op, Scan)
......@@ -474,7 +475,7 @@ class TestScan:
return 2 * x_i
with config.change_flags(mode=mode):
values, _ = scan(inner_fn, sequences=x)
values = scan(inner_fn, sequences=x, return_updates=False)
values_fn = function((x,), values)
assert isinstance(values.owner.op, Scan)
......@@ -491,7 +492,9 @@ class TestScan:
# Compile the PyTensor function
n_steps = 2
inp = matrix()
broadcasted_inp, _ = scan(lambda x: x, non_sequences=[inp], n_steps=n_steps)
broadcasted_inp = scan(
lambda x: x, non_sequences=[inp], n_steps=n_steps, return_updates=False
)
out = broadcasted_inp.sum()
gr = grad(out, inp)
fun = function([inp], [broadcasted_inp, gr])
......@@ -519,7 +522,7 @@ class TestScan:
W_in = scalar("win")
W = scalar("w")
output, updates = scan(
output = scan(
f_rnn,
u,
x0,
......@@ -527,11 +530,10 @@ class TestScan:
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
f2 = function(
[u, x0, W_in, W], output, updates=updates, allow_input_downcast=True
)
f2 = function([u, x0, W_in, W], output, allow_input_downcast=True)
# get random initial values
rng = np.random.default_rng(utt.fetch_seed())
v_u = rng.uniform(-5.0, 5.0, size=(4,))
......@@ -561,7 +563,7 @@ class TestScan:
def f_rnn_shared(u_t, x_tm1, tmp_W_in, tmp_W):
return u_t * tmp_W_in + x_tm1 * tmp_W
output, updates = scan(
output = scan(
f_rnn_shared,
u,
x0,
......@@ -569,8 +571,9 @@ class TestScan:
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
f3 = function([u, x0], output, updates=updates, allow_input_downcast=True)
f3 = function([u, x0], output, allow_input_downcast=True)
# get random initial values
v_u = rng.uniform(-5.0, 5.0, size=(4,))
......@@ -688,11 +691,14 @@ class TestScan:
# this test refers to a bug reported by Nicolas
# Boulanger-Lewandowski June 6th
x = dvector()
y, updates = scan(
lambda x: [x], sequences=dict(input=x, taps=[-1]), outputs_info=[None]
y = scan(
lambda x: [x],
sequences=dict(input=x, taps=[-1]),
outputs_info=[None],
return_updates=False,
)
inp = np.arange(5).astype("float64")
rval = function([x], y, updates=updates)(inp)
rval = function([x], y)(inp)
assert np.all(rval == inp[:-1])
def test_output_only(self):
......@@ -701,11 +707,18 @@ class TestScan:
u = vector("u")
outputs, updates = scan(
f_rnn, u, [], [], n_steps=None, truncate_gradient=-1, go_backwards=False
outputs = scan(
f_rnn,
u,
[],
[],
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
f2 = function([u], outputs, updates=updates, allow_input_downcast=True)
f2 = function([u], outputs, allow_input_downcast=True)
rng = np.random.default_rng(utt.fetch_seed())
v_u = rng.uniform(-5.0, 5.0, size=(5,))
......@@ -722,7 +735,7 @@ class TestScan:
W_in = scalar("win")
W = scalar("w")
output, updates = scan(
output = scan(
f_rnn,
u,
x0,
......@@ -730,11 +743,10 @@ class TestScan:
n_steps=None,
truncate_gradient=-1,
go_backwards=True,
return_updates=False,
)
f2 = function(
[u, x0, W_in, W], output, updates=updates, allow_input_downcast=True
)
f2 = function([u, x0, W_in, W], output, allow_input_downcast=True)
# get random initial values
rng = np.random.default_rng(utt.fetch_seed())
v_u = rng.uniform(-5.0, 5.0, size=(4,))
......@@ -797,8 +809,8 @@ class TestScan:
def test_hash(self):
x = vector()
y = vector()
scan1, _updates = scan(lambda _x: _x + 1, x)
scan2, _updates = scan(lambda _x: _x + 1, y)
scan1 = scan(lambda _x: _x + 1, x, return_updates=False)
scan2 = scan(lambda _x: _x + 1, y, return_updates=False)
assert scan1.owner.op == scan2.owner.op
assert hash(scan1.owner.op) == hash(scan2.owner.op)
......@@ -809,9 +821,24 @@ class TestScan:
y = vector("y")
c = scalar("c")
scan_a, _ = scan(lambda x, y, c: x + y + c, sequences=[x, y], non_sequences=[c])
scan_b, _ = scan(lambda x, y, c: x + y + c, sequences=[x, y], non_sequences=[c])
scan_c, _ = scan(lambda x, y, c: x + y + c, sequences=[y, x], non_sequences=[c])
scan_a = scan(
lambda x, y, c: x + y + c,
sequences=[x, y],
non_sequences=[c],
return_updates=False,
)
scan_b = scan(
lambda x, y, c: x + y + c,
sequences=[x, y],
non_sequences=[c],
return_updates=False,
)
scan_c = scan(
lambda x, y, c: x + y + c,
sequences=[y, x],
non_sequences=[c],
return_updates=False,
)
assert scan_b is not scan_a
assert scan_c is not scan_a
......@@ -1006,7 +1033,7 @@ class TestScan:
def lambda_fn(x_t):
return x_t + 1, until(x_t > 3)
o, _ = scan(lambda_fn, x)
o = scan(lambda_fn, x, return_updates=False)
f = function([x], o)
vx = np.zeros((50,), dtype=config.floatX)
vx[23] = 4
......@@ -1019,7 +1046,7 @@ class TestScan:
def lambda_fn(x_t):
return x_t + 1, until(x_t > 3)
o, _ = scan(lambda_fn, x)
o = scan(lambda_fn, x, return_updates=False)
f = function([x], o.shape[0], mode=mode_with_opt)
vx = np.zeros((50,), dtype=config.floatX)
......@@ -1029,11 +1056,12 @@ class TestScan:
def test_infer_shape_nsteps_smaller_seq_length(self):
x = vector("x")
[o1, o2], _ = scan(
[o1, o2] = scan(
lambda x, y: (x + 1, y + x),
sequences=x,
outputs_info=[None, x[0]],
n_steps=20,
return_updates=False,
)
f = function([x], [o1.shape[0], o2.shape[0]], mode=mode_with_opt)
......@@ -1071,17 +1099,18 @@ class TestScan:
mode = MonitorMode(post_func=detect_large_outputs)
# Symbolic description of the result
result, updates = scan(
result = scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=pt.ones_like(A),
non_sequences=A,
n_steps=k,
mode=mode,
return_updates=False,
)
final_result = result[-1]
f = function(inputs=[A, k], outputs=final_result, updates=updates)
f = function(inputs=[A, k], outputs=final_result)
f(np.asarray([2, 3, 0.1, 0, 1], dtype=config.floatX), 4)
# There should be 3 outputs greater than 10: prior_result[0] at step 3,
......@@ -1103,10 +1132,11 @@ class TestScan:
y.name = "y"
gy = grad(y, x)
gy.name = "gy"
hy, _updates = scan(
hy = scan(
lambda i, gy, x: grad(gy[i] * fc2, x),
sequences=pt.arange(gy.shape[0]),
non_sequences=[gy, x],
return_updates=False,
)
f = function([x, A], hy, allow_input_downcast=True)
......@@ -1123,8 +1153,13 @@ class TestScan:
def test_sequence_is_scan(self, mode):
"""Make sure that a `Scan` can be used as a sequence input to another `Scan`."""
x0 = scalar("x0")
scan_1, _ = scan(lambda x: x + 1, outputs_info={"initial": x0}, n_steps=10)
scan_2, _ = scan(lambda x: x + 1, sequences=[scan_1])
scan_1 = scan(
lambda x: x + 1,
outputs_info={"initial": x0},
n_steps=10,
return_updates=False,
)
scan_2 = scan(lambda x: x + 1, sequences=[scan_1], return_updates=False)
with config.change_flags(mode=mode):
scan_2_fn = function([x0], scan_2)
......@@ -1185,7 +1220,7 @@ class TestScan:
def test_blockwise_scan(self):
x = pt.tensor("x", shape=())
out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10)
out = scan(lambda x: x + 1, outputs_info=[x], n_steps=10, return_updates=False)
x_vec = pt.tensor("x_vec", shape=(None,))
out_vec = vectorize_graph(out, {x: x_vec})
......@@ -1203,13 +1238,14 @@ class TestScan:
a0 = shared(np.arange(2))
b0 = shared(np.arange(2))
(a, _b), _ = scan(
(a, _b) = scan(
fn,
outputs_info=[
{"initial": a0, "taps": [-2, -1]},
{"initial": b0, "taps": [-2, -1]},
],
n_steps=2,
return_updates=False,
)
grad(a[-1], a0)
......@@ -1241,8 +1277,11 @@ class TestScan:
state_next = state_old * 2 + state_current + seq
return state_next
out, _ = scan(
inner_fct, sequences=seq, outputs_info={"initial": x, "taps": [-2, -1]}
out = scan(
inner_fct,
sequences=seq,
outputs_info={"initial": x, "taps": [-2, -1]},
return_updates=False,
)
g_out = grad(out.sum(), [seq, x])
......@@ -1302,12 +1341,13 @@ class TestScan:
new_y = pt.switch(cond, y, sigmoid(x))
return new_cond, new_x, new_y
values, _ = scan(
values = scan(
inner_fn,
outputs_info=[c, x, y],
n_steps=10,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
gX, gY = grad(values[1].sum(), [x, y])
f = function([c, x, y], [gX, gY], allow_input_downcast=True)
......@@ -1762,11 +1802,12 @@ class TestScan:
outputs_info = [None, dict(initial=out_init, taps=[-3])]
scan_outputs, _ = scan(
scan_outputs = scan(
fn=inner_fct,
sequences=seq,
outputs_info=outputs_info,
non_sequences=non_seq,
return_updates=False,
)
# Attempt to take various gradients
......@@ -1834,7 +1875,9 @@ class TestScan:
dict(initial=out_init[3], taps=[-2, -1]),
]
scan_outputs, _ = scan(fn=inner_fct, outputs_info=outputs_info, n_steps=10)
scan_outputs = scan(
fn=inner_fct, outputs_info=outputs_info, n_steps=10, return_updates=False
)
grad(scan_outputs[0].sum(), out_init[1])
......@@ -1857,11 +1900,12 @@ class TestScan:
x = scalar("x")
_max_coefficients_supported = 1000
full_range = pt.arange(_max_coefficients_supported)
components, _updates = scan(
components = scan(
fn=lambda coeff, power, free_var: coeff * (free_var**power),
outputs_info=None,
sequences=[c, full_range],
non_sequences=x,
return_updates=False,
)
P = components.sum()
dP = grad(P, x)
......@@ -1877,11 +1921,12 @@ class TestScan:
x = scalar("x")
_max_coefficients_supported = 1000
full_range = pt.arange(_max_coefficients_supported)
components, _updates = scan(
components = scan(
fn=lambda coeff, power, free_var: coeff * (free_var**power),
outputs_info=None,
sequences=[c, full_range],
non_sequences=x,
return_updates=False,
)
P = components.sum()
dP = grad(P, x).sum()
......@@ -1968,8 +2013,13 @@ class TestScan:
_W = specify_shape(W, v_W.shape)
_W.name = "_W"
o, _ = scan(
rnn_fn, sequences=_u, outputs_info=_h0, non_sequences=_W, name="rnn_fn"
o = scan(
rnn_fn,
sequences=_u,
outputs_info=_h0,
non_sequences=_W,
name="rnn_fn",
return_updates=False,
)
o = o[-1]
eu = matrix("eu")
......@@ -1983,25 +2033,28 @@ class TestScan:
[u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore"
)
n2o_u, _ = scan(
n2o_u = scan(
lambda i, o, u, h0, W, eu: (grad(o[i], u) * eu).sum(),
sequences=pt.arange(o.shape[0]),
non_sequences=[o, u, h0, W, eu],
name="jacobU",
return_updates=False,
)
n2o_h0, _ = scan(
n2o_h0 = scan(
lambda i, o, u, h0, W, eh0: (grad(o[i], h0) * eh0).sum(),
sequences=pt.arange(o.shape[0]),
non_sequences=[o, u, h0, W, eh0],
name="jacobh",
return_updates=False,
)
n2o_W, _ = scan(
n2o_W = scan(
lambda i, o, u, h0, W, eW: (grad(o[i], W) * eW).sum(),
sequences=pt.arange(o.shape[0]),
non_sequences=[o, u, h0, W, eW],
name="jacobW",
return_updates=False,
)
fn_test = function(
......@@ -2132,10 +2185,11 @@ class TestScan:
transfer = sigmoid
hidden_rec, _ = scan(
hidden_rec = scan(
lambda x, h_tm1: transfer(dot(h_tm1, W2) + x),
sequences=hidden,
outputs_info=[pt.zeros_like(hidden[0])],
return_updates=False,
)
hidden_rec.reshape(
......@@ -2168,12 +2222,13 @@ class TestScan:
def step(s, xtm2, xtm1, z):
return s * ((xtm2 * 0 + xtm1) ** 2) * (z / 2)
xs, _ = scan(
xs = scan(
step,
sequences=[seq],
outputs_info=[{"initial": x0, "taps": (-2, -1)}],
non_sequences=[z],
n_steps=2,
return_updates=False,
)
last_x = xs[-1]
......@@ -2254,11 +2309,12 @@ class TestScan:
raise ValueError(f"Invalid case: {case}")
seq = vector("seq")
xs, _ = scan(
xs = scan(
step,
sequences=[seq],
non_sequences=non_sequences,
strict=strict,
return_updates=False,
)
x0 = xs[0]
......@@ -2298,7 +2354,7 @@ def test_cvm_exception_handling(mode):
def scan_fn():
return myop(pt.as_tensor(1))
res, _ = scan(scan_fn, n_steps=4, mode=mode)
res = scan(scan_fn, n_steps=4, mode=mode, return_updates=False)
res_fn = function([], res, mode=mode)
......@@ -2328,14 +2384,14 @@ def test_cython_performance(benchmark):
py_res = f_py()
s_r = pt.as_tensor_variable(r, dtype=config.floatX)
s_y, updates = scan(
s_y = scan(
fn=lambda ri, rii, M: ri + M * rii,
sequences=[s_r[1:]],
non_sequences=[pt.as_tensor_variable(M, dtype=config.floatX)],
outputs_info=s_r[0],
mode=Mode(linker="cvm", optimizer="fast_run"),
return_updates=False,
)
assert not updates
f_cvm = function([], s_y, mode="FAST_RUN")
f_cvm.trust_input = True
......@@ -2357,9 +2413,7 @@ def test_compute_test_values():
y = shared(np.arange(3, dtype=config.floatX), name="y")
z, updates = scan(fn=lambda u, v: u + v, sequences=[x, y])
assert not updates
z = scan(fn=lambda u, v: u + v, sequences=[x, y], return_updates=False)
z_grad = grad(z.sum(), x)
......@@ -2368,9 +2422,9 @@ def test_compute_test_values():
# Use `non_sequences` this time
y = shared(np.arange(9, dtype=config.floatX).reshape(3, 3), name="y")
z, updates = scan(fn=lambda u, v: u + v, sequences=[x], non_sequences=[y])
assert not updates
z = scan(
fn=lambda u, v: u + v, sequences=[x], non_sequences=[y], return_updates=False
)
z_grad = grad(z.sum(), x)
......@@ -2399,20 +2453,22 @@ def test_compute_test_value_grad():
def loss_ti(ti, sum_ti, mi, W):
return W.sum().sum().sum() + sum_ti
result_ti, _ = scan(
result_ti = scan(
fn=loss_ti,
outputs_info=outputs_ti,
sequences=pt.arange(W.shape[1], dtype="int32"),
non_sequences=[mi, W],
return_updates=False,
)
lossmi = result_ti[-1]
return sum_mi + lossmi
result_mi, _ = scan(
result_mi = scan(
fn=loss_mi,
outputs_info=outputs_mi,
sequences=pt.arange(W.shape[0], dtype="int32"),
non_sequences=[W],
return_updates=False,
)
loss = result_mi[-1]
......@@ -2436,11 +2492,12 @@ def test_compute_test_value_grad_cast():
name="w",
)
outputs, _ = scan(
outputs = scan(
lambda i, h, w: (dot(h[i], w), i),
outputs_info=[None, 0],
non_sequences=[h, w],
n_steps=3,
return_updates=False,
)
grad(outputs[0].sum(), w)
......@@ -2449,11 +2506,12 @@ def test_compute_test_value_grad_cast():
def test_constant_folding_n_steps():
# The following code used to crash at revision 2060b8f, in the constant
# folding optimization step.
res, _ = scan(
res = scan(
lambda x: x * 2,
outputs_info=pt.ones(()),
# The constant `n_steps` was causing the crash.
n_steps=10,
return_updates=False,
)
with config.change_flags(on_opt_error="raise"):
function([], res)()
......@@ -2478,10 +2536,11 @@ def test_outputs_taps_check():
def test_inconsistent_broadcast_error():
x = tensor3()
initial_x = pt.constant(np.zeros((1, 10)))
y, _updates = scan(
y = scan(
fn=lambda x, prev_x: x + prev_x,
sequences=x,
outputs_info=[dict(initial=initial_x)],
return_updates=False,
)
# Error, because the broadcast patterns are inconsistent.
with pytest.raises(TypeError):
......@@ -2509,10 +2568,11 @@ class TestGradUntil:
self.numpy_gradient = 2 * np.concatenate([self.seq[:7], z], axis=0)
def test_grad_until(self):
r, _ = scan(
r = scan(
lambda x, u: (x * x, until(x > u)),
sequences=self.x,
non_sequences=[self.threshold],
return_updates=False,
)
g = grad(r.sum(), self.x)
f = function([self.x, self.threshold], [r, g])
......@@ -2528,10 +2588,11 @@ class TestGradUntil:
X = matrix(name="x")
arr = tile_array(self.seq)
r, _ = scan(
r = scan(
lambda x, u: (x * x, until(pt_all(x > u))),
sequences=X,
non_sequences=[self.threshold],
return_updates=False,
)
g = grad(r.sum(), X)
f = function([X, self.threshold], [r, g])
......@@ -2542,11 +2603,12 @@ class TestGradUntil:
def test_grad_until_and_truncate(self):
n = 3
r, _ = scan(
r = scan(
lambda x, u: (x * x, until(x > u)),
sequences=self.x,
non_sequences=[self.threshold],
truncate_gradient=n,
return_updates=False,
)
g = grad(r.sum(), self.x)
f = function([self.x, self.threshold], [r, g])
......@@ -2558,11 +2620,12 @@ class TestGradUntil:
def test_grad_until_and_truncate_sequence_taps(self):
n = 3
r, _ = scan(
r = scan(
lambda x, y, u: (x * y, until(y > u)),
sequences=dict(input=self.x, taps=[-2, 0]),
non_sequences=[self.threshold],
truncate_gradient=n,
return_updates=False,
)
g = grad(r.sum(), self.x)
f = function([self.x, self.threshold], [r, g])
......@@ -2581,8 +2644,12 @@ def test_mintap_onestep():
new_sum = prev_sum + seq_t
return new_sum
rs, _updates = scan(
fn=accum, sequences={"input": seq, "taps": [2]}, outputs_info=0, n_steps=1
rs = scan(
fn=accum,
sequences={"input": seq, "taps": [2]},
outputs_info=0,
n_steps=1,
return_updates=False,
)
f = function(inputs=[seq], outputs=rs)
......@@ -2667,7 +2734,12 @@ def test_inner_get_vector_length():
def test_profile_info():
from pytensor.scan.utils import ScanProfileStats
z, _updates = scan(fn=lambda u: u + 1, sequences=[pt.arange(10)], profile=True)
z = scan(
fn=lambda u: u + 1,
sequences=[pt.arange(10)],
profile=True,
return_updates=False,
)
assert isinstance(z.owner.op, Scan)
fn = z.owner.op.fn
......@@ -2676,8 +2748,11 @@ def test_profile_info():
assert fn.profile.name == "scan_fn"
# Set the `ScanProfileStats` name
z, _updates = scan(
fn=lambda u: u + 1, sequences=[pt.arange(10)], profile="profile_name"
z = scan(
fn=lambda u: u + 1,
sequences=[pt.arange(10)],
profile="profile_name",
return_updates=False,
)
assert isinstance(z.owner.op, Scan)
......@@ -2688,7 +2763,12 @@ def test_profile_info():
# Use an existing profile object
profile = fn.profile
z, _updates = scan(fn=lambda u: u + 1, sequences=[pt.arange(10)], profile=profile)
z = scan(
fn=lambda u: u + 1,
sequences=[pt.arange(10)],
profile=profile,
return_updates=False,
)
assert isinstance(z.owner.op, Scan)
fn = z.owner.op.fn
......@@ -2819,7 +2899,7 @@ class TestExamples:
y_tm1 + dot(x_tm1, W_out),
]
outputs, updates = scan(
outputs = scan(
f_rnn_cmpl,
[u1, u2],
[None, None, x0, dict(initial=y0, taps=[-1, -3])],
......@@ -2827,11 +2907,10 @@ class TestExamples:
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
f4 = function(
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
)
f4 = function([u1, u2, x0, y0, W_in1], outputs, allow_input_downcast=True)
# compute the values in numpy
v_x = np.zeros((3, 2), dtype=config.floatX)
......@@ -2857,8 +2936,12 @@ class TestExamples:
def scanStep(prev, seq, f1):
return prev + f1 * seq
scanned, _ = scan(
fn=scanStep, sequences=[seq], outputs_info=[to_scan], non_sequences=[f1]
scanned = scan(
fn=scanStep,
sequences=[seq],
outputs_info=[to_scan],
non_sequences=[f1],
return_updates=False,
)
function(inputs=[to_scan, seq, f1], outputs=scanned, allow_input_downcast=True)
......@@ -2879,8 +2962,12 @@ class TestExamples:
expr = dot(h_tm1, W) + x_t
return expr
expr, _ = scan(
fn=one_step, sequences=[inpt], outputs_info=[initial], non_sequences=[W]
expr = scan(
fn=one_step,
sequences=[inpt],
outputs_info=[initial],
non_sequences=[W],
return_updates=False,
)
v1 = shared(np.ones(5, dtype=config.floatX))
......@@ -2917,11 +3004,12 @@ class TestExamples:
x = scalar()
seq = vector()
outputs_info = [x, pt.zeros_like(x)]
(out1, out2), _updates = scan(
(out1, out2) = scan(
lambda a, b, c: (a + b, b + c),
sequences=seq,
outputs_info=outputs_info,
mode=mode,
return_updates=False,
)
# Obtain a reference to the scan outputs before the subtensor and
......@@ -2956,8 +3044,11 @@ class TestExamples:
x = dcol()
seq = dcol()
outputs_info = [x, pt.zeros_like(x)]
(out1, out2), _updates = scan(
lambda a, b, c: (a + b, a + c), sequences=seq, outputs_info=outputs_info
(out1, out2) = scan(
lambda a, b, c: (a + b, a + c),
sequences=seq,
outputs_info=outputs_info,
return_updates=False,
)
# Obtain a reference to the scan outputs before the subtensor and
......@@ -3096,7 +3187,9 @@ class TestExamples:
seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
results, _updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info)
results = scan(
fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False
)
f = function([seq], results[1])
assert np.all(exp_out == f(inp))
......@@ -3119,7 +3212,9 @@ class TestExamples:
seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
results, _ = scan(fn=onestep, sequences=seq, outputs_info=outputs_info)
results = scan(
fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False
)
sharedvar = shared(np.zeros((1, 1), dtype=config.floatX))
updates = {sharedvar: results[0][-1:]}
......@@ -3164,7 +3259,7 @@ class TestExamples:
init = matrix()
outputs_info = [None, None, None, None, dict(initial=init, taps=[-3, -2, -1])]
out, _ = scan(inner_fn, outputs_info=outputs_info, n_steps=3)
out = scan(inner_fn, outputs_info=outputs_info, n_steps=3, return_updates=False)
fct = function([init], out)
# Compare obtained outputs with expected outputs
......@@ -3197,21 +3292,23 @@ class TestExamples:
def loss_inner(sum_inner, W):
return sum_inner + (W**2).sum()
result_inner, _ = scan(
result_inner = scan(
fn=loss_inner,
outputs_info=pt.as_tensor_variable(np.asarray(0, dtype=np.float32)),
non_sequences=[W],
n_steps=1,
return_updates=False,
)
return sum_outer + result_inner[-1]
# Also test return_list for that case.
result_outer, _ = scan(
result_outer = scan(
fn=loss_outer,
outputs_info=pt.as_tensor_variable(np.asarray(0, dtype=np.float32)),
non_sequences=[W],
n_steps=n_steps,
return_list=True,
return_updates=False,
)
cost = result_outer[0][-1]
......@@ -3230,7 +3327,9 @@ class TestExamples:
x0 = vector("X")
y0 = vector("y0")
z0 = vector("Z")
[x, y, z], _ = scan(inner_fn, outputs_info=[x0, y0, z0], n_steps=10)
[x, y, z] = scan(
inner_fn, outputs_info=[x0, y0, z0], n_steps=10, return_updates=False
)
cost = (x + y + z).sum()
grad(cost, x0) # defined
......@@ -3247,7 +3346,12 @@ class TestExamples:
m = matrix("m")
u0 = pt.zeros((7,))
[_u, m2], _ = scan(lambda _, u: [u, v], sequences=m, outputs_info=[u0, None])
[_u, m2] = scan(
lambda _, u: [u, v],
sequences=m,
outputs_info=[u0, None],
return_updates=False,
)
# This used to raise an exception with older versions because for a
# disconnected gradient a non disconnected type was returned
grad((m * m2).sum(), v)
......@@ -3257,8 +3361,11 @@ class TestExamples:
m = matrix("m")
u0 = pt.zeros((7,))
[_u, m2], _ = scan(
lambda x, u: [x + u, u + v], sequences=m, outputs_info=[u0, None]
[_u, m2] = scan(
lambda x, u: [x + u, u + v],
sequences=m,
outputs_info=[u0, None],
return_updates=False,
)
# This used to raise an exception with older versions because
# scan could not detect the connection between `m2` and `x`
......@@ -3278,7 +3385,7 @@ class TestExamples:
out2 = out1 + 1
return out1, out2
[_out1, out2], _ = scan(step, sequences=v)
[_out1, out2] = scan(step, sequences=v, return_updates=False)
gv = grad(out2.sum(), [v])
f = function([v], gv)
......@@ -3289,7 +3396,13 @@ class TestExamples:
def test_grad_bug_disconnected_input(self):
W = shared(np.zeros((3, 3)), name="W")
v = ivector(name="v")
y, _ = scan(lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=W)
y = scan(
lambda i, W: W[i],
sequences=v,
outputs_info=None,
non_sequences=W,
return_updates=False,
)
# This used to raise an exception
f = function([v], grad(y.sum(), W))
......@@ -3299,10 +3412,8 @@ class TestExamples:
w = shared(np.array(0, dtype="float32"), name="w")
init = fscalar("init")
out, _ = scan(
fn=lambda prev: w,
outputs_info=init,
n_steps=2,
out = scan(
fn=lambda prev: w, outputs_info=init, n_steps=2, return_updates=False
)
grad(out[-1], w)
......@@ -3326,7 +3437,7 @@ class TestExamples:
def f_rnn_shared(u_tm2, x_tm1, x_tm2):
return u_tm2 * W_in + x_tm1 * W + x_tm2
outputs, updates = scan(
outputs = scan(
f_rnn_shared,
dict(input=u, taps=-2),
dict(initial=x0, taps=[-1, -2]),
......@@ -3334,9 +3445,10 @@ class TestExamples:
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
f7 = function([u, x0], outputs, updates=updates, allow_input_downcast=True)
f7 = function([u, x0], outputs, allow_input_downcast=True)
pytensor_out = f7(vu, vx0)
# compute output in numpy
......@@ -3372,7 +3484,7 @@ class TestExamples:
def f_rnn_shared(u_tm2, u_tp2, x_tm1, x_tm2):
return (u_tm2 + u_tp2) * W_in + x_tm1 * W + x_tm2
output, updates = scan(
output = scan(
f_rnn_shared,
dict(input=u, taps=[-2, 2]),
dict(initial=x0, taps=[-1, -2]),
......@@ -3380,9 +3492,10 @@ class TestExamples:
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
f8 = function([u, x0], output, updates=updates, allow_input_downcast=True)
f8 = function([u, x0], output, allow_input_downcast=True)
pytensor_out = f8(vu, vx0)
# compute output in numpy
numpy_out = np.zeros(2)
......@@ -3404,7 +3517,7 @@ class TestExamples:
state = scalar("state")
n_steps = iscalar("nsteps")
# Test return_list at the same time.
output, updates = scan(
output = scan(
f_pow2,
[],
state,
......@@ -3413,10 +3526,9 @@ class TestExamples:
truncate_gradient=-1,
return_list=True,
go_backwards=False,
return_updates=False,
)
my_f = function(
[state, n_steps], output, updates=updates, allow_input_downcast=True
)
my_f = function([state, n_steps], output, allow_input_downcast=True)
rng = np.random.default_rng(utt.fetch_seed())
state = rng.uniform()
......@@ -3446,10 +3558,11 @@ class TestExamples:
pre_h = dot(x, W_x)
return pre_h
value, _scan_updates = scan(
value = scan(
_active,
sequences=X,
outputs_info=[pt.alloc(floatx(0.0), 1, out_size)],
return_updates=False,
)
cost = mean(value)
gW_x = grad(cost, W_x)
......@@ -3467,7 +3580,7 @@ class TestExamples:
condition = until(new_value > max_value)
return [new_value, new_step], condition
rs, _updates = scan(fn=accum, outputs_info=[0, 0], n_steps=n_steps)
rs = scan(fn=accum, outputs_info=[0, 0], n_steps=n_steps, return_updates=False)
f = function(inputs=[max_value, n_steps], outputs=rs)
......@@ -3487,33 +3600,37 @@ class TestExamples:
# Generate the components of the polynomial
full_range = pt.arange(max_coefficients_supported)
components, _updates = scan(
components = scan(
fn=lambda coeff, power, free_var: coeff * (free_var**power),
sequences=[coefficients, full_range],
non_sequences=x,
return_updates=False,
)
polynomial1 = components.sum()
polynomial2, _updates = scan(
polynomial2 = scan(
fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power),
outputs_info=pt.constant(0, dtype="floatX"),
sequences=[coefficients, full_range],
non_sequences=x,
return_updates=False,
)
# python int
polynomial3, _updates = scan(
polynomial3 = scan(
fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power),
outputs_info=0,
sequences=[coefficients, full_range],
non_sequences=x,
return_updates=False,
)
# python float
polynomial4, _updates = scan(
polynomial4 = scan(
fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power),
outputs_info=0.0,
sequences=[coefficients, full_range],
non_sequences=x,
return_updates=False,
)
calculate_polynomial = function(
......@@ -3576,8 +3693,12 @@ class TestExamples:
# o = v + 1 # <-- this line works
return o
OS, _updates = scan(
fn=one_step, sequences=V, outputs_info=[None], non_sequences=[W]
OS = scan(
fn=one_step,
sequences=V,
outputs_info=[None],
non_sequences=[W],
return_updates=False,
)
O = OS.sum() + W.sum()
......@@ -3591,11 +3712,12 @@ class TestExamples:
)
def test_infershape_seq_shorter_nsteps(self):
x = vector("x")
[o1, o2], _ = scan(
[o1, o2] = scan(
lambda x, y: (x + 1, y + x),
sequences=x,
outputs_info=[None, x[0]],
n_steps=20,
return_updates=False,
)
f = function([x], [o1, o2], mode=mode_with_opt)
......@@ -3667,10 +3789,14 @@ class TestExamples:
condition = until(previous_val > 5)
return new_val, condition
out, _updates = scan(inner_fct, outputs_info=x, n_steps=10)
out, updates = scan(inner_fct, outputs_info=x, n_steps=10)
g_out = grad(out.sum(), x)
fct = function([x], [out, g_out])
fct = function(
[x],
[out, g_out],
updates=updates,
)
for i in range(-5, 5):
output, g_output = fct(i)
......@@ -3702,7 +3828,7 @@ class TestExamples:
)
return next_sitsot_val, next_mitsot_val, nitsot_out
out, _updates = scan(
out = scan(
fn=step,
sequences=seq,
outputs_info=[
......@@ -3711,6 +3837,7 @@ class TestExamples:
None,
],
n_steps=5,
return_updates=False,
)
f = function([seq, sitsot_init, mitsot_init], out[2].shape)
......@@ -3746,7 +3873,7 @@ class TestExamples:
dot(x_tm1, W_out),
]
outputs, updates = scan(
outputs = scan(
f_rnn_cmpl,
[u1, u2],
[x0, y0],
......@@ -3754,11 +3881,10 @@ class TestExamples:
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
f4 = function(
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
)
f4 = function([u1, u2, x0, y0, W_in1], outputs, allow_input_downcast=True)
# compute the values in numpy
v_x = np.zeros((3, 2), dtype=config.floatX)
......@@ -3802,7 +3928,7 @@ class TestExamples:
dot(u1_t, W_in1),
]
outputs, updates = scan(
outputs = scan(
f_rnn_cmpl,
[u1, dict(input=u2, taps=[-1, 0, 1])],
[x0, dict(initial=y0, taps=[-1, -3]), None],
......@@ -3810,11 +3936,10 @@ class TestExamples:
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
f = function(
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
)
f = function([u1, u2, x0, y0, W_in1], outputs, allow_input_downcast=True)
ny0 = np.zeros((5, 2))
ny1 = np.zeros((5,))
......@@ -3904,13 +4029,14 @@ class TestExamples:
return [h_t, y_t]
# hidden and outputs of the entire sequence
[_h, y], _ = scan(
[_h, y] = scan(
fn=one_step,
sequences=dict(input=x),
# corresponds to the return type of one_step
outputs_info=[dict(initial=h0, taps=[-2, -1]), None],
non_sequences=[W_ih, W_hh, b_h, W_ho, b_o],
mode=mode,
return_updates=False,
)
# target values
......@@ -4084,7 +4210,7 @@ def test_output_storage_reuse(linker_mode):
outer-output arrays are initialized using the outer-input arrays, the
shape difference needs to be handled correctly.
"""
s_in_y, _ = scan(
s_in_y = scan(
fn=lambda z: (z + 1, until(z > 2)),
outputs_info=[
{"taps": [-1], "initial": pt.as_tensor(0.0, dtype=np.float64)}
......@@ -4092,16 +4218,18 @@ def test_output_storage_reuse(linker_mode):
mode=mode,
n_steps=n - 1,
allow_gc=False,
return_updates=False,
)
return s_in_y.sum()
s_y, _updates = scan(
s_y = scan(
fn=fn,
outputs_info=[None],
sequences=[pt.as_tensor([3, 2, 1], dtype=np.int64)],
mode=mode,
allow_gc=False,
return_updates=False,
)
f_cvm = function([], s_y, mode=mode)
......@@ -4121,14 +4249,14 @@ def test_rng_outputs_info():
).owner.outputs
return next_x, next_rng
[xs, rng_final], updates = scan(
[xs, rng_final] = scan(
fn=step,
outputs_info=[x0, rng_x0],
n_steps=10,
return_updates=False,
)
assert isinstance(xs.type, TensorType)
assert isinstance(rng_final.type, RandomGeneratorType)
assert not updates
fn = function([rng_init], [xs, rng_final])
xs_eval, rng_final_eval = fn(np.random.default_rng(0))
......
......@@ -47,38 +47,47 @@ class TestRemoveConstantsAndUnusedInputsScan:
"""Test the rewrite `remove_constants_and_unused_inputs_scan` for non-sequences."""
W = matrix(name="W")
v = ivector(name="v")
y1, _ = scan(
lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=[W]
y1 = scan(
lambda i, W: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W],
return_updates=False,
)
y2, _ = scan(
y2 = scan(
lambda i, _, W: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W[0], W],
return_updates=False,
)
y3, _ = scan(
y3 = scan(
lambda i, W, _: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W, W[0]],
return_updates=False,
)
y4, _ = scan(
y4 = scan(
lambda i, _, _2, W: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W[0], W[0], W],
return_updates=False,
)
y5, _ = scan(
y5 = scan(
lambda i, _, W, _2: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W[0], W, W[0]],
return_updates=False,
)
y6, _ = scan(
y6 = scan(
lambda i, W, _, _2: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W, W[0], W[0]],
return_updates=False,
)
# TODO: y7 have problem during run time. I think it should
# raise an error during the scan construction.
......@@ -112,47 +121,61 @@ class TestRemoveConstantsAndUnusedInputsScan:
W = matrix(name="W")
v = ivector(name="v")
vv = matrix(name="vv")
y1, _ = scan(
lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=[W]
y1 = scan(
lambda i, W: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W],
return_updates=False,
)
y2, _ = scan(
lambda i, _, W: W[i], sequences=[v, v], outputs_info=None, non_sequences=W
y2 = scan(
lambda i, _, W: W[i],
sequences=[v, v],
outputs_info=None,
non_sequences=W,
return_updates=False,
)
y3, _ = scan(
y3 = scan(
lambda i, _, W: W[i],
sequences=[v, vv[0]],
outputs_info=None,
non_sequences=W,
return_updates=False,
)
y4, _ = scan(
y4 = scan(
lambda _, i, W: W[i],
sequences=[vv[0], v],
outputs_info=None,
non_sequences=W,
return_updates=False,
)
y5, _ = scan(
y5 = scan(
lambda _, i, _2, W: W[i],
sequences=[vv, v, vv[0]],
outputs_info=None,
non_sequences=W,
return_updates=False,
)
y6, _ = scan(
y6 = scan(
lambda _, _2, i, W: W[i],
sequences=[vv[0], vv, v],
outputs_info=None,
non_sequences=W,
return_updates=False,
)
y7, _ = scan(
y7 = scan(
lambda i, _, _2, W: W[i],
sequences=[v, vv[0], vv[0]],
outputs_info=None,
non_sequences=W,
return_updates=False,
)
y8, _ = scan(
y8 = scan(
lambda _, i, W, _2, _3: W[i],
sequences=[vv[0], v],
outputs_info=None,
non_sequences=[W, W[0], W[0]],
return_updates=False,
)
W_val = np.random.normal(size=(3, 3)).astype(config.floatX)
......@@ -195,7 +218,7 @@ class TestPushOutDot:
def lambda_fn(h, W1, W2):
return dot(h, W1 + W2)
o, _ = scan(lambda_fn, non_sequences=[h0, W1, W2], n_steps=5)
o = scan(lambda_fn, non_sequences=[h0, W1, W2], n_steps=5, return_updates=False)
f = function([h0, W1, W2], o, mode=self.mode)
......@@ -232,19 +255,24 @@ class TestPushOutDot:
return dot(W1, W2), until_condition
# Compile a function with the optimization
o, _ = scan(
lambda_fn, sequences=[step_indices, W1], non_sequences=[W2], n_steps=5
o = scan(
lambda_fn,
sequences=[step_indices, W1],
non_sequences=[W2],
n_steps=5,
return_updates=False,
)
f = function([W1, W2, step_indices], o, mode=self.mode)
# Compule an pytensor function without the optimization
o, _ = scan(
o = scan(
lambda_fn,
sequences=[step_indices, W1],
non_sequences=[W2],
n_steps=5,
mode="FAST_COMPILE",
return_updates=False,
)
f_ref = function([W1, W2, step_indices], o, mode=self.mode)
......@@ -268,7 +296,13 @@ class TestPushOutDot:
def lambda_fn(h, W1, W2):
return dot(h, W1 + W2)
o, _ = scan(lambda_fn, outputs_info=h0, non_sequences=[W1, W2], n_steps=5)
o = scan(
lambda_fn,
outputs_info=h0,
non_sequences=[W1, W2],
n_steps=5,
return_updates=False,
)
f = function([h0, W1, W2], o, mode=self.mode)
......@@ -290,10 +324,11 @@ class TestPushOutDot:
def fn(i, i_tm1):
return i + 10, i_tm1
([i_t, i_tm1], _) = scan(
[i_t, i_tm1] = scan(
fn,
sequences=[inp],
outputs_info=[np.asarray([0.0, 0.0], config.floatX), None],
return_updates=False,
)
f = function([inp], [i_t, i_tm1])
val = np.arange(10).reshape(5, 2).astype(config.floatX)
......@@ -397,17 +432,18 @@ class TestPushOutNonSeqScan:
@config.change_flags(on_opt_error="raise")
def test_pushout_seqs2(self):
x = matrix()
outputs, updates = scan(
outputs = scan(
lambda x: [x * x, pt.constant(0).copy().copy()],
n_steps=2,
sequences=[],
non_sequences=[],
outputs_info=[x, None],
return_updates=False,
)
# Compile an PyTensor function where any optimization error will lead to
# an exception being raised
function([x], outputs, updates=updates)
function([x], outputs)
@config.change_flags(on_opt_error="raise")
def test_pushout_nonseq(self):
......@@ -418,7 +454,9 @@ class TestPushOutNonSeqScan:
outputs. This led the optimization to raise an exception.
"""
outputs, _ = scan(lambda x: (x * x, x), non_sequences=[2], n_steps=2)
outputs = scan(
lambda x: (x * x, x), non_sequences=[2], n_steps=2, return_updates=False
)
f = function(inputs=[], outputs=outputs)
outs = f()
......@@ -583,10 +621,12 @@ class TestPushOutNonSeqScan:
test_ofg = OpFromGraph([], [y])
def inner_func(x):
out, _ = pytensor.scan(lambda: test_ofg(), n_steps=x)
out = pytensor.scan(lambda: test_ofg(), n_steps=x, return_updates=False)
return out
out, _ = pytensor.scan(inner_func, sequences=[pt.arange(1, 2)])
out = pytensor.scan(
inner_func, sequences=[pt.arange(1, 2)], return_updates=False
)
_ = pytensor.function([], test_ofg())
......@@ -612,10 +652,11 @@ class TestPushOutAddScan:
def test_sum_dot(self):
A = matrix("A")
B = matrix("B")
S, _ = scan(
S = scan(
lambda x1, x2, u: u + dot(x1, x2),
sequences=[A.dimshuffle(0, 1, "x"), B.dimshuffle(0, "x", 1)],
outputs_info=[pt.zeros_like(A)],
return_updates=False,
)
# FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that.
# They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]`
......@@ -636,13 +677,17 @@ class TestPushOutAddScan:
bv = pt.zeros((5,))
bh = pt.zeros((4,))
v = matrix("v")
(bv_t, bh_t), _ = scan(
lambda _: [bv, bh], sequences=v, outputs_info=[None, None]
(bv_t, bh_t) = scan(
lambda _: [bv, bh],
sequences=v,
outputs_info=[None, None],
return_updates=False,
)
chain, _ = scan(
chain = scan(
lambda x: dot(dot(x, W) + bh_t, W.T) + bv_t,
outputs_info=v,
n_steps=2,
return_updates=False,
)
# TODO FIXME: Make this a real test and assert something.
chain_fn = function([v], chain)
......@@ -710,26 +755,28 @@ class TestPushOutAddScan:
# Compile the function twice, once with the optimization and once
# without
opt_mode = mode.including("scan")
h, _ = pytensor.scan(
h = pytensor.scan(
rnn_step1,
sequences=[x, ri, zi],
n_steps=seq_len,
outputs_info=init,
name="fpass1",
mode=opt_mode,
return_updates=False,
)
cost = h[-1].sum()
grad1 = grad(cost, [U, V, W])
f_opt = pytensor.function(inputs=[x, ri, zi], outputs=grad1, mode=opt_mode)
no_opt_mode = mode.excluding("scan_pushout_add")
h, _ = pytensor.scan(
h = pytensor.scan(
rnn_step1,
sequences=[x, ri, zi],
n_steps=seq_len,
outputs_info=init,
name="fpass1",
mode=no_opt_mode,
return_updates=False,
)
cost = h[-1].sum()
grad1 = grad(cost, [U, V, W])
......@@ -773,21 +820,23 @@ class TestPushOutAddScan:
# Compile the function twice, once with the optimization and once without
opt_mode = mode.including("scan")
h, _ = pytensor.scan(
h = pytensor.scan(
inner_fct,
sequences=[input1, input2, input3],
outputs_info=init,
mode=opt_mode,
return_updates=False,
)
output = h[-1]
f_opt = pytensor.function([input1, input2, input3], output, mode=opt_mode)
no_opt_mode = mode.excluding("scan_pushout_add")
h, _ = pytensor.scan(
h = pytensor.scan(
inner_fct,
sequences=[input1, input2, input3],
outputs_info=init,
mode=no_opt_mode,
return_updates=False,
)
output = h[-1]
f_no_opt = pytensor.function([input1, input2, input3], output, mode=no_opt_mode)
......@@ -892,13 +941,20 @@ class TestScanMerge:
"""
inps = vector()
state = scalar()
y1, _ = scan(lambda x, y: x * y, sequences=inps, outputs_info=state, n_steps=5)
y1 = scan(
lambda x, y: x * y,
sequences=inps,
outputs_info=state,
n_steps=5,
return_updates=False,
)
y2, _ = scan(
y2 = scan(
lambda x, y: (x + y, until(x > 0)),
sequences=inps,
outputs_info=state,
n_steps=5,
return_updates=False,
)
scan_node1 = y1.owner.inputs[0].owner
assert isinstance(scan_node1.op, Scan)
......@@ -958,8 +1014,8 @@ class TestScanMerge:
def sub(s1, s2, const):
return s1 - 1, until(s2 > const)
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, -z], non_sequences=[c1])
sx = scan(add, sequences=[x, z], non_sequences=[c1], return_updates=False)
sy = scan(sub, sequences=[y, -z], non_sequences=[c1], return_updates=False)
f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
......@@ -972,8 +1028,8 @@ class TestScanMerge:
np.testing.assert_array_equal(res_sx, [1, 1])
np.testing.assert_array_equal(res_sy, [-1, -1, -1, -1, -1])
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c2])
sx = scan(add, sequences=[x, z], non_sequences=[c1], return_updates=False)
sy = scan(sub, sequences=[y, z], non_sequences=[c2], return_updates=False)
f = pytensor.function(
inputs=[x, y, z, c1, c2], outputs=[sx, sy], mode=self.mode
......@@ -989,22 +1045,23 @@ class TestScanMerge:
np.testing.assert_array_equal(res_sx, [1, 1, 1, 1, 1])
np.testing.assert_array_equal(res_sy, [-1, -1, -1])
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c1])
sx = scan(add, sequences=[x, z], non_sequences=[c1], return_updates=False)
sy = scan(sub, sequences=[y, z], non_sequences=[c1], return_updates=False)
f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
assert self.count_scans(f) == 1
def nested_scan(c, x, z):
sx, _ = scan(add, sequences=[x, z], non_sequences=[c])
sy, _ = scan(sub, sequences=[x, z], non_sequences=[c])
sx = scan(add, sequences=[x, z], non_sequences=[c], return_updates=False)
sy = scan(sub, sequences=[x, z], non_sequences=[c], return_updates=False)
return sx.sum() + sy.sum()
sz, _ = scan(
sz = scan(
nested_scan,
sequences=[stack([c1, c2])],
non_sequences=[x, z],
mode=self.mode,
return_updates=False,
)
f = pytensor.function(inputs=[x, z, c1, c2], outputs=sz, mode=mode)
......@@ -1023,9 +1080,8 @@ class TestScanInplaceOptimizer:
x = pt.vector("x")
scan_out, _ = pytensor.scan(
lambda x: (x + 1) / 2 + 1,
sequences=[x],
scan_out = pytensor.scan(
lambda x: (x + 1) / 2 + 1, sequences=[x], return_updates=False
)
fgraph = FunctionGraph(
......@@ -1039,10 +1095,8 @@ class TestScanInplaceOptimizer:
assert equal_computations([scan_out], fgraph.outputs)
def test_inplace_basic(self):
scan_out, _ = pytensor.scan(
lambda x: x + 1,
outputs_info=[pt.zeros(1)],
n_steps=3,
scan_out = pytensor.scan(
lambda x: x + 1, outputs_info=[pt.zeros(1)], n_steps=3, return_updates=False
)
fgraph = FunctionGraph(
......@@ -1089,7 +1143,7 @@ class TestScanInplaceOptimizer:
u0_t * W_in + x1_tm1 * W + u1_t + u2_t,
]
outputs, updates = scan(
outputs = scan(
f_rnn_shared,
[u0, u1, u2],
[dict(initial=x0, inplace=u2), dict(initial=x1, inplace=u1)],
......@@ -1098,12 +1152,12 @@ class TestScanInplaceOptimizer:
truncate_gradient=-1,
go_backwards=False,
mode=self.mode,
return_updates=False,
)
f9 = function(
[mu0, mu1, mu2, x0, x1],
outputs,
updates=updates,
mode=self.mode,
allow_input_downcast=True,
)
......@@ -1155,7 +1209,7 @@ class TestScanInplaceOptimizer:
u0_t * W_in + x1_tm1 * W + u2_tm1 + u2_t + u2_tp1,
]
outputs, updates = scan(
outputs = scan(
f_rnn_shared,
[u0, dict(input=u1, taps=[0, 1]), dict(input=u2, taps=[-1, 0, +1])],
[dict(initial=x0), dict(initial=x1)],
......@@ -1164,11 +1218,11 @@ class TestScanInplaceOptimizer:
truncate_gradient=-1,
go_backwards=False,
mode=self.mode,
return_updates=False,
)
f9 = function(
[mu0, mu1, mu2, x0, x1],
outputs,
updates=updates,
mode=self.mode,
allow_input_downcast=True,
)
......@@ -1202,8 +1256,12 @@ class TestScanInplaceOptimizer:
vx1 = asarrayX(rng.uniform())
x0 = shared(vx0)
x1 = shared(vx1)
outputs, updates = scan(
lambda x, y: (x + asarrayX(1), y + asarrayX(1)), [], [x0, x1], n_steps=3
outputs = scan(
lambda x, y: (x + asarrayX(1), y + asarrayX(1)),
[],
[x0, x1],
n_steps=3,
return_updates=False,
)
x0 = asarrayX(np.zeros((4,)))
x0[0] = vx0
......@@ -1212,7 +1270,7 @@ class TestScanInplaceOptimizer:
to_replace = outputs[0].owner.inputs[0].owner.inputs[1]
outputs = clone_replace(outputs, replace=[(to_replace, x0)])
f9 = function([], outputs, updates=updates, mode=self.mode)
f9 = function([], outputs, mode=self.mode)
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 not in scan_node[0].op.destroy_map
assert 1 in scan_node[0].op.destroy_map
......@@ -1249,7 +1307,7 @@ class TestSaveMem:
y_tm1 + dot(x_tm1, W_out),
]
_outputs, updates = scan(
outs = scan(
f_rnn_cmpl,
[u1, u2],
[None, dict(initial=x0), dict(initial=y0, taps=[-1, -3])],
......@@ -1257,12 +1315,12 @@ class TestSaveMem:
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
outputs = [_outputs[0][-1], _outputs[1][-1], _outputs[2][-1]]
outputs = [outs[0][-1], outs[1][-1], outs[2][-1]]
f4 = function(
[u1, u2, x0, y0, W_in1],
outputs,
updates=updates,
allow_input_downcast=True,
mode=self.mode,
)
......@@ -1297,14 +1355,18 @@ class TestSaveMem:
u = vector("u")
idx = iscalar("idx")
jdx = iscalar("jdx")
[x1, x2, x3, x4, x5, x6, x7], updates = scan(
f_rnn, u, n_steps=None, truncate_gradient=-1, go_backwards=False
[x1, x2, x3, x4, x5, x6, x7] = scan(
f_rnn,
u,
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
f2 = function(
[u, idx, jdx],
[x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]],
updates=updates,
allow_input_downcast=True,
mode=self.mode.excluding("scan_push_out_seq"),
)
......@@ -1341,10 +1403,8 @@ class TestSaveMem:
def test_save_mem_reduced_number_of_steps_constant(self):
x0 = pt.scalar("x0")
xs, _ = scan(
lambda xtm1: xtm1 + 1,
outputs_info=[x0],
n_steps=10,
xs = scan(
lambda xtm1: xtm1 + 1, outputs_info=[x0], n_steps=10, return_updates=False
)
fn = function([x0], xs[:5], mode=self.mode)
......@@ -1358,10 +1418,11 @@ class TestSaveMem:
def test_save_mem_cannot_reduce_constant_number_of_steps(self):
x0 = pt.scalar("x0")
[xs, ys], _ = scan(
[xs, ys] = scan(
lambda xtm1, ytm1: (xtm1 + 1, ytm1 - 1),
outputs_info=[x0, x0],
n_steps=10,
return_updates=False,
)
# Because of ys[-1] we need all the steps!
......@@ -1399,7 +1460,7 @@ class TestSaveMem:
x20 = scalar("x20")
x30 = vector("x30")
x40 = scalar("x40")
[x1, x2, x3, x4, x5, _x6, _x7], updates = scan(
[x1, x2, x3, x4, x5, _x6, _x7] = scan(
step,
u,
[
......@@ -1414,12 +1475,12 @@ class TestSaveMem:
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
f = function(
[u, x10, x20, x30, x40],
[x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]],
updates=updates,
allow_input_downcast=True,
mode=self.mode,
)
......@@ -1479,10 +1540,11 @@ class TestSaveMem:
def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
var = pt.ones(())
values, _ = scan(
values = scan(
lambda x: ([x], (), until(x)),
outputs_info=[var],
n_steps=2,
return_updates=False,
)
tmp_fn = function([var], values, mode=self.mode)
......@@ -1493,10 +1555,11 @@ class TestSaveMem:
def test_savemem_opt(self, benchmark):
y0 = shared(np.ones((2, 10)))
[_y1, y2], _updates = scan(
[_y1, y2] = scan(
lambda y: [y, y],
outputs_info=[dict(initial=y0, taps=[-2]), None],
n_steps=5,
return_updates=False,
)
# TODO FIXME: Make this a real test and assert something.
fn = function([], y2.sum(), mode=self.mode)
......@@ -1515,23 +1578,25 @@ class TestSaveMem:
return dot(h_tm1, w) + x_t_t
def outer_scan_step(x_t, w):
h, _ = scan(
h = scan(
inner_scan_step,
sequences=[x_t[1:]],
outputs_info=[x_t[0]],
non_sequences=[w],
strict=True,
name="the_inner_scan",
return_updates=False,
)
return h
def get_outputs(x, w):
features, _ = scan(
features = scan(
outer_scan_step,
sequences=[x],
non_sequences=[w],
strict=True,
name="the_outer_scan",
return_updates=False,
)
return_val = grad(features.sum(), w)
......@@ -1571,7 +1636,7 @@ class TestSaveMem:
state = vector("state")
n_steps = iscalar("nsteps")
output, updates = scan(
output = scan(
f_pow2,
[],
state,
......@@ -1579,13 +1644,13 @@ class TestSaveMem:
n_steps=n_steps,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
)
nw_shape = ivector("nw_shape")
# Note that the output is reshaped to 3 dimensional tensor, and
my_f = function(
[state, n_steps, nw_shape],
[reshape(output, nw_shape, ndim=3)[:-2], output[:-4]],
updates=updates,
allow_input_downcast=True,
)
nodes = [x for x in my_f.maker.fgraph.toposort() if isinstance(x.op, Scan)]
......@@ -1599,11 +1664,12 @@ class TestSaveMem:
n_steps = scalar("n_steps", dtype="int64")
x0 = vector("x0")
ys, _ = pytensor.scan(
ys = pytensor.scan(
# Fibonacci Sequence
lambda xtm2, xtm1: (xtm1 + xtm2, {}, until(xtm1 >= 34)),
outputs_info=[{"initial": x0, "taps": [-2, -1]}],
n_steps=n_steps,
return_updates=False,
)
# Save memory is triggered by choosing only last value
y = ys[-1]
......@@ -1629,10 +1695,11 @@ class TestSaveMem:
def test_while_scan_map(self):
xs = vector("xs")
ys, _ = pytensor.scan(
ys = pytensor.scan(
lambda x: (x + 1, {}, until(x + 1 >= 10)),
outputs_info=[None],
sequences=[xs],
return_updates=False,
)
# Save memory is triggered by choosing only last value
y = ys[-1]
......@@ -1656,11 +1723,12 @@ class TestSaveMem:
n_steps = scalar("n_steps", dtype="int64")
# while loop
[ys, zs], _ = pytensor.scan(
[ys, zs] = pytensor.scan(
lambda s, xtm1: ((xtm1 + 1, xtm1 + 1 + s), {}, until(xtm1 >= 99)),
sequences=[seq],
outputs_info=[x0, None],
n_steps=n_steps,
return_updates=False,
)
# Save memory is triggered by choosing only last value
y = ys[-1]
......@@ -1696,10 +1764,11 @@ class TestSaveMem:
val_test = np.zeros(val_shape, dtype=val.dtype)
init = pt.full((2,), val)
ys, _ = pytensor.scan(
ys = pytensor.scan(
fn=lambda *args: pt.add(*args),
outputs_info=[{"initial": init, "taps": (-2, -1)}],
n_steps=100,
return_updates=False,
)
out = ys[:-50] if keep_beginning else ys[-50:]
......@@ -1729,12 +1798,13 @@ def test_inner_replace_dot():
mode = get_default_mode().including("scan") # .excluding("BlasOpt")
o, _ = scan(
o = scan(
lambda hi, him1, W: (hi, dot(hi + him1, W)),
outputs_info=[pt.zeros([h.shape[1]]), None],
sequences=[h],
non_sequences=[W],
mode=mode,
return_updates=False,
)
f = function([W, h], o, mode=mode)
......@@ -1753,11 +1823,12 @@ def test_alloc_inputs1():
def lambda_fn(h, W1, W2):
return dot(h, W1 * W2)
o, _ = scan(
o = scan(
lambda_fn,
outputs_info=h0,
non_sequences=[W1, pt.zeros_like(W2)],
n_steps=5,
return_updates=False,
)
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
......@@ -1786,12 +1857,13 @@ def test_alloc_inputs2():
def lambda_fn(W1, h, W2):
return W1 * dot(h, W2)
o, _ = scan(
o = scan(
lambda_fn,
sequences=pt.zeros_like(W1),
outputs_info=h0,
non_sequences=[pt.zeros_like(W2)],
n_steps=5,
return_updates=False,
)
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
......@@ -1821,12 +1893,13 @@ def test_alloc_inputs3():
def lambda_fn(W1, h, W2):
return W1 * dot(h, W2)
o, _ = scan(
o = scan(
lambda_fn,
sequences=pt.zeros_like(W1),
outputs_info=h0,
non_sequences=[pt.zeros_like(W2)],
n_steps=5,
return_updates=False,
)
# TODO FIXME: This result depends on unrelated rewrites in the "fast" mode.
......@@ -1848,7 +1921,7 @@ def test_opt_order():
x = matrix("x")
A = matrix("A")
z, _updates = scan(dot, sequences=[], non_sequences=[x, A], n_steps=2)
z = scan(dot, sequences=[], non_sequences=[x, A], n_steps=2, return_updates=False)
f = function([x, A], z, mode="FAST_RUN")
topo = f.maker.fgraph.toposort()
......
......@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
A = tensor("A", shape=(3, 3))
x0 = tensor("b", shape=(3, 4))
xs, _ = scan(
xs = scan(
lambda xtm1, A: solve(A, xtm1, assume_a=assume_a, transposed=transposed),
outputs_info=[x0],
non_sequences=[A],
n_steps=10,
return_updates=False,
)
fn_no_opt = function(
......
......@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type():
def test_scan_gradient_core_type():
n_steps = 3
seq = tensor("seq", shape=(n_steps, 1), dtype="float64")
out, _ = scan(
out = scan(
lambda s: s,
sequences=[seq],
n_steps=n_steps,
return_updates=False,
)
vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论