提交 abedb7fb authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Start using new API in tests that don't involve shared updates

上级 78293400
...@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax") ...@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax")
@pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)]) @pytest.mark.parametrize("view", [None, (-1,), slice(-2, None, None)])
def test_scan_sit_sot(view): def test_scan_sit_sot(view):
x0 = pt.scalar("x0", dtype="float64") x0 = pt.scalar("x0", dtype="float64")
xs, _ = scan( xs = scan(
lambda xtm1: xtm1 + 1, lambda xtm1: xtm1 + 1,
outputs_info=[x0], outputs_info=[x0],
n_steps=10, n_steps=10,
return_updates=False,
) )
if view: if view:
xs = xs[view] xs = xs[view]
...@@ -37,10 +38,11 @@ def test_scan_sit_sot(view): ...@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
@pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) @pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)])
def test_scan_mit_sot(view): def test_scan_mit_sot(view):
x0 = pt.vector("x0", dtype="float64", shape=(3,)) x0 = pt.vector("x0", dtype="float64", shape=(3,))
xs, _ = scan( xs = scan(
lambda xtm3, xtm1: xtm3 + xtm1 + 1, lambda xtm3, xtm1: xtm3 + xtm1 + 1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}], outputs_info=[{"initial": x0, "taps": [-3, -1]}],
n_steps=10, n_steps=10,
return_updates=False,
) )
if view: if view:
xs = xs[view] xs = xs[view]
...@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y): ...@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
def step(xtm3, xtm1, ytm4, ytm2): def step(xtm3, xtm1, ytm4, ytm2):
return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2 return xtm3 + ytm4 + 1, xtm1 + ytm2 + 2
[xs, ys], _ = scan( [xs, ys] = scan(
fn=step, fn=step,
outputs_info=[ outputs_info=[
{"initial": x0, "taps": [-3, -1]}, {"initial": x0, "taps": [-3, -1]},
{"initial": y0, "taps": [-4, -2]}, {"initial": y0, "taps": [-4, -2]},
], ],
n_steps=10, n_steps=10,
return_updates=False,
) )
if view_x: if view_x:
xs = xs[view_x] xs = xs[view_x]
...@@ -80,10 +83,8 @@ def test_scan_nit_sot(view): ...@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
xs = pt.vector("x0", dtype="float64", shape=(10,)) xs = pt.vector("x0", dtype="float64", shape=(10,))
ys, _ = scan( ys = scan(
lambda x: pt.exp(x), lambda x: pt.exp(x), outputs_info=[None], sequences=[xs], return_updates=False
outputs_info=[None],
sequences=[xs],
) )
if view: if view:
ys = ys[view] ys = ys[view]
...@@ -106,11 +107,12 @@ def test_scan_mit_mot(): ...@@ -106,11 +107,12 @@ def test_scan_mit_mot():
rho = pt.scalar("rho", dtype="float64") rho = pt.scalar("rho", dtype="float64")
x0 = pt.vector("xs", shape=(2,)) x0 = pt.vector("xs", shape=(2,))
y0 = pt.vector("ys", shape=(3,)) y0 = pt.vector("ys", shape=(3,))
[outs, _], _ = scan( [outs, _] = scan(
step, step,
outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}], outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}],
non_sequences=[rho], non_sequences=[rho],
n_steps=10, n_steps=10,
return_updates=False,
) )
grads = pt.grad(outs.sum(), wrt=[x0, y0, rho]) grads = pt.grad(outs.sum(), wrt=[x0, y0, rho])
compare_jax_and_py( compare_jax_and_py(
...@@ -191,10 +193,11 @@ def test_scan_rng_update(): ...@@ -191,10 +193,11 @@ def test_scan_rng_update():
@pytest.mark.xfail(raises=NotImplementedError) @pytest.mark.xfail(raises=NotImplementedError)
def test_scan_while(): def test_scan_while():
xs, _ = scan( xs = scan(
lambda x: (x + 1, until(x < 10)), lambda x: (x + 1, until(x < 10)),
outputs_info=[pt.zeros(())], outputs_info=[pt.zeros(())],
n_steps=100, n_steps=100,
return_updates=False,
) )
compare_jax_and_py([], [xs], []) compare_jax_and_py([], [xs], [])
...@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq(): ...@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq():
res.name = "y_t" res.name = "y_t"
return res return res
y_scan_pt, _ = scan( y_scan_pt = scan(
fn=input_step_fn, fn=input_step_fn,
outputs_info=[ outputs_info=[
{ {
...@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq(): ...@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq():
non_sequences=[a_pt], non_sequences=[a_pt],
n_steps=10, n_steps=10,
name="y_scan", name="y_scan",
return_updates=False,
) )
y_scan_pt.name = "y" y_scan_pt.name = "y"
y_scan_pt.owner.inputs[0].name = "y_all" y_scan_pt.owner.inputs[0].name = "y_all"
...@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func): ...@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
k = 3 k = 3
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan( xs = scan(
lambda X, A: A @ X, lambda X, A: A @ X,
non_sequences=[A], non_sequences=[A],
outputs_info=[x0], outputs_info=[x0],
n_steps=n_steps, n_steps=n_steps,
return_updates=False,
) )
x0_val = ( x0_val = (
...@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq(): ...@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
A = pt.matrix("A", shape=(k, k)) A = pt.matrix("A", shape=(k, k))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan( xs = scan(
lambda X, A: A @ X, lambda X, A: A @ X,
non_sequences=[A], non_sequences=[A],
sequences=[x], sequences=[x],
n_steps=n_steps, n_steps=n_steps,
return_updates=False,
) )
x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k)
...@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot(): ...@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
B = pt.matrix("B", shape=(3, 3)) B = pt.matrix("B", shape=(3, 3))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan( xs = scan(
lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1, lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1,
outputs_info=[{"initial": x0, "taps": [-3, -1]}], outputs_info=[{"initial": x0, "taps": [-3, -1]}],
non_sequences=[A, B], non_sequences=[A, B],
n_steps=10, n_steps=10,
return_updates=False,
) )
x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3) x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3)
...@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry(): ...@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry():
return A @ x, x.sum() return A @ x, x.sum()
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs, _ = scan( xs = scan(
step, step,
outputs_info=[x0, None], outputs_info=[x0, None],
non_sequences=[A], non_sequences=[A],
n_steps=10, n_steps=10,
mode=get_mode("JAX"), mode=get_mode("JAX"),
return_updates=False,
) )
x0_val = np.arange(3, dtype=config.floatX) x0_val = np.arange(3, dtype=config.floatX)
...@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites(): ...@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
# See issue #426 # See issue #426
A = matrix("A") A = matrix("A")
B = matrix("B") B = matrix("B")
out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) out = scan(
lambda a, b: a @ b,
outputs_info=[A],
non_sequences=[B],
n_steps=2,
return_updates=False,
)
compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX") compare_jax_and_py([A, B], [out], [np.eye(3), np.eye(3)], jax_mode="JAX")
...@@ -353,8 +367,11 @@ def test_dynamic_sequence_length(): ...@@ -353,8 +367,11 @@ def test_dynamic_sequence_length():
x = pt.tensor("x", shape=(None, 3)) x = pt.tensor("x", shape=(None, 3))
out, _ = scan( out = scan(
lambda x: inc_without_static_shape(x), outputs_info=[None], sequences=[x] lambda x: inc_without_static_shape(x),
outputs_info=[None],
sequences=[x],
return_updates=False,
) )
f = function([x], out, mode=get_mode("JAX").excluding("scan")) f = function([x], out, mode=get_mode("JAX").excluding("scan"))
assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1 assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1
...@@ -364,10 +381,11 @@ def test_dynamic_sequence_length(): ...@@ -364,10 +381,11 @@ def test_dynamic_sequence_length():
np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3))) np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3)))
# With known static shape we should always manage, regardless of the internal implementation # With known static shape we should always manage, regardless of the internal implementation
out2, _ = scan( out2 = scan(
lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape), lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape),
outputs_info=[None], outputs_info=[None],
sequences=[x], sequences=[x],
return_updates=False,
) )
f2 = function([x], out2, mode=get_mode("JAX").excluding("scan")) f2 = function([x], out2, mode=get_mode("JAX").excluding("scan"))
np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]])) np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]]))
...@@ -418,11 +436,12 @@ def SEIR_model_logp(): ...@@ -418,11 +436,12 @@ def SEIR_model_logp():
it1 = it0 + ct0 - dt0 it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1 return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan( (st, et, it, logp_c_all, logp_d_all) = scan(
fn=seir_one_step, fn=seir_one_step,
sequences=[C_t, D_t], sequences=[C_t, D_t],
outputs_info=[st0, et0, it0, None, None], outputs_info=[st0, et0, it0, None, None],
non_sequences=[beta, gamma, delta], non_sequences=[beta, gamma, delta],
return_updates=False,
) )
st.name = "S_t" st.name = "S_t"
et.name = "E_t" et.name = "E_t"
...@@ -511,11 +530,12 @@ def cyclical_reduction(): ...@@ -511,11 +530,12 @@ def cyclical_reduction():
max_iter = 100 max_iter = 100
tol = 1e-7 tol = 1e-7
(*_, A1_hat, norm, _n_steps), _ = scan( (*_, A1_hat, norm, _n_steps) = scan(
step, step,
outputs_info=[A, B, C, B, norm, step_num], outputs_info=[A, B, C, B, norm, step_num],
non_sequences=[tol], non_sequences=[tol],
n_steps=max_iter, n_steps=max_iter,
return_updates=False,
) )
A1_hat = A1_hat[-1] A1_hat = A1_hat[-1]
......
...@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark): ...@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark):
it1 = it0 + ct0 - dt0 it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1 return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan( (st, et, it, logp_c_all, logp_d_all) = scan(
fn=seir_one_step, fn=seir_one_step,
sequences=[pt_C, pt_D], sequences=[pt_C, pt_D],
outputs_info=[st0, et0, it0, logp_c, logp_d], outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta], non_sequences=[beta, gamma, delta],
return_updates=False,
) )
st.name = "S_t" st.name = "S_t"
et.name = "E_t" et.name = "E_t"
...@@ -268,7 +269,7 @@ def test_scan_tap_output(): ...@@ -268,7 +269,7 @@ def test_scan_tap_output():
y_t.name = "y_t" y_t.name = "y_t"
return x_t, y_t, pt.fill((10,), z_t) return x_t, y_t, pt.fill((10,), z_t)
scan_res, _ = scan( scan_res = scan(
fn=input_step_fn, fn=input_step_fn,
sequences=[ sequences=[
{ {
...@@ -297,6 +298,7 @@ def test_scan_tap_output(): ...@@ -297,6 +298,7 @@ def test_scan_tap_output():
n_steps=5, n_steps=5,
name="yz_scan", name="yz_scan",
strict=True, strict=True,
return_updates=False,
) )
test_input_vals = [ test_input_vals = [
...@@ -312,11 +314,12 @@ def test_scan_while(): ...@@ -312,11 +314,12 @@ def test_scan_while():
return previous_power * 2, until(previous_power * 2 > max_value) return previous_power * 2, until(previous_power * 2 > max_value)
max_value = pt.scalar() max_value = pt.scalar()
values, _ = scan( values = scan(
power_of_2, power_of_2,
outputs_info=pt.constant(1.0), outputs_info=pt.constant(1.0),
non_sequences=max_value, non_sequences=max_value,
n_steps=1024, n_steps=1024,
return_updates=False,
) )
test_input_vals = [ test_input_vals = [
...@@ -331,11 +334,12 @@ def test_scan_multiple_none_output(): ...@@ -331,11 +334,12 @@ def test_scan_multiple_none_output():
def power_step(prior_result, x): def power_step(prior_result, x):
return prior_result * x, prior_result * x * x, prior_result * x * x * x return prior_result * x, prior_result * x * x, prior_result * x * x * x
result, _ = scan( result = scan(
power_step, power_step,
non_sequences=[A], non_sequences=[A],
outputs_info=[pt.ones_like(A), None, None], outputs_info=[pt.ones_like(A), None, None],
n_steps=3, n_steps=3,
return_updates=False,
) )
test_input_vals = (np.array([1.0, 2.0]),) test_input_vals = (np.array([1.0, 2.0]),)
compare_numba_and_py([A], result, test_input_vals) compare_numba_and_py([A], result, test_input_vals)
...@@ -343,8 +347,12 @@ def test_scan_multiple_none_output(): ...@@ -343,8 +347,12 @@ def test_scan_multiple_none_output():
def test_grad_sitsot(): def test_grad_sitsot():
def get_sum_of_grad(inp): def get_sum_of_grad(inp):
scan_outputs, _updates = scan( scan_outputs = scan(
fn=lambda x: x * 2, outputs_info=[inp], n_steps=5, mode="NUMBA" fn=lambda x: x * 2,
outputs_info=[inp],
n_steps=5,
mode="NUMBA",
return_updates=False,
) )
return grad(scan_outputs.sum(), inp).sum() return grad(scan_outputs.sum(), inp).sum()
...@@ -362,8 +370,11 @@ def test_mitmots_basic(): ...@@ -362,8 +370,11 @@ def test_mitmots_basic():
def inner_fct(seq, state_old, state_current): def inner_fct(seq, state_old, state_current):
return state_old * 2 + state_current + seq return state_old * 2 + state_current + seq
out, _ = scan( out = scan(
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]} inner_fct,
sequences=seq,
outputs_info={"initial": init_x, "taps": [-2, -1]},
return_updates=False,
) )
g_outs = grad(out.sum(), [seq, init_x]) g_outs = grad(out.sum(), [seq, init_x])
...@@ -383,10 +394,11 @@ def test_mitmots_basic(): ...@@ -383,10 +394,11 @@ def test_mitmots_basic():
def test_inner_graph_optimized(): def test_inner_graph_optimized():
"""Test that inner graph of Scan is optimized""" """Test that inner graph of Scan is optimized"""
xs = vector("xs") xs = vector("xs")
seq, _ = scan( seq = scan(
fn=lambda x: log(1 + x), fn=lambda x: log(1 + x),
sequences=[xs], sequences=[xs],
mode=get_mode("NUMBA"), mode=get_mode("NUMBA"),
return_updates=False,
) )
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise # Disable scan pushout, in which case the whole scan is replaced by an Elemwise
...@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark): ...@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot2 = (sitsot1 + mitsot3) / np.sqrt(2) sitsot2 = (sitsot1 + mitsot3) / np.sqrt(2)
return mitsot3, sitsot2 return mitsot3, sitsot2
outs, _ = scan( outs = scan(
fn=step, fn=step,
sequences=[seq1, seq2], sequences=[seq1, seq2],
outputs_info=[ outputs_info=[
dict(initial=mitsot_init, taps=[-2, -1]), dict(initial=mitsot_init, taps=[-2, -1]),
dict(initial=sitsot_init, taps=[-1]), dict(initial=sitsot_init, taps=[-1]),
], ],
return_updates=False,
) )
rng = np.random.default_rng(474) rng = np.random.default_rng(474)
...@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant): ...@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant):
y = ytm1 + 1 + ytm2 + a y = ytm1 + 1 + ytm2 + a
return z, x, z + x + y, y return z, x, z + x + y, y
[zs, xs, ws, ys], _ = scan( [zs, xs, ws, ys] = scan(
fn=step, fn=step,
outputs_info=[ outputs_info=[
dict(initial=z0, taps=[-3, -1]), dict(initial=z0, taps=[-3, -1]),
...@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant): ...@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant):
], ],
non_sequences=[a], non_sequences=[a],
n_steps=n_steps, n_steps=n_steps,
return_updates=False,
) )
numba_fn, _ = compare_numba_and_py( numba_fn, _ = compare_numba_and_py(
[n_steps] * (not n_steps_constant) + [a, x0, y0, z0], [n_steps] * (not n_steps_constant) + [a, x0, y0, z0],
...@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant): ...@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant):
class TestScanSITSOTBuffer: class TestScanSITSOTBuffer:
def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None): def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
x0 = pt.vector(shape=(op_size,), dtype="float64") x0 = pt.vector(shape=(op_size,), dtype="float64")
xs, _ = pytensor.scan( xs = pytensor.scan(
fn=lambda xtm1: (xtm1 + 1), fn=lambda xtm1: (xtm1 + 1),
outputs_info=[x0], outputs_info=[x0],
n_steps=n_steps - 1, # 1- makes it easier to align/misalign n_steps=n_steps - 1, # 1- makes it easier to align/misalign
return_updates=False,
) )
if buffer_size == "unit": if buffer_size == "unit":
xs_kept = xs[-1] # Only last state is used xs_kept = xs[-1] # Only last state is used
...@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer: ...@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer:
init_x = pt.vector("init_x", shape=(2,)) init_x = pt.vector("init_x", shape=(2,))
n_steps = pt.iscalar("n_steps") n_steps = pt.iscalar("n_steps")
output, _ = scan( output = scan(
f_pow2, f_pow2,
sequences=[], sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}], outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
non_sequences=[], non_sequences=[],
n_steps=n_steps_val if constant_n_steps else n_steps, n_steps=n_steps_val if constant_n_steps else n_steps,
return_updates=False,
) )
init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype) init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype)
......
...@@ -294,7 +294,7 @@ class TestScan: ...@@ -294,7 +294,7 @@ class TestScan:
def test_clone(self): def test_clone(self):
a = vector() a = vector()
output, _ = scan(fn=lambda x: x**2, sequences=[a]) output = scan(fn=lambda x: x**2, sequences=[a], return_updates=False)
scan_op = output.owner.op scan_op = output.owner.op
assert isinstance(scan_op, Scan) assert isinstance(scan_op, Scan)
...@@ -320,7 +320,7 @@ class TestScan: ...@@ -320,7 +320,7 @@ class TestScan:
state = scalar("state") state = scalar("state")
n_steps = iscalar("nsteps") n_steps = iscalar("nsteps")
output, updates = scan( output = scan(
f_pow2, f_pow2,
[], [],
state, state,
...@@ -328,10 +328,9 @@ class TestScan: ...@@ -328,10 +328,9 @@ class TestScan:
n_steps=n_steps, n_steps=n_steps,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
_my_f = function( _my_f = function([state, n_steps], output, allow_input_downcast=True)
[state, n_steps], output, updates=updates, allow_input_downcast=True
)
origdir = Path.cwd() origdir = Path.cwd()
tmpdir = None tmpdir = None
...@@ -368,11 +367,9 @@ class TestScan: ...@@ -368,11 +367,9 @@ class TestScan:
state = scalar("state") state = scalar("state")
n_steps = iscalar("nsteps") n_steps = iscalar("nsteps")
output, updates = scan(f_pow2, [], state, [], n_steps=n_steps) output = scan(f_pow2, [], state, [], n_steps=n_steps, return_updates=False)
f = function( f = function([state, n_steps], output, allow_input_downcast=True)
[state, n_steps], output, updates=updates, allow_input_downcast=True
)
scan_node = [ scan_node = [
node for node in f.maker.fgraph.toposort() if isinstance(node.op, Scan) node for node in f.maker.fgraph.toposort() if isinstance(node.op, Scan)
...@@ -410,7 +407,9 @@ class TestScan: ...@@ -410,7 +407,9 @@ class TestScan:
return 2 * x_tm1 return 2 * x_tm1
n_steps = iscalar("n_steps") n_steps = iscalar("n_steps")
values, _ = scan(f_pow, outputs_info=(x_init,), n_steps=n_steps) values = scan(
f_pow, outputs_info=(x_init,), n_steps=n_steps, return_updates=False
)
update_fn = function((x_init, n_steps), values, mode=mode) update_fn = function((x_init, n_steps), values, mode=mode)
...@@ -443,7 +442,9 @@ class TestScan: ...@@ -443,7 +442,9 @@ class TestScan:
return 2 * x_i return 2 * x_i
with config.change_flags(mode=mode): with config.change_flags(mode=mode):
values, _ = scan(inner_fn, outputs_info=(x_init,), sequences=x) values = scan(
inner_fn, outputs_info=(x_init,), sequences=x, return_updates=False
)
values_fn = function((x_init, x), values) values_fn = function((x_init, x), values)
assert isinstance(values.owner.inputs[0].owner.op, Scan) assert isinstance(values.owner.inputs[0].owner.op, Scan)
...@@ -474,7 +475,7 @@ class TestScan: ...@@ -474,7 +475,7 @@ class TestScan:
return 2 * x_i return 2 * x_i
with config.change_flags(mode=mode): with config.change_flags(mode=mode):
values, _ = scan(inner_fn, sequences=x) values = scan(inner_fn, sequences=x, return_updates=False)
values_fn = function((x,), values) values_fn = function((x,), values)
assert isinstance(values.owner.op, Scan) assert isinstance(values.owner.op, Scan)
...@@ -491,7 +492,9 @@ class TestScan: ...@@ -491,7 +492,9 @@ class TestScan:
# Compile the PyTensor function # Compile the PyTensor function
n_steps = 2 n_steps = 2
inp = matrix() inp = matrix()
broadcasted_inp, _ = scan(lambda x: x, non_sequences=[inp], n_steps=n_steps) broadcasted_inp = scan(
lambda x: x, non_sequences=[inp], n_steps=n_steps, return_updates=False
)
out = broadcasted_inp.sum() out = broadcasted_inp.sum()
gr = grad(out, inp) gr = grad(out, inp)
fun = function([inp], [broadcasted_inp, gr]) fun = function([inp], [broadcasted_inp, gr])
...@@ -519,7 +522,7 @@ class TestScan: ...@@ -519,7 +522,7 @@ class TestScan:
W_in = scalar("win") W_in = scalar("win")
W = scalar("w") W = scalar("w")
output, updates = scan( output = scan(
f_rnn, f_rnn,
u, u,
x0, x0,
...@@ -527,11 +530,10 @@ class TestScan: ...@@ -527,11 +530,10 @@ class TestScan:
n_steps=None, n_steps=None,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
f2 = function( f2 = function([u, x0, W_in, W], output, allow_input_downcast=True)
[u, x0, W_in, W], output, updates=updates, allow_input_downcast=True
)
# get random initial values # get random initial values
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
v_u = rng.uniform(-5.0, 5.0, size=(4,)) v_u = rng.uniform(-5.0, 5.0, size=(4,))
...@@ -561,7 +563,7 @@ class TestScan: ...@@ -561,7 +563,7 @@ class TestScan:
def f_rnn_shared(u_t, x_tm1, tmp_W_in, tmp_W): def f_rnn_shared(u_t, x_tm1, tmp_W_in, tmp_W):
return u_t * tmp_W_in + x_tm1 * tmp_W return u_t * tmp_W_in + x_tm1 * tmp_W
output, updates = scan( output = scan(
f_rnn_shared, f_rnn_shared,
u, u,
x0, x0,
...@@ -569,8 +571,9 @@ class TestScan: ...@@ -569,8 +571,9 @@ class TestScan:
n_steps=None, n_steps=None,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
f3 = function([u, x0], output, updates=updates, allow_input_downcast=True) f3 = function([u, x0], output, allow_input_downcast=True)
# get random initial values # get random initial values
v_u = rng.uniform(-5.0, 5.0, size=(4,)) v_u = rng.uniform(-5.0, 5.0, size=(4,))
...@@ -688,11 +691,14 @@ class TestScan: ...@@ -688,11 +691,14 @@ class TestScan:
# this test refers to a bug reported by Nicolas # this test refers to a bug reported by Nicolas
# Boulanger-Lewandowski June 6th # Boulanger-Lewandowski June 6th
x = dvector() x = dvector()
y, updates = scan( y = scan(
lambda x: [x], sequences=dict(input=x, taps=[-1]), outputs_info=[None] lambda x: [x],
sequences=dict(input=x, taps=[-1]),
outputs_info=[None],
return_updates=False,
) )
inp = np.arange(5).astype("float64") inp = np.arange(5).astype("float64")
rval = function([x], y, updates=updates)(inp) rval = function([x], y)(inp)
assert np.all(rval == inp[:-1]) assert np.all(rval == inp[:-1])
def test_output_only(self): def test_output_only(self):
...@@ -701,11 +707,18 @@ class TestScan: ...@@ -701,11 +707,18 @@ class TestScan:
u = vector("u") u = vector("u")
outputs, updates = scan( outputs = scan(
f_rnn, u, [], [], n_steps=None, truncate_gradient=-1, go_backwards=False f_rnn,
u,
[],
[],
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
) )
f2 = function([u], outputs, updates=updates, allow_input_downcast=True) f2 = function([u], outputs, allow_input_downcast=True)
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
v_u = rng.uniform(-5.0, 5.0, size=(5,)) v_u = rng.uniform(-5.0, 5.0, size=(5,))
...@@ -722,7 +735,7 @@ class TestScan: ...@@ -722,7 +735,7 @@ class TestScan:
W_in = scalar("win") W_in = scalar("win")
W = scalar("w") W = scalar("w")
output, updates = scan( output = scan(
f_rnn, f_rnn,
u, u,
x0, x0,
...@@ -730,11 +743,10 @@ class TestScan: ...@@ -730,11 +743,10 @@ class TestScan:
n_steps=None, n_steps=None,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=True, go_backwards=True,
return_updates=False,
) )
f2 = function( f2 = function([u, x0, W_in, W], output, allow_input_downcast=True)
[u, x0, W_in, W], output, updates=updates, allow_input_downcast=True
)
# get random initial values # get random initial values
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
v_u = rng.uniform(-5.0, 5.0, size=(4,)) v_u = rng.uniform(-5.0, 5.0, size=(4,))
...@@ -797,8 +809,8 @@ class TestScan: ...@@ -797,8 +809,8 @@ class TestScan:
def test_hash(self): def test_hash(self):
x = vector() x = vector()
y = vector() y = vector()
scan1, _updates = scan(lambda _x: _x + 1, x) scan1 = scan(lambda _x: _x + 1, x, return_updates=False)
scan2, _updates = scan(lambda _x: _x + 1, y) scan2 = scan(lambda _x: _x + 1, y, return_updates=False)
assert scan1.owner.op == scan2.owner.op assert scan1.owner.op == scan2.owner.op
assert hash(scan1.owner.op) == hash(scan2.owner.op) assert hash(scan1.owner.op) == hash(scan2.owner.op)
...@@ -809,9 +821,24 @@ class TestScan: ...@@ -809,9 +821,24 @@ class TestScan:
y = vector("y") y = vector("y")
c = scalar("c") c = scalar("c")
scan_a, _ = scan(lambda x, y, c: x + y + c, sequences=[x, y], non_sequences=[c]) scan_a = scan(
scan_b, _ = scan(lambda x, y, c: x + y + c, sequences=[x, y], non_sequences=[c]) lambda x, y, c: x + y + c,
scan_c, _ = scan(lambda x, y, c: x + y + c, sequences=[y, x], non_sequences=[c]) sequences=[x, y],
non_sequences=[c],
return_updates=False,
)
scan_b = scan(
lambda x, y, c: x + y + c,
sequences=[x, y],
non_sequences=[c],
return_updates=False,
)
scan_c = scan(
lambda x, y, c: x + y + c,
sequences=[y, x],
non_sequences=[c],
return_updates=False,
)
assert scan_b is not scan_a assert scan_b is not scan_a
assert scan_c is not scan_a assert scan_c is not scan_a
...@@ -1006,7 +1033,7 @@ class TestScan: ...@@ -1006,7 +1033,7 @@ class TestScan:
def lambda_fn(x_t): def lambda_fn(x_t):
return x_t + 1, until(x_t > 3) return x_t + 1, until(x_t > 3)
o, _ = scan(lambda_fn, x) o = scan(lambda_fn, x, return_updates=False)
f = function([x], o) f = function([x], o)
vx = np.zeros((50,), dtype=config.floatX) vx = np.zeros((50,), dtype=config.floatX)
vx[23] = 4 vx[23] = 4
...@@ -1019,7 +1046,7 @@ class TestScan: ...@@ -1019,7 +1046,7 @@ class TestScan:
def lambda_fn(x_t): def lambda_fn(x_t):
return x_t + 1, until(x_t > 3) return x_t + 1, until(x_t > 3)
o, _ = scan(lambda_fn, x) o = scan(lambda_fn, x, return_updates=False)
f = function([x], o.shape[0], mode=mode_with_opt) f = function([x], o.shape[0], mode=mode_with_opt)
vx = np.zeros((50,), dtype=config.floatX) vx = np.zeros((50,), dtype=config.floatX)
...@@ -1029,11 +1056,12 @@ class TestScan: ...@@ -1029,11 +1056,12 @@ class TestScan:
def test_infer_shape_nsteps_smaller_seq_length(self): def test_infer_shape_nsteps_smaller_seq_length(self):
x = vector("x") x = vector("x")
[o1, o2], _ = scan( [o1, o2] = scan(
lambda x, y: (x + 1, y + x), lambda x, y: (x + 1, y + x),
sequences=x, sequences=x,
outputs_info=[None, x[0]], outputs_info=[None, x[0]],
n_steps=20, n_steps=20,
return_updates=False,
) )
f = function([x], [o1.shape[0], o2.shape[0]], mode=mode_with_opt) f = function([x], [o1.shape[0], o2.shape[0]], mode=mode_with_opt)
...@@ -1071,17 +1099,18 @@ class TestScan: ...@@ -1071,17 +1099,18 @@ class TestScan:
mode = MonitorMode(post_func=detect_large_outputs) mode = MonitorMode(post_func=detect_large_outputs)
# Symbolic description of the result # Symbolic description of the result
result, updates = scan( result = scan(
fn=lambda prior_result, A: prior_result * A, fn=lambda prior_result, A: prior_result * A,
outputs_info=pt.ones_like(A), outputs_info=pt.ones_like(A),
non_sequences=A, non_sequences=A,
n_steps=k, n_steps=k,
mode=mode, mode=mode,
return_updates=False,
) )
final_result = result[-1] final_result = result[-1]
f = function(inputs=[A, k], outputs=final_result, updates=updates) f = function(inputs=[A, k], outputs=final_result)
f(np.asarray([2, 3, 0.1, 0, 1], dtype=config.floatX), 4) f(np.asarray([2, 3, 0.1, 0, 1], dtype=config.floatX), 4)
# There should be 3 outputs greater than 10: prior_result[0] at step 3, # There should be 3 outputs greater than 10: prior_result[0] at step 3,
...@@ -1103,10 +1132,11 @@ class TestScan: ...@@ -1103,10 +1132,11 @@ class TestScan:
y.name = "y" y.name = "y"
gy = grad(y, x) gy = grad(y, x)
gy.name = "gy" gy.name = "gy"
hy, _updates = scan( hy = scan(
lambda i, gy, x: grad(gy[i] * fc2, x), lambda i, gy, x: grad(gy[i] * fc2, x),
sequences=pt.arange(gy.shape[0]), sequences=pt.arange(gy.shape[0]),
non_sequences=[gy, x], non_sequences=[gy, x],
return_updates=False,
) )
f = function([x, A], hy, allow_input_downcast=True) f = function([x, A], hy, allow_input_downcast=True)
...@@ -1123,8 +1153,13 @@ class TestScan: ...@@ -1123,8 +1153,13 @@ class TestScan:
def test_sequence_is_scan(self, mode): def test_sequence_is_scan(self, mode):
"""Make sure that a `Scan` can be used as a sequence input to another `Scan`.""" """Make sure that a `Scan` can be used as a sequence input to another `Scan`."""
x0 = scalar("x0") x0 = scalar("x0")
scan_1, _ = scan(lambda x: x + 1, outputs_info={"initial": x0}, n_steps=10) scan_1 = scan(
scan_2, _ = scan(lambda x: x + 1, sequences=[scan_1]) lambda x: x + 1,
outputs_info={"initial": x0},
n_steps=10,
return_updates=False,
)
scan_2 = scan(lambda x: x + 1, sequences=[scan_1], return_updates=False)
with config.change_flags(mode=mode): with config.change_flags(mode=mode):
scan_2_fn = function([x0], scan_2) scan_2_fn = function([x0], scan_2)
...@@ -1185,7 +1220,7 @@ class TestScan: ...@@ -1185,7 +1220,7 @@ class TestScan:
def test_blockwise_scan(self): def test_blockwise_scan(self):
x = pt.tensor("x", shape=()) x = pt.tensor("x", shape=())
out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10) out = scan(lambda x: x + 1, outputs_info=[x], n_steps=10, return_updates=False)
x_vec = pt.tensor("x_vec", shape=(None,)) x_vec = pt.tensor("x_vec", shape=(None,))
out_vec = vectorize_graph(out, {x: x_vec}) out_vec = vectorize_graph(out, {x: x_vec})
...@@ -1203,13 +1238,14 @@ class TestScan: ...@@ -1203,13 +1238,14 @@ class TestScan:
a0 = shared(np.arange(2)) a0 = shared(np.arange(2))
b0 = shared(np.arange(2)) b0 = shared(np.arange(2))
(a, _b), _ = scan( (a, _b) = scan(
fn, fn,
outputs_info=[ outputs_info=[
{"initial": a0, "taps": [-2, -1]}, {"initial": a0, "taps": [-2, -1]},
{"initial": b0, "taps": [-2, -1]}, {"initial": b0, "taps": [-2, -1]},
], ],
n_steps=2, n_steps=2,
return_updates=False,
) )
grad(a[-1], a0) grad(a[-1], a0)
...@@ -1241,8 +1277,11 @@ class TestScan: ...@@ -1241,8 +1277,11 @@ class TestScan:
state_next = state_old * 2 + state_current + seq state_next = state_old * 2 + state_current + seq
return state_next return state_next
out, _ = scan( out = scan(
inner_fct, sequences=seq, outputs_info={"initial": x, "taps": [-2, -1]} inner_fct,
sequences=seq,
outputs_info={"initial": x, "taps": [-2, -1]},
return_updates=False,
) )
g_out = grad(out.sum(), [seq, x]) g_out = grad(out.sum(), [seq, x])
...@@ -1302,12 +1341,13 @@ class TestScan: ...@@ -1302,12 +1341,13 @@ class TestScan:
new_y = pt.switch(cond, y, sigmoid(x)) new_y = pt.switch(cond, y, sigmoid(x))
return new_cond, new_x, new_y return new_cond, new_x, new_y
values, _ = scan( values = scan(
inner_fn, inner_fn,
outputs_info=[c, x, y], outputs_info=[c, x, y],
n_steps=10, n_steps=10,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
gX, gY = grad(values[1].sum(), [x, y]) gX, gY = grad(values[1].sum(), [x, y])
f = function([c, x, y], [gX, gY], allow_input_downcast=True) f = function([c, x, y], [gX, gY], allow_input_downcast=True)
...@@ -1762,11 +1802,12 @@ class TestScan: ...@@ -1762,11 +1802,12 @@ class TestScan:
outputs_info = [None, dict(initial=out_init, taps=[-3])] outputs_info = [None, dict(initial=out_init, taps=[-3])]
scan_outputs, _ = scan( scan_outputs = scan(
fn=inner_fct, fn=inner_fct,
sequences=seq, sequences=seq,
outputs_info=outputs_info, outputs_info=outputs_info,
non_sequences=non_seq, non_sequences=non_seq,
return_updates=False,
) )
# Attempt to take various gradients # Attempt to take various gradients
...@@ -1834,7 +1875,9 @@ class TestScan: ...@@ -1834,7 +1875,9 @@ class TestScan:
dict(initial=out_init[3], taps=[-2, -1]), dict(initial=out_init[3], taps=[-2, -1]),
] ]
scan_outputs, _ = scan(fn=inner_fct, outputs_info=outputs_info, n_steps=10) scan_outputs = scan(
fn=inner_fct, outputs_info=outputs_info, n_steps=10, return_updates=False
)
grad(scan_outputs[0].sum(), out_init[1]) grad(scan_outputs[0].sum(), out_init[1])
...@@ -1857,11 +1900,12 @@ class TestScan: ...@@ -1857,11 +1900,12 @@ class TestScan:
x = scalar("x") x = scalar("x")
_max_coefficients_supported = 1000 _max_coefficients_supported = 1000
full_range = pt.arange(_max_coefficients_supported) full_range = pt.arange(_max_coefficients_supported)
components, _updates = scan( components = scan(
fn=lambda coeff, power, free_var: coeff * (free_var**power), fn=lambda coeff, power, free_var: coeff * (free_var**power),
outputs_info=None, outputs_info=None,
sequences=[c, full_range], sequences=[c, full_range],
non_sequences=x, non_sequences=x,
return_updates=False,
) )
P = components.sum() P = components.sum()
dP = grad(P, x) dP = grad(P, x)
...@@ -1877,11 +1921,12 @@ class TestScan: ...@@ -1877,11 +1921,12 @@ class TestScan:
x = scalar("x") x = scalar("x")
_max_coefficients_supported = 1000 _max_coefficients_supported = 1000
full_range = pt.arange(_max_coefficients_supported) full_range = pt.arange(_max_coefficients_supported)
components, _updates = scan( components = scan(
fn=lambda coeff, power, free_var: coeff * (free_var**power), fn=lambda coeff, power, free_var: coeff * (free_var**power),
outputs_info=None, outputs_info=None,
sequences=[c, full_range], sequences=[c, full_range],
non_sequences=x, non_sequences=x,
return_updates=False,
) )
P = components.sum() P = components.sum()
dP = grad(P, x).sum() dP = grad(P, x).sum()
...@@ -1968,8 +2013,13 @@ class TestScan: ...@@ -1968,8 +2013,13 @@ class TestScan:
_W = specify_shape(W, v_W.shape) _W = specify_shape(W, v_W.shape)
_W.name = "_W" _W.name = "_W"
o, _ = scan( o = scan(
rnn_fn, sequences=_u, outputs_info=_h0, non_sequences=_W, name="rnn_fn" rnn_fn,
sequences=_u,
outputs_info=_h0,
non_sequences=_W,
name="rnn_fn",
return_updates=False,
) )
o = o[-1] o = o[-1]
eu = matrix("eu") eu = matrix("eu")
...@@ -1983,25 +2033,28 @@ class TestScan: ...@@ -1983,25 +2033,28 @@ class TestScan:
[u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore" [u, h0, W, eu, eh0, eW], [nwo_u, nwo_h0, nwo_W], on_unused_input="ignore"
) )
n2o_u, _ = scan( n2o_u = scan(
lambda i, o, u, h0, W, eu: (grad(o[i], u) * eu).sum(), lambda i, o, u, h0, W, eu: (grad(o[i], u) * eu).sum(),
sequences=pt.arange(o.shape[0]), sequences=pt.arange(o.shape[0]),
non_sequences=[o, u, h0, W, eu], non_sequences=[o, u, h0, W, eu],
name="jacobU", name="jacobU",
return_updates=False,
) )
n2o_h0, _ = scan( n2o_h0 = scan(
lambda i, o, u, h0, W, eh0: (grad(o[i], h0) * eh0).sum(), lambda i, o, u, h0, W, eh0: (grad(o[i], h0) * eh0).sum(),
sequences=pt.arange(o.shape[0]), sequences=pt.arange(o.shape[0]),
non_sequences=[o, u, h0, W, eh0], non_sequences=[o, u, h0, W, eh0],
name="jacobh", name="jacobh",
return_updates=False,
) )
n2o_W, _ = scan( n2o_W = scan(
lambda i, o, u, h0, W, eW: (grad(o[i], W) * eW).sum(), lambda i, o, u, h0, W, eW: (grad(o[i], W) * eW).sum(),
sequences=pt.arange(o.shape[0]), sequences=pt.arange(o.shape[0]),
non_sequences=[o, u, h0, W, eW], non_sequences=[o, u, h0, W, eW],
name="jacobW", name="jacobW",
return_updates=False,
) )
fn_test = function( fn_test = function(
...@@ -2132,10 +2185,11 @@ class TestScan: ...@@ -2132,10 +2185,11 @@ class TestScan:
transfer = sigmoid transfer = sigmoid
hidden_rec, _ = scan( hidden_rec = scan(
lambda x, h_tm1: transfer(dot(h_tm1, W2) + x), lambda x, h_tm1: transfer(dot(h_tm1, W2) + x),
sequences=hidden, sequences=hidden,
outputs_info=[pt.zeros_like(hidden[0])], outputs_info=[pt.zeros_like(hidden[0])],
return_updates=False,
) )
hidden_rec.reshape( hidden_rec.reshape(
...@@ -2168,12 +2222,13 @@ class TestScan: ...@@ -2168,12 +2222,13 @@ class TestScan:
def step(s, xtm2, xtm1, z): def step(s, xtm2, xtm1, z):
return s * ((xtm2 * 0 + xtm1) ** 2) * (z / 2) return s * ((xtm2 * 0 + xtm1) ** 2) * (z / 2)
xs, _ = scan( xs = scan(
step, step,
sequences=[seq], sequences=[seq],
outputs_info=[{"initial": x0, "taps": (-2, -1)}], outputs_info=[{"initial": x0, "taps": (-2, -1)}],
non_sequences=[z], non_sequences=[z],
n_steps=2, n_steps=2,
return_updates=False,
) )
last_x = xs[-1] last_x = xs[-1]
...@@ -2254,11 +2309,12 @@ class TestScan: ...@@ -2254,11 +2309,12 @@ class TestScan:
raise ValueError(f"Invalid case: {case}") raise ValueError(f"Invalid case: {case}")
seq = vector("seq") seq = vector("seq")
xs, _ = scan( xs = scan(
step, step,
sequences=[seq], sequences=[seq],
non_sequences=non_sequences, non_sequences=non_sequences,
strict=strict, strict=strict,
return_updates=False,
) )
x0 = xs[0] x0 = xs[0]
...@@ -2298,7 +2354,7 @@ def test_cvm_exception_handling(mode): ...@@ -2298,7 +2354,7 @@ def test_cvm_exception_handling(mode):
def scan_fn(): def scan_fn():
return myop(pt.as_tensor(1)) return myop(pt.as_tensor(1))
res, _ = scan(scan_fn, n_steps=4, mode=mode) res = scan(scan_fn, n_steps=4, mode=mode, return_updates=False)
res_fn = function([], res, mode=mode) res_fn = function([], res, mode=mode)
...@@ -2328,14 +2384,14 @@ def test_cython_performance(benchmark): ...@@ -2328,14 +2384,14 @@ def test_cython_performance(benchmark):
py_res = f_py() py_res = f_py()
s_r = pt.as_tensor_variable(r, dtype=config.floatX) s_r = pt.as_tensor_variable(r, dtype=config.floatX)
s_y, updates = scan( s_y = scan(
fn=lambda ri, rii, M: ri + M * rii, fn=lambda ri, rii, M: ri + M * rii,
sequences=[s_r[1:]], sequences=[s_r[1:]],
non_sequences=[pt.as_tensor_variable(M, dtype=config.floatX)], non_sequences=[pt.as_tensor_variable(M, dtype=config.floatX)],
outputs_info=s_r[0], outputs_info=s_r[0],
mode=Mode(linker="cvm", optimizer="fast_run"), mode=Mode(linker="cvm", optimizer="fast_run"),
return_updates=False,
) )
assert not updates
f_cvm = function([], s_y, mode="FAST_RUN") f_cvm = function([], s_y, mode="FAST_RUN")
f_cvm.trust_input = True f_cvm.trust_input = True
...@@ -2357,9 +2413,7 @@ def test_compute_test_values(): ...@@ -2357,9 +2413,7 @@ def test_compute_test_values():
y = shared(np.arange(3, dtype=config.floatX), name="y") y = shared(np.arange(3, dtype=config.floatX), name="y")
z, updates = scan(fn=lambda u, v: u + v, sequences=[x, y]) z = scan(fn=lambda u, v: u + v, sequences=[x, y], return_updates=False)
assert not updates
z_grad = grad(z.sum(), x) z_grad = grad(z.sum(), x)
...@@ -2368,9 +2422,9 @@ def test_compute_test_values(): ...@@ -2368,9 +2422,9 @@ def test_compute_test_values():
# Use `non_sequences` this time # Use `non_sequences` this time
y = shared(np.arange(9, dtype=config.floatX).reshape(3, 3), name="y") y = shared(np.arange(9, dtype=config.floatX).reshape(3, 3), name="y")
z, updates = scan(fn=lambda u, v: u + v, sequences=[x], non_sequences=[y]) z = scan(
fn=lambda u, v: u + v, sequences=[x], non_sequences=[y], return_updates=False
assert not updates )
z_grad = grad(z.sum(), x) z_grad = grad(z.sum(), x)
...@@ -2399,20 +2453,22 @@ def test_compute_test_value_grad(): ...@@ -2399,20 +2453,22 @@ def test_compute_test_value_grad():
def loss_ti(ti, sum_ti, mi, W): def loss_ti(ti, sum_ti, mi, W):
return W.sum().sum().sum() + sum_ti return W.sum().sum().sum() + sum_ti
result_ti, _ = scan( result_ti = scan(
fn=loss_ti, fn=loss_ti,
outputs_info=outputs_ti, outputs_info=outputs_ti,
sequences=pt.arange(W.shape[1], dtype="int32"), sequences=pt.arange(W.shape[1], dtype="int32"),
non_sequences=[mi, W], non_sequences=[mi, W],
return_updates=False,
) )
lossmi = result_ti[-1] lossmi = result_ti[-1]
return sum_mi + lossmi return sum_mi + lossmi
result_mi, _ = scan( result_mi = scan(
fn=loss_mi, fn=loss_mi,
outputs_info=outputs_mi, outputs_info=outputs_mi,
sequences=pt.arange(W.shape[0], dtype="int32"), sequences=pt.arange(W.shape[0], dtype="int32"),
non_sequences=[W], non_sequences=[W],
return_updates=False,
) )
loss = result_mi[-1] loss = result_mi[-1]
...@@ -2436,11 +2492,12 @@ def test_compute_test_value_grad_cast(): ...@@ -2436,11 +2492,12 @@ def test_compute_test_value_grad_cast():
name="w", name="w",
) )
outputs, _ = scan( outputs = scan(
lambda i, h, w: (dot(h[i], w), i), lambda i, h, w: (dot(h[i], w), i),
outputs_info=[None, 0], outputs_info=[None, 0],
non_sequences=[h, w], non_sequences=[h, w],
n_steps=3, n_steps=3,
return_updates=False,
) )
grad(outputs[0].sum(), w) grad(outputs[0].sum(), w)
...@@ -2449,11 +2506,12 @@ def test_compute_test_value_grad_cast(): ...@@ -2449,11 +2506,12 @@ def test_compute_test_value_grad_cast():
def test_constant_folding_n_steps(): def test_constant_folding_n_steps():
# The following code used to crash at revision 2060b8f, in the constant # The following code used to crash at revision 2060b8f, in the constant
# folding optimization step. # folding optimization step.
res, _ = scan( res = scan(
lambda x: x * 2, lambda x: x * 2,
outputs_info=pt.ones(()), outputs_info=pt.ones(()),
# The constant `n_steps` was causing the crash. # The constant `n_steps` was causing the crash.
n_steps=10, n_steps=10,
return_updates=False,
) )
with config.change_flags(on_opt_error="raise"): with config.change_flags(on_opt_error="raise"):
function([], res)() function([], res)()
...@@ -2478,10 +2536,11 @@ def test_outputs_taps_check(): ...@@ -2478,10 +2536,11 @@ def test_outputs_taps_check():
def test_inconsistent_broadcast_error(): def test_inconsistent_broadcast_error():
x = tensor3() x = tensor3()
initial_x = pt.constant(np.zeros((1, 10))) initial_x = pt.constant(np.zeros((1, 10)))
y, _updates = scan( y = scan(
fn=lambda x, prev_x: x + prev_x, fn=lambda x, prev_x: x + prev_x,
sequences=x, sequences=x,
outputs_info=[dict(initial=initial_x)], outputs_info=[dict(initial=initial_x)],
return_updates=False,
) )
# Error, because the broadcast patterns are inconsistent. # Error, because the broadcast patterns are inconsistent.
with pytest.raises(TypeError): with pytest.raises(TypeError):
...@@ -2509,10 +2568,11 @@ class TestGradUntil: ...@@ -2509,10 +2568,11 @@ class TestGradUntil:
self.numpy_gradient = 2 * np.concatenate([self.seq[:7], z], axis=0) self.numpy_gradient = 2 * np.concatenate([self.seq[:7], z], axis=0)
def test_grad_until(self): def test_grad_until(self):
r, _ = scan( r = scan(
lambda x, u: (x * x, until(x > u)), lambda x, u: (x * x, until(x > u)),
sequences=self.x, sequences=self.x,
non_sequences=[self.threshold], non_sequences=[self.threshold],
return_updates=False,
) )
g = grad(r.sum(), self.x) g = grad(r.sum(), self.x)
f = function([self.x, self.threshold], [r, g]) f = function([self.x, self.threshold], [r, g])
...@@ -2528,10 +2588,11 @@ class TestGradUntil: ...@@ -2528,10 +2588,11 @@ class TestGradUntil:
X = matrix(name="x") X = matrix(name="x")
arr = tile_array(self.seq) arr = tile_array(self.seq)
r, _ = scan( r = scan(
lambda x, u: (x * x, until(pt_all(x > u))), lambda x, u: (x * x, until(pt_all(x > u))),
sequences=X, sequences=X,
non_sequences=[self.threshold], non_sequences=[self.threshold],
return_updates=False,
) )
g = grad(r.sum(), X) g = grad(r.sum(), X)
f = function([X, self.threshold], [r, g]) f = function([X, self.threshold], [r, g])
...@@ -2542,11 +2603,12 @@ class TestGradUntil: ...@@ -2542,11 +2603,12 @@ class TestGradUntil:
def test_grad_until_and_truncate(self): def test_grad_until_and_truncate(self):
n = 3 n = 3
r, _ = scan( r = scan(
lambda x, u: (x * x, until(x > u)), lambda x, u: (x * x, until(x > u)),
sequences=self.x, sequences=self.x,
non_sequences=[self.threshold], non_sequences=[self.threshold],
truncate_gradient=n, truncate_gradient=n,
return_updates=False,
) )
g = grad(r.sum(), self.x) g = grad(r.sum(), self.x)
f = function([self.x, self.threshold], [r, g]) f = function([self.x, self.threshold], [r, g])
...@@ -2558,11 +2620,12 @@ class TestGradUntil: ...@@ -2558,11 +2620,12 @@ class TestGradUntil:
def test_grad_until_and_truncate_sequence_taps(self): def test_grad_until_and_truncate_sequence_taps(self):
n = 3 n = 3
r, _ = scan( r = scan(
lambda x, y, u: (x * y, until(y > u)), lambda x, y, u: (x * y, until(y > u)),
sequences=dict(input=self.x, taps=[-2, 0]), sequences=dict(input=self.x, taps=[-2, 0]),
non_sequences=[self.threshold], non_sequences=[self.threshold],
truncate_gradient=n, truncate_gradient=n,
return_updates=False,
) )
g = grad(r.sum(), self.x) g = grad(r.sum(), self.x)
f = function([self.x, self.threshold], [r, g]) f = function([self.x, self.threshold], [r, g])
...@@ -2581,8 +2644,12 @@ def test_mintap_onestep(): ...@@ -2581,8 +2644,12 @@ def test_mintap_onestep():
new_sum = prev_sum + seq_t new_sum = prev_sum + seq_t
return new_sum return new_sum
rs, _updates = scan( rs = scan(
fn=accum, sequences={"input": seq, "taps": [2]}, outputs_info=0, n_steps=1 fn=accum,
sequences={"input": seq, "taps": [2]},
outputs_info=0,
n_steps=1,
return_updates=False,
) )
f = function(inputs=[seq], outputs=rs) f = function(inputs=[seq], outputs=rs)
...@@ -2667,7 +2734,12 @@ def test_inner_get_vector_length(): ...@@ -2667,7 +2734,12 @@ def test_inner_get_vector_length():
def test_profile_info(): def test_profile_info():
from pytensor.scan.utils import ScanProfileStats from pytensor.scan.utils import ScanProfileStats
z, _updates = scan(fn=lambda u: u + 1, sequences=[pt.arange(10)], profile=True) z = scan(
fn=lambda u: u + 1,
sequences=[pt.arange(10)],
profile=True,
return_updates=False,
)
assert isinstance(z.owner.op, Scan) assert isinstance(z.owner.op, Scan)
fn = z.owner.op.fn fn = z.owner.op.fn
...@@ -2676,8 +2748,11 @@ def test_profile_info(): ...@@ -2676,8 +2748,11 @@ def test_profile_info():
assert fn.profile.name == "scan_fn" assert fn.profile.name == "scan_fn"
# Set the `ScanProfileStats` name # Set the `ScanProfileStats` name
z, _updates = scan( z = scan(
fn=lambda u: u + 1, sequences=[pt.arange(10)], profile="profile_name" fn=lambda u: u + 1,
sequences=[pt.arange(10)],
profile="profile_name",
return_updates=False,
) )
assert isinstance(z.owner.op, Scan) assert isinstance(z.owner.op, Scan)
...@@ -2688,7 +2763,12 @@ def test_profile_info(): ...@@ -2688,7 +2763,12 @@ def test_profile_info():
# Use an existing profile object # Use an existing profile object
profile = fn.profile profile = fn.profile
z, _updates = scan(fn=lambda u: u + 1, sequences=[pt.arange(10)], profile=profile) z = scan(
fn=lambda u: u + 1,
sequences=[pt.arange(10)],
profile=profile,
return_updates=False,
)
assert isinstance(z.owner.op, Scan) assert isinstance(z.owner.op, Scan)
fn = z.owner.op.fn fn = z.owner.op.fn
...@@ -2819,7 +2899,7 @@ class TestExamples: ...@@ -2819,7 +2899,7 @@ class TestExamples:
y_tm1 + dot(x_tm1, W_out), y_tm1 + dot(x_tm1, W_out),
] ]
outputs, updates = scan( outputs = scan(
f_rnn_cmpl, f_rnn_cmpl,
[u1, u2], [u1, u2],
[None, None, x0, dict(initial=y0, taps=[-1, -3])], [None, None, x0, dict(initial=y0, taps=[-1, -3])],
...@@ -2827,11 +2907,10 @@ class TestExamples: ...@@ -2827,11 +2907,10 @@ class TestExamples:
n_steps=None, n_steps=None,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
f4 = function( f4 = function([u1, u2, x0, y0, W_in1], outputs, allow_input_downcast=True)
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
)
# compute the values in numpy # compute the values in numpy
v_x = np.zeros((3, 2), dtype=config.floatX) v_x = np.zeros((3, 2), dtype=config.floatX)
...@@ -2857,8 +2936,12 @@ class TestExamples: ...@@ -2857,8 +2936,12 @@ class TestExamples:
def scanStep(prev, seq, f1): def scanStep(prev, seq, f1):
return prev + f1 * seq return prev + f1 * seq
scanned, _ = scan( scanned = scan(
fn=scanStep, sequences=[seq], outputs_info=[to_scan], non_sequences=[f1] fn=scanStep,
sequences=[seq],
outputs_info=[to_scan],
non_sequences=[f1],
return_updates=False,
) )
function(inputs=[to_scan, seq, f1], outputs=scanned, allow_input_downcast=True) function(inputs=[to_scan, seq, f1], outputs=scanned, allow_input_downcast=True)
...@@ -2879,8 +2962,12 @@ class TestExamples: ...@@ -2879,8 +2962,12 @@ class TestExamples:
expr = dot(h_tm1, W) + x_t expr = dot(h_tm1, W) + x_t
return expr return expr
expr, _ = scan( expr = scan(
fn=one_step, sequences=[inpt], outputs_info=[initial], non_sequences=[W] fn=one_step,
sequences=[inpt],
outputs_info=[initial],
non_sequences=[W],
return_updates=False,
) )
v1 = shared(np.ones(5, dtype=config.floatX)) v1 = shared(np.ones(5, dtype=config.floatX))
...@@ -2917,11 +3004,12 @@ class TestExamples: ...@@ -2917,11 +3004,12 @@ class TestExamples:
x = scalar() x = scalar()
seq = vector() seq = vector()
outputs_info = [x, pt.zeros_like(x)] outputs_info = [x, pt.zeros_like(x)]
(out1, out2), _updates = scan( (out1, out2) = scan(
lambda a, b, c: (a + b, b + c), lambda a, b, c: (a + b, b + c),
sequences=seq, sequences=seq,
outputs_info=outputs_info, outputs_info=outputs_info,
mode=mode, mode=mode,
return_updates=False,
) )
# Obtain a reference to the scan outputs before the subtensor and # Obtain a reference to the scan outputs before the subtensor and
...@@ -2956,8 +3044,11 @@ class TestExamples: ...@@ -2956,8 +3044,11 @@ class TestExamples:
x = dcol() x = dcol()
seq = dcol() seq = dcol()
outputs_info = [x, pt.zeros_like(x)] outputs_info = [x, pt.zeros_like(x)]
(out1, out2), _updates = scan( (out1, out2) = scan(
lambda a, b, c: (a + b, a + c), sequences=seq, outputs_info=outputs_info lambda a, b, c: (a + b, a + c),
sequences=seq,
outputs_info=outputs_info,
return_updates=False,
) )
# Obtain a reference to the scan outputs before the subtensor and # Obtain a reference to the scan outputs before the subtensor and
...@@ -3096,7 +3187,9 @@ class TestExamples: ...@@ -3096,7 +3187,9 @@ class TestExamples:
seq = matrix() seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX)) initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [{"initial": initial_value, "taps": [-4]}, None] outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
results, _updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info) results = scan(
fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False
)
f = function([seq], results[1]) f = function([seq], results[1])
assert np.all(exp_out == f(inp)) assert np.all(exp_out == f(inp))
...@@ -3119,7 +3212,9 @@ class TestExamples: ...@@ -3119,7 +3212,9 @@ class TestExamples:
seq = matrix() seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX)) initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [{"initial": initial_value, "taps": [-4]}, None] outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
results, _ = scan(fn=onestep, sequences=seq, outputs_info=outputs_info) results = scan(
fn=onestep, sequences=seq, outputs_info=outputs_info, return_updates=False
)
sharedvar = shared(np.zeros((1, 1), dtype=config.floatX)) sharedvar = shared(np.zeros((1, 1), dtype=config.floatX))
updates = {sharedvar: results[0][-1:]} updates = {sharedvar: results[0][-1:]}
...@@ -3164,7 +3259,7 @@ class TestExamples: ...@@ -3164,7 +3259,7 @@ class TestExamples:
init = matrix() init = matrix()
outputs_info = [None, None, None, None, dict(initial=init, taps=[-3, -2, -1])] outputs_info = [None, None, None, None, dict(initial=init, taps=[-3, -2, -1])]
out, _ = scan(inner_fn, outputs_info=outputs_info, n_steps=3) out = scan(inner_fn, outputs_info=outputs_info, n_steps=3, return_updates=False)
fct = function([init], out) fct = function([init], out)
# Compare obtained outputs with expected outputs # Compare obtained outputs with expected outputs
...@@ -3197,21 +3292,23 @@ class TestExamples: ...@@ -3197,21 +3292,23 @@ class TestExamples:
def loss_inner(sum_inner, W): def loss_inner(sum_inner, W):
return sum_inner + (W**2).sum() return sum_inner + (W**2).sum()
result_inner, _ = scan( result_inner = scan(
fn=loss_inner, fn=loss_inner,
outputs_info=pt.as_tensor_variable(np.asarray(0, dtype=np.float32)), outputs_info=pt.as_tensor_variable(np.asarray(0, dtype=np.float32)),
non_sequences=[W], non_sequences=[W],
n_steps=1, n_steps=1,
return_updates=False,
) )
return sum_outer + result_inner[-1] return sum_outer + result_inner[-1]
# Also test return_list for that case. # Also test return_list for that case.
result_outer, _ = scan( result_outer = scan(
fn=loss_outer, fn=loss_outer,
outputs_info=pt.as_tensor_variable(np.asarray(0, dtype=np.float32)), outputs_info=pt.as_tensor_variable(np.asarray(0, dtype=np.float32)),
non_sequences=[W], non_sequences=[W],
n_steps=n_steps, n_steps=n_steps,
return_list=True, return_list=True,
return_updates=False,
) )
cost = result_outer[0][-1] cost = result_outer[0][-1]
...@@ -3230,7 +3327,9 @@ class TestExamples: ...@@ -3230,7 +3327,9 @@ class TestExamples:
x0 = vector("X") x0 = vector("X")
y0 = vector("y0") y0 = vector("y0")
z0 = vector("Z") z0 = vector("Z")
[x, y, z], _ = scan(inner_fn, outputs_info=[x0, y0, z0], n_steps=10) [x, y, z] = scan(
inner_fn, outputs_info=[x0, y0, z0], n_steps=10, return_updates=False
)
cost = (x + y + z).sum() cost = (x + y + z).sum()
grad(cost, x0) # defined grad(cost, x0) # defined
...@@ -3247,7 +3346,12 @@ class TestExamples: ...@@ -3247,7 +3346,12 @@ class TestExamples:
m = matrix("m") m = matrix("m")
u0 = pt.zeros((7,)) u0 = pt.zeros((7,))
[_u, m2], _ = scan(lambda _, u: [u, v], sequences=m, outputs_info=[u0, None]) [_u, m2] = scan(
lambda _, u: [u, v],
sequences=m,
outputs_info=[u0, None],
return_updates=False,
)
# This used to raise an exception with older versions because for a # This used to raise an exception with older versions because for a
# disconnected gradient a non disconnected type was returned # disconnected gradient a non disconnected type was returned
grad((m * m2).sum(), v) grad((m * m2).sum(), v)
...@@ -3257,8 +3361,11 @@ class TestExamples: ...@@ -3257,8 +3361,11 @@ class TestExamples:
m = matrix("m") m = matrix("m")
u0 = pt.zeros((7,)) u0 = pt.zeros((7,))
[_u, m2], _ = scan( [_u, m2] = scan(
lambda x, u: [x + u, u + v], sequences=m, outputs_info=[u0, None] lambda x, u: [x + u, u + v],
sequences=m,
outputs_info=[u0, None],
return_updates=False,
) )
# This used to raise an exception with older versions because # This used to raise an exception with older versions because
# scan could not detect the connection between `m2` and `x` # scan could not detect the connection between `m2` and `x`
...@@ -3278,7 +3385,7 @@ class TestExamples: ...@@ -3278,7 +3385,7 @@ class TestExamples:
out2 = out1 + 1 out2 = out1 + 1
return out1, out2 return out1, out2
[_out1, out2], _ = scan(step, sequences=v) [_out1, out2] = scan(step, sequences=v, return_updates=False)
gv = grad(out2.sum(), [v]) gv = grad(out2.sum(), [v])
f = function([v], gv) f = function([v], gv)
...@@ -3289,7 +3396,13 @@ class TestExamples: ...@@ -3289,7 +3396,13 @@ class TestExamples:
def test_grad_bug_disconnected_input(self): def test_grad_bug_disconnected_input(self):
W = shared(np.zeros((3, 3)), name="W") W = shared(np.zeros((3, 3)), name="W")
v = ivector(name="v") v = ivector(name="v")
y, _ = scan(lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=W) y = scan(
lambda i, W: W[i],
sequences=v,
outputs_info=None,
non_sequences=W,
return_updates=False,
)
# This used to raise an exception # This used to raise an exception
f = function([v], grad(y.sum(), W)) f = function([v], grad(y.sum(), W))
...@@ -3299,10 +3412,8 @@ class TestExamples: ...@@ -3299,10 +3412,8 @@ class TestExamples:
w = shared(np.array(0, dtype="float32"), name="w") w = shared(np.array(0, dtype="float32"), name="w")
init = fscalar("init") init = fscalar("init")
out, _ = scan( out = scan(
fn=lambda prev: w, fn=lambda prev: w, outputs_info=init, n_steps=2, return_updates=False
outputs_info=init,
n_steps=2,
) )
grad(out[-1], w) grad(out[-1], w)
...@@ -3326,7 +3437,7 @@ class TestExamples: ...@@ -3326,7 +3437,7 @@ class TestExamples:
def f_rnn_shared(u_tm2, x_tm1, x_tm2): def f_rnn_shared(u_tm2, x_tm1, x_tm2):
return u_tm2 * W_in + x_tm1 * W + x_tm2 return u_tm2 * W_in + x_tm1 * W + x_tm2
outputs, updates = scan( outputs = scan(
f_rnn_shared, f_rnn_shared,
dict(input=u, taps=-2), dict(input=u, taps=-2),
dict(initial=x0, taps=[-1, -2]), dict(initial=x0, taps=[-1, -2]),
...@@ -3334,9 +3445,10 @@ class TestExamples: ...@@ -3334,9 +3445,10 @@ class TestExamples:
n_steps=None, n_steps=None,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
f7 = function([u, x0], outputs, updates=updates, allow_input_downcast=True) f7 = function([u, x0], outputs, allow_input_downcast=True)
pytensor_out = f7(vu, vx0) pytensor_out = f7(vu, vx0)
# compute output in numpy # compute output in numpy
...@@ -3372,7 +3484,7 @@ class TestExamples: ...@@ -3372,7 +3484,7 @@ class TestExamples:
def f_rnn_shared(u_tm2, u_tp2, x_tm1, x_tm2): def f_rnn_shared(u_tm2, u_tp2, x_tm1, x_tm2):
return (u_tm2 + u_tp2) * W_in + x_tm1 * W + x_tm2 return (u_tm2 + u_tp2) * W_in + x_tm1 * W + x_tm2
output, updates = scan( output = scan(
f_rnn_shared, f_rnn_shared,
dict(input=u, taps=[-2, 2]), dict(input=u, taps=[-2, 2]),
dict(initial=x0, taps=[-1, -2]), dict(initial=x0, taps=[-1, -2]),
...@@ -3380,9 +3492,10 @@ class TestExamples: ...@@ -3380,9 +3492,10 @@ class TestExamples:
n_steps=None, n_steps=None,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
f8 = function([u, x0], output, updates=updates, allow_input_downcast=True) f8 = function([u, x0], output, allow_input_downcast=True)
pytensor_out = f8(vu, vx0) pytensor_out = f8(vu, vx0)
# compute output in numpy # compute output in numpy
numpy_out = np.zeros(2) numpy_out = np.zeros(2)
...@@ -3404,7 +3517,7 @@ class TestExamples: ...@@ -3404,7 +3517,7 @@ class TestExamples:
state = scalar("state") state = scalar("state")
n_steps = iscalar("nsteps") n_steps = iscalar("nsteps")
# Test return_list at the same time. # Test return_list at the same time.
output, updates = scan( output = scan(
f_pow2, f_pow2,
[], [],
state, state,
...@@ -3413,10 +3526,9 @@ class TestExamples: ...@@ -3413,10 +3526,9 @@ class TestExamples:
truncate_gradient=-1, truncate_gradient=-1,
return_list=True, return_list=True,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
my_f = function( my_f = function([state, n_steps], output, allow_input_downcast=True)
[state, n_steps], output, updates=updates, allow_input_downcast=True
)
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
state = rng.uniform() state = rng.uniform()
...@@ -3446,10 +3558,11 @@ class TestExamples: ...@@ -3446,10 +3558,11 @@ class TestExamples:
pre_h = dot(x, W_x) pre_h = dot(x, W_x)
return pre_h return pre_h
value, _scan_updates = scan( value = scan(
_active, _active,
sequences=X, sequences=X,
outputs_info=[pt.alloc(floatx(0.0), 1, out_size)], outputs_info=[pt.alloc(floatx(0.0), 1, out_size)],
return_updates=False,
) )
cost = mean(value) cost = mean(value)
gW_x = grad(cost, W_x) gW_x = grad(cost, W_x)
...@@ -3467,7 +3580,7 @@ class TestExamples: ...@@ -3467,7 +3580,7 @@ class TestExamples:
condition = until(new_value > max_value) condition = until(new_value > max_value)
return [new_value, new_step], condition return [new_value, new_step], condition
rs, _updates = scan(fn=accum, outputs_info=[0, 0], n_steps=n_steps) rs = scan(fn=accum, outputs_info=[0, 0], n_steps=n_steps, return_updates=False)
f = function(inputs=[max_value, n_steps], outputs=rs) f = function(inputs=[max_value, n_steps], outputs=rs)
...@@ -3487,33 +3600,37 @@ class TestExamples: ...@@ -3487,33 +3600,37 @@ class TestExamples:
# Generate the components of the polynomial # Generate the components of the polynomial
full_range = pt.arange(max_coefficients_supported) full_range = pt.arange(max_coefficients_supported)
components, _updates = scan( components = scan(
fn=lambda coeff, power, free_var: coeff * (free_var**power), fn=lambda coeff, power, free_var: coeff * (free_var**power),
sequences=[coefficients, full_range], sequences=[coefficients, full_range],
non_sequences=x, non_sequences=x,
return_updates=False,
) )
polynomial1 = components.sum() polynomial1 = components.sum()
polynomial2, _updates = scan( polynomial2 = scan(
fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power), fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power),
outputs_info=pt.constant(0, dtype="floatX"), outputs_info=pt.constant(0, dtype="floatX"),
sequences=[coefficients, full_range], sequences=[coefficients, full_range],
non_sequences=x, non_sequences=x,
return_updates=False,
) )
# python int # python int
polynomial3, _updates = scan( polynomial3 = scan(
fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power), fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power),
outputs_info=0, outputs_info=0,
sequences=[coefficients, full_range], sequences=[coefficients, full_range],
non_sequences=x, non_sequences=x,
return_updates=False,
) )
# python float # python float
polynomial4, _updates = scan( polynomial4 = scan(
fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power), fn=lambda coeff, power, prev, free_var: prev + coeff * (free_var**power),
outputs_info=0.0, outputs_info=0.0,
sequences=[coefficients, full_range], sequences=[coefficients, full_range],
non_sequences=x, non_sequences=x,
return_updates=False,
) )
calculate_polynomial = function( calculate_polynomial = function(
...@@ -3576,8 +3693,12 @@ class TestExamples: ...@@ -3576,8 +3693,12 @@ class TestExamples:
# o = v + 1 # <-- this line works # o = v + 1 # <-- this line works
return o return o
OS, _updates = scan( OS = scan(
fn=one_step, sequences=V, outputs_info=[None], non_sequences=[W] fn=one_step,
sequences=V,
outputs_info=[None],
non_sequences=[W],
return_updates=False,
) )
O = OS.sum() + W.sum() O = OS.sum() + W.sum()
...@@ -3591,11 +3712,12 @@ class TestExamples: ...@@ -3591,11 +3712,12 @@ class TestExamples:
) )
def test_infershape_seq_shorter_nsteps(self): def test_infershape_seq_shorter_nsteps(self):
x = vector("x") x = vector("x")
[o1, o2], _ = scan( [o1, o2] = scan(
lambda x, y: (x + 1, y + x), lambda x, y: (x + 1, y + x),
sequences=x, sequences=x,
outputs_info=[None, x[0]], outputs_info=[None, x[0]],
n_steps=20, n_steps=20,
return_updates=False,
) )
f = function([x], [o1, o2], mode=mode_with_opt) f = function([x], [o1, o2], mode=mode_with_opt)
...@@ -3667,10 +3789,14 @@ class TestExamples: ...@@ -3667,10 +3789,14 @@ class TestExamples:
condition = until(previous_val > 5) condition = until(previous_val > 5)
return new_val, condition return new_val, condition
out, _updates = scan(inner_fct, outputs_info=x, n_steps=10) out, updates = scan(inner_fct, outputs_info=x, n_steps=10)
g_out = grad(out.sum(), x) g_out = grad(out.sum(), x)
fct = function([x], [out, g_out]) fct = function(
[x],
[out, g_out],
updates=updates,
)
for i in range(-5, 5): for i in range(-5, 5):
output, g_output = fct(i) output, g_output = fct(i)
...@@ -3702,7 +3828,7 @@ class TestExamples: ...@@ -3702,7 +3828,7 @@ class TestExamples:
) )
return next_sitsot_val, next_mitsot_val, nitsot_out return next_sitsot_val, next_mitsot_val, nitsot_out
out, _updates = scan( out = scan(
fn=step, fn=step,
sequences=seq, sequences=seq,
outputs_info=[ outputs_info=[
...@@ -3711,6 +3837,7 @@ class TestExamples: ...@@ -3711,6 +3837,7 @@ class TestExamples:
None, None,
], ],
n_steps=5, n_steps=5,
return_updates=False,
) )
f = function([seq, sitsot_init, mitsot_init], out[2].shape) f = function([seq, sitsot_init, mitsot_init], out[2].shape)
...@@ -3746,7 +3873,7 @@ class TestExamples: ...@@ -3746,7 +3873,7 @@ class TestExamples:
dot(x_tm1, W_out), dot(x_tm1, W_out),
] ]
outputs, updates = scan( outputs = scan(
f_rnn_cmpl, f_rnn_cmpl,
[u1, u2], [u1, u2],
[x0, y0], [x0, y0],
...@@ -3754,11 +3881,10 @@ class TestExamples: ...@@ -3754,11 +3881,10 @@ class TestExamples:
n_steps=None, n_steps=None,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
f4 = function( f4 = function([u1, u2, x0, y0, W_in1], outputs, allow_input_downcast=True)
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
)
# compute the values in numpy # compute the values in numpy
v_x = np.zeros((3, 2), dtype=config.floatX) v_x = np.zeros((3, 2), dtype=config.floatX)
...@@ -3802,7 +3928,7 @@ class TestExamples: ...@@ -3802,7 +3928,7 @@ class TestExamples:
dot(u1_t, W_in1), dot(u1_t, W_in1),
] ]
outputs, updates = scan( outputs = scan(
f_rnn_cmpl, f_rnn_cmpl,
[u1, dict(input=u2, taps=[-1, 0, 1])], [u1, dict(input=u2, taps=[-1, 0, 1])],
[x0, dict(initial=y0, taps=[-1, -3]), None], [x0, dict(initial=y0, taps=[-1, -3]), None],
...@@ -3810,11 +3936,10 @@ class TestExamples: ...@@ -3810,11 +3936,10 @@ class TestExamples:
n_steps=None, n_steps=None,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
f = function( f = function([u1, u2, x0, y0, W_in1], outputs, allow_input_downcast=True)
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
)
ny0 = np.zeros((5, 2)) ny0 = np.zeros((5, 2))
ny1 = np.zeros((5,)) ny1 = np.zeros((5,))
...@@ -3904,13 +4029,14 @@ class TestExamples: ...@@ -3904,13 +4029,14 @@ class TestExamples:
return [h_t, y_t] return [h_t, y_t]
# hidden and outputs of the entire sequence # hidden and outputs of the entire sequence
[_h, y], _ = scan( [_h, y] = scan(
fn=one_step, fn=one_step,
sequences=dict(input=x), sequences=dict(input=x),
# corresponds to the return type of one_step # corresponds to the return type of one_step
outputs_info=[dict(initial=h0, taps=[-2, -1]), None], outputs_info=[dict(initial=h0, taps=[-2, -1]), None],
non_sequences=[W_ih, W_hh, b_h, W_ho, b_o], non_sequences=[W_ih, W_hh, b_h, W_ho, b_o],
mode=mode, mode=mode,
return_updates=False,
) )
# target values # target values
...@@ -4084,7 +4210,7 @@ def test_output_storage_reuse(linker_mode): ...@@ -4084,7 +4210,7 @@ def test_output_storage_reuse(linker_mode):
outer-output arrays are initialized using the outer-input arrays, the outer-output arrays are initialized using the outer-input arrays, the
shape difference needs to be handled correctly. shape difference needs to be handled correctly.
""" """
s_in_y, _ = scan( s_in_y = scan(
fn=lambda z: (z + 1, until(z > 2)), fn=lambda z: (z + 1, until(z > 2)),
outputs_info=[ outputs_info=[
{"taps": [-1], "initial": pt.as_tensor(0.0, dtype=np.float64)} {"taps": [-1], "initial": pt.as_tensor(0.0, dtype=np.float64)}
...@@ -4092,16 +4218,18 @@ def test_output_storage_reuse(linker_mode): ...@@ -4092,16 +4218,18 @@ def test_output_storage_reuse(linker_mode):
mode=mode, mode=mode,
n_steps=n - 1, n_steps=n - 1,
allow_gc=False, allow_gc=False,
return_updates=False,
) )
return s_in_y.sum() return s_in_y.sum()
s_y, _updates = scan( s_y = scan(
fn=fn, fn=fn,
outputs_info=[None], outputs_info=[None],
sequences=[pt.as_tensor([3, 2, 1], dtype=np.int64)], sequences=[pt.as_tensor([3, 2, 1], dtype=np.int64)],
mode=mode, mode=mode,
allow_gc=False, allow_gc=False,
return_updates=False,
) )
f_cvm = function([], s_y, mode=mode) f_cvm = function([], s_y, mode=mode)
...@@ -4121,14 +4249,14 @@ def test_rng_outputs_info(): ...@@ -4121,14 +4249,14 @@ def test_rng_outputs_info():
).owner.outputs ).owner.outputs
return next_x, next_rng return next_x, next_rng
[xs, rng_final], updates = scan( [xs, rng_final] = scan(
fn=step, fn=step,
outputs_info=[x0, rng_x0], outputs_info=[x0, rng_x0],
n_steps=10, n_steps=10,
return_updates=False,
) )
assert isinstance(xs.type, TensorType) assert isinstance(xs.type, TensorType)
assert isinstance(rng_final.type, RandomGeneratorType) assert isinstance(rng_final.type, RandomGeneratorType)
assert not updates
fn = function([rng_init], [xs, rng_final]) fn = function([rng_init], [xs, rng_final])
xs_eval, rng_final_eval = fn(np.random.default_rng(0)) xs_eval, rng_final_eval = fn(np.random.default_rng(0))
......
...@@ -47,38 +47,47 @@ class TestRemoveConstantsAndUnusedInputsScan: ...@@ -47,38 +47,47 @@ class TestRemoveConstantsAndUnusedInputsScan:
"""Test the rewrite `remove_constants_and_unused_inputs_scan` for non-sequences.""" """Test the rewrite `remove_constants_and_unused_inputs_scan` for non-sequences."""
W = matrix(name="W") W = matrix(name="W")
v = ivector(name="v") v = ivector(name="v")
y1, _ = scan( y1 = scan(
lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=[W] lambda i, W: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W],
return_updates=False,
) )
y2, _ = scan( y2 = scan(
lambda i, _, W: W[i], lambda i, _, W: W[i],
sequences=v, sequences=v,
outputs_info=None, outputs_info=None,
non_sequences=[W[0], W], non_sequences=[W[0], W],
return_updates=False,
) )
y3, _ = scan( y3 = scan(
lambda i, W, _: W[i], lambda i, W, _: W[i],
sequences=v, sequences=v,
outputs_info=None, outputs_info=None,
non_sequences=[W, W[0]], non_sequences=[W, W[0]],
return_updates=False,
) )
y4, _ = scan( y4 = scan(
lambda i, _, _2, W: W[i], lambda i, _, _2, W: W[i],
sequences=v, sequences=v,
outputs_info=None, outputs_info=None,
non_sequences=[W[0], W[0], W], non_sequences=[W[0], W[0], W],
return_updates=False,
) )
y5, _ = scan( y5 = scan(
lambda i, _, W, _2: W[i], lambda i, _, W, _2: W[i],
sequences=v, sequences=v,
outputs_info=None, outputs_info=None,
non_sequences=[W[0], W, W[0]], non_sequences=[W[0], W, W[0]],
return_updates=False,
) )
y6, _ = scan( y6 = scan(
lambda i, W, _, _2: W[i], lambda i, W, _, _2: W[i],
sequences=v, sequences=v,
outputs_info=None, outputs_info=None,
non_sequences=[W, W[0], W[0]], non_sequences=[W, W[0], W[0]],
return_updates=False,
) )
# TODO: y7 have problem during run time. I think it should # TODO: y7 have problem during run time. I think it should
# raise an error during the scan construction. # raise an error during the scan construction.
...@@ -112,47 +121,61 @@ class TestRemoveConstantsAndUnusedInputsScan: ...@@ -112,47 +121,61 @@ class TestRemoveConstantsAndUnusedInputsScan:
W = matrix(name="W") W = matrix(name="W")
v = ivector(name="v") v = ivector(name="v")
vv = matrix(name="vv") vv = matrix(name="vv")
y1, _ = scan( y1 = scan(
lambda i, W: W[i], sequences=v, outputs_info=None, non_sequences=[W] lambda i, W: W[i],
sequences=v,
outputs_info=None,
non_sequences=[W],
return_updates=False,
) )
y2, _ = scan( y2 = scan(
lambda i, _, W: W[i], sequences=[v, v], outputs_info=None, non_sequences=W lambda i, _, W: W[i],
sequences=[v, v],
outputs_info=None,
non_sequences=W,
return_updates=False,
) )
y3, _ = scan( y3 = scan(
lambda i, _, W: W[i], lambda i, _, W: W[i],
sequences=[v, vv[0]], sequences=[v, vv[0]],
outputs_info=None, outputs_info=None,
non_sequences=W, non_sequences=W,
return_updates=False,
) )
y4, _ = scan( y4 = scan(
lambda _, i, W: W[i], lambda _, i, W: W[i],
sequences=[vv[0], v], sequences=[vv[0], v],
outputs_info=None, outputs_info=None,
non_sequences=W, non_sequences=W,
return_updates=False,
) )
y5, _ = scan( y5 = scan(
lambda _, i, _2, W: W[i], lambda _, i, _2, W: W[i],
sequences=[vv, v, vv[0]], sequences=[vv, v, vv[0]],
outputs_info=None, outputs_info=None,
non_sequences=W, non_sequences=W,
return_updates=False,
) )
y6, _ = scan( y6 = scan(
lambda _, _2, i, W: W[i], lambda _, _2, i, W: W[i],
sequences=[vv[0], vv, v], sequences=[vv[0], vv, v],
outputs_info=None, outputs_info=None,
non_sequences=W, non_sequences=W,
return_updates=False,
) )
y7, _ = scan( y7 = scan(
lambda i, _, _2, W: W[i], lambda i, _, _2, W: W[i],
sequences=[v, vv[0], vv[0]], sequences=[v, vv[0], vv[0]],
outputs_info=None, outputs_info=None,
non_sequences=W, non_sequences=W,
return_updates=False,
) )
y8, _ = scan( y8 = scan(
lambda _, i, W, _2, _3: W[i], lambda _, i, W, _2, _3: W[i],
sequences=[vv[0], v], sequences=[vv[0], v],
outputs_info=None, outputs_info=None,
non_sequences=[W, W[0], W[0]], non_sequences=[W, W[0], W[0]],
return_updates=False,
) )
W_val = np.random.normal(size=(3, 3)).astype(config.floatX) W_val = np.random.normal(size=(3, 3)).astype(config.floatX)
...@@ -195,7 +218,7 @@ class TestPushOutDot: ...@@ -195,7 +218,7 @@ class TestPushOutDot:
def lambda_fn(h, W1, W2): def lambda_fn(h, W1, W2):
return dot(h, W1 + W2) return dot(h, W1 + W2)
o, _ = scan(lambda_fn, non_sequences=[h0, W1, W2], n_steps=5) o = scan(lambda_fn, non_sequences=[h0, W1, W2], n_steps=5, return_updates=False)
f = function([h0, W1, W2], o, mode=self.mode) f = function([h0, W1, W2], o, mode=self.mode)
...@@ -232,19 +255,24 @@ class TestPushOutDot: ...@@ -232,19 +255,24 @@ class TestPushOutDot:
return dot(W1, W2), until_condition return dot(W1, W2), until_condition
# Compile a function with the optimization # Compile a function with the optimization
o, _ = scan( o = scan(
lambda_fn, sequences=[step_indices, W1], non_sequences=[W2], n_steps=5 lambda_fn,
sequences=[step_indices, W1],
non_sequences=[W2],
n_steps=5,
return_updates=False,
) )
f = function([W1, W2, step_indices], o, mode=self.mode) f = function([W1, W2, step_indices], o, mode=self.mode)
# Compule an pytensor function without the optimization # Compule an pytensor function without the optimization
o, _ = scan( o = scan(
lambda_fn, lambda_fn,
sequences=[step_indices, W1], sequences=[step_indices, W1],
non_sequences=[W2], non_sequences=[W2],
n_steps=5, n_steps=5,
mode="FAST_COMPILE", mode="FAST_COMPILE",
return_updates=False,
) )
f_ref = function([W1, W2, step_indices], o, mode=self.mode) f_ref = function([W1, W2, step_indices], o, mode=self.mode)
...@@ -268,7 +296,13 @@ class TestPushOutDot: ...@@ -268,7 +296,13 @@ class TestPushOutDot:
def lambda_fn(h, W1, W2): def lambda_fn(h, W1, W2):
return dot(h, W1 + W2) return dot(h, W1 + W2)
o, _ = scan(lambda_fn, outputs_info=h0, non_sequences=[W1, W2], n_steps=5) o = scan(
lambda_fn,
outputs_info=h0,
non_sequences=[W1, W2],
n_steps=5,
return_updates=False,
)
f = function([h0, W1, W2], o, mode=self.mode) f = function([h0, W1, W2], o, mode=self.mode)
...@@ -290,10 +324,11 @@ class TestPushOutDot: ...@@ -290,10 +324,11 @@ class TestPushOutDot:
def fn(i, i_tm1): def fn(i, i_tm1):
return i + 10, i_tm1 return i + 10, i_tm1
([i_t, i_tm1], _) = scan( [i_t, i_tm1] = scan(
fn, fn,
sequences=[inp], sequences=[inp],
outputs_info=[np.asarray([0.0, 0.0], config.floatX), None], outputs_info=[np.asarray([0.0, 0.0], config.floatX), None],
return_updates=False,
) )
f = function([inp], [i_t, i_tm1]) f = function([inp], [i_t, i_tm1])
val = np.arange(10).reshape(5, 2).astype(config.floatX) val = np.arange(10).reshape(5, 2).astype(config.floatX)
...@@ -397,17 +432,18 @@ class TestPushOutNonSeqScan: ...@@ -397,17 +432,18 @@ class TestPushOutNonSeqScan:
@config.change_flags(on_opt_error="raise") @config.change_flags(on_opt_error="raise")
def test_pushout_seqs2(self): def test_pushout_seqs2(self):
x = matrix() x = matrix()
outputs, updates = scan( outputs = scan(
lambda x: [x * x, pt.constant(0).copy().copy()], lambda x: [x * x, pt.constant(0).copy().copy()],
n_steps=2, n_steps=2,
sequences=[], sequences=[],
non_sequences=[], non_sequences=[],
outputs_info=[x, None], outputs_info=[x, None],
return_updates=False,
) )
# Compile an PyTensor function where any optimization error will lead to # Compile an PyTensor function where any optimization error will lead to
# an exception being raised # an exception being raised
function([x], outputs, updates=updates) function([x], outputs)
@config.change_flags(on_opt_error="raise") @config.change_flags(on_opt_error="raise")
def test_pushout_nonseq(self): def test_pushout_nonseq(self):
...@@ -418,7 +454,9 @@ class TestPushOutNonSeqScan: ...@@ -418,7 +454,9 @@ class TestPushOutNonSeqScan:
outputs. This led the optimization to raise an exception. outputs. This led the optimization to raise an exception.
""" """
outputs, _ = scan(lambda x: (x * x, x), non_sequences=[2], n_steps=2) outputs = scan(
lambda x: (x * x, x), non_sequences=[2], n_steps=2, return_updates=False
)
f = function(inputs=[], outputs=outputs) f = function(inputs=[], outputs=outputs)
outs = f() outs = f()
...@@ -583,10 +621,12 @@ class TestPushOutNonSeqScan: ...@@ -583,10 +621,12 @@ class TestPushOutNonSeqScan:
test_ofg = OpFromGraph([], [y]) test_ofg = OpFromGraph([], [y])
def inner_func(x): def inner_func(x):
out, _ = pytensor.scan(lambda: test_ofg(), n_steps=x) out = pytensor.scan(lambda: test_ofg(), n_steps=x, return_updates=False)
return out return out
out, _ = pytensor.scan(inner_func, sequences=[pt.arange(1, 2)]) out = pytensor.scan(
inner_func, sequences=[pt.arange(1, 2)], return_updates=False
)
_ = pytensor.function([], test_ofg()) _ = pytensor.function([], test_ofg())
...@@ -612,10 +652,11 @@ class TestPushOutAddScan: ...@@ -612,10 +652,11 @@ class TestPushOutAddScan:
def test_sum_dot(self): def test_sum_dot(self):
A = matrix("A") A = matrix("A")
B = matrix("B") B = matrix("B")
S, _ = scan( S = scan(
lambda x1, x2, u: u + dot(x1, x2), lambda x1, x2, u: u + dot(x1, x2),
sequences=[A.dimshuffle(0, 1, "x"), B.dimshuffle(0, "x", 1)], sequences=[A.dimshuffle(0, 1, "x"), B.dimshuffle(0, "x", 1)],
outputs_info=[pt.zeros_like(A)], outputs_info=[pt.zeros_like(A)],
return_updates=False,
) )
# FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that. # FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that.
# They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]` # They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]`
...@@ -636,13 +677,17 @@ class TestPushOutAddScan: ...@@ -636,13 +677,17 @@ class TestPushOutAddScan:
bv = pt.zeros((5,)) bv = pt.zeros((5,))
bh = pt.zeros((4,)) bh = pt.zeros((4,))
v = matrix("v") v = matrix("v")
(bv_t, bh_t), _ = scan( (bv_t, bh_t) = scan(
lambda _: [bv, bh], sequences=v, outputs_info=[None, None] lambda _: [bv, bh],
sequences=v,
outputs_info=[None, None],
return_updates=False,
) )
chain, _ = scan( chain = scan(
lambda x: dot(dot(x, W) + bh_t, W.T) + bv_t, lambda x: dot(dot(x, W) + bh_t, W.T) + bv_t,
outputs_info=v, outputs_info=v,
n_steps=2, n_steps=2,
return_updates=False,
) )
# TODO FIXME: Make this a real test and assert something. # TODO FIXME: Make this a real test and assert something.
chain_fn = function([v], chain) chain_fn = function([v], chain)
...@@ -710,26 +755,28 @@ class TestPushOutAddScan: ...@@ -710,26 +755,28 @@ class TestPushOutAddScan:
# Compile the function twice, once with the optimization and once # Compile the function twice, once with the optimization and once
# without # without
opt_mode = mode.including("scan") opt_mode = mode.including("scan")
h, _ = pytensor.scan( h = pytensor.scan(
rnn_step1, rnn_step1,
sequences=[x, ri, zi], sequences=[x, ri, zi],
n_steps=seq_len, n_steps=seq_len,
outputs_info=init, outputs_info=init,
name="fpass1", name="fpass1",
mode=opt_mode, mode=opt_mode,
return_updates=False,
) )
cost = h[-1].sum() cost = h[-1].sum()
grad1 = grad(cost, [U, V, W]) grad1 = grad(cost, [U, V, W])
f_opt = pytensor.function(inputs=[x, ri, zi], outputs=grad1, mode=opt_mode) f_opt = pytensor.function(inputs=[x, ri, zi], outputs=grad1, mode=opt_mode)
no_opt_mode = mode.excluding("scan_pushout_add") no_opt_mode = mode.excluding("scan_pushout_add")
h, _ = pytensor.scan( h = pytensor.scan(
rnn_step1, rnn_step1,
sequences=[x, ri, zi], sequences=[x, ri, zi],
n_steps=seq_len, n_steps=seq_len,
outputs_info=init, outputs_info=init,
name="fpass1", name="fpass1",
mode=no_opt_mode, mode=no_opt_mode,
return_updates=False,
) )
cost = h[-1].sum() cost = h[-1].sum()
grad1 = grad(cost, [U, V, W]) grad1 = grad(cost, [U, V, W])
...@@ -773,21 +820,23 @@ class TestPushOutAddScan: ...@@ -773,21 +820,23 @@ class TestPushOutAddScan:
# Compile the function twice, once with the optimization and once without # Compile the function twice, once with the optimization and once without
opt_mode = mode.including("scan") opt_mode = mode.including("scan")
h, _ = pytensor.scan( h = pytensor.scan(
inner_fct, inner_fct,
sequences=[input1, input2, input3], sequences=[input1, input2, input3],
outputs_info=init, outputs_info=init,
mode=opt_mode, mode=opt_mode,
return_updates=False,
) )
output = h[-1] output = h[-1]
f_opt = pytensor.function([input1, input2, input3], output, mode=opt_mode) f_opt = pytensor.function([input1, input2, input3], output, mode=opt_mode)
no_opt_mode = mode.excluding("scan_pushout_add") no_opt_mode = mode.excluding("scan_pushout_add")
h, _ = pytensor.scan( h = pytensor.scan(
inner_fct, inner_fct,
sequences=[input1, input2, input3], sequences=[input1, input2, input3],
outputs_info=init, outputs_info=init,
mode=no_opt_mode, mode=no_opt_mode,
return_updates=False,
) )
output = h[-1] output = h[-1]
f_no_opt = pytensor.function([input1, input2, input3], output, mode=no_opt_mode) f_no_opt = pytensor.function([input1, input2, input3], output, mode=no_opt_mode)
...@@ -892,13 +941,20 @@ class TestScanMerge: ...@@ -892,13 +941,20 @@ class TestScanMerge:
""" """
inps = vector() inps = vector()
state = scalar() state = scalar()
y1, _ = scan(lambda x, y: x * y, sequences=inps, outputs_info=state, n_steps=5) y1 = scan(
lambda x, y: x * y,
sequences=inps,
outputs_info=state,
n_steps=5,
return_updates=False,
)
y2, _ = scan( y2 = scan(
lambda x, y: (x + y, until(x > 0)), lambda x, y: (x + y, until(x > 0)),
sequences=inps, sequences=inps,
outputs_info=state, outputs_info=state,
n_steps=5, n_steps=5,
return_updates=False,
) )
scan_node1 = y1.owner.inputs[0].owner scan_node1 = y1.owner.inputs[0].owner
assert isinstance(scan_node1.op, Scan) assert isinstance(scan_node1.op, Scan)
...@@ -958,8 +1014,8 @@ class TestScanMerge: ...@@ -958,8 +1014,8 @@ class TestScanMerge:
def sub(s1, s2, const): def sub(s1, s2, const):
return s1 - 1, until(s2 > const) return s1 - 1, until(s2 > const)
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1]) sx = scan(add, sequences=[x, z], non_sequences=[c1], return_updates=False)
sy, _ = scan(sub, sequences=[y, -z], non_sequences=[c1]) sy = scan(sub, sequences=[y, -z], non_sequences=[c1], return_updates=False)
f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode) f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
assert self.count_scans(f) == 2 assert self.count_scans(f) == 2
...@@ -972,8 +1028,8 @@ class TestScanMerge: ...@@ -972,8 +1028,8 @@ class TestScanMerge:
np.testing.assert_array_equal(res_sx, [1, 1]) np.testing.assert_array_equal(res_sx, [1, 1])
np.testing.assert_array_equal(res_sy, [-1, -1, -1, -1, -1]) np.testing.assert_array_equal(res_sy, [-1, -1, -1, -1, -1])
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1]) sx = scan(add, sequences=[x, z], non_sequences=[c1], return_updates=False)
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c2]) sy = scan(sub, sequences=[y, z], non_sequences=[c2], return_updates=False)
f = pytensor.function( f = pytensor.function(
inputs=[x, y, z, c1, c2], outputs=[sx, sy], mode=self.mode inputs=[x, y, z, c1, c2], outputs=[sx, sy], mode=self.mode
...@@ -989,22 +1045,23 @@ class TestScanMerge: ...@@ -989,22 +1045,23 @@ class TestScanMerge:
np.testing.assert_array_equal(res_sx, [1, 1, 1, 1, 1]) np.testing.assert_array_equal(res_sx, [1, 1, 1, 1, 1])
np.testing.assert_array_equal(res_sy, [-1, -1, -1]) np.testing.assert_array_equal(res_sy, [-1, -1, -1])
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1]) sx = scan(add, sequences=[x, z], non_sequences=[c1], return_updates=False)
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c1]) sy = scan(sub, sequences=[y, z], non_sequences=[c1], return_updates=False)
f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode) f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
assert self.count_scans(f) == 1 assert self.count_scans(f) == 1
def nested_scan(c, x, z): def nested_scan(c, x, z):
sx, _ = scan(add, sequences=[x, z], non_sequences=[c]) sx = scan(add, sequences=[x, z], non_sequences=[c], return_updates=False)
sy, _ = scan(sub, sequences=[x, z], non_sequences=[c]) sy = scan(sub, sequences=[x, z], non_sequences=[c], return_updates=False)
return sx.sum() + sy.sum() return sx.sum() + sy.sum()
sz, _ = scan( sz = scan(
nested_scan, nested_scan,
sequences=[stack([c1, c2])], sequences=[stack([c1, c2])],
non_sequences=[x, z], non_sequences=[x, z],
mode=self.mode, mode=self.mode,
return_updates=False,
) )
f = pytensor.function(inputs=[x, z, c1, c2], outputs=sz, mode=mode) f = pytensor.function(inputs=[x, z, c1, c2], outputs=sz, mode=mode)
...@@ -1023,9 +1080,8 @@ class TestScanInplaceOptimizer: ...@@ -1023,9 +1080,8 @@ class TestScanInplaceOptimizer:
x = pt.vector("x") x = pt.vector("x")
scan_out, _ = pytensor.scan( scan_out = pytensor.scan(
lambda x: (x + 1) / 2 + 1, lambda x: (x + 1) / 2 + 1, sequences=[x], return_updates=False
sequences=[x],
) )
fgraph = FunctionGraph( fgraph = FunctionGraph(
...@@ -1039,10 +1095,8 @@ class TestScanInplaceOptimizer: ...@@ -1039,10 +1095,8 @@ class TestScanInplaceOptimizer:
assert equal_computations([scan_out], fgraph.outputs) assert equal_computations([scan_out], fgraph.outputs)
def test_inplace_basic(self): def test_inplace_basic(self):
scan_out, _ = pytensor.scan( scan_out = pytensor.scan(
lambda x: x + 1, lambda x: x + 1, outputs_info=[pt.zeros(1)], n_steps=3, return_updates=False
outputs_info=[pt.zeros(1)],
n_steps=3,
) )
fgraph = FunctionGraph( fgraph = FunctionGraph(
...@@ -1089,7 +1143,7 @@ class TestScanInplaceOptimizer: ...@@ -1089,7 +1143,7 @@ class TestScanInplaceOptimizer:
u0_t * W_in + x1_tm1 * W + u1_t + u2_t, u0_t * W_in + x1_tm1 * W + u1_t + u2_t,
] ]
outputs, updates = scan( outputs = scan(
f_rnn_shared, f_rnn_shared,
[u0, u1, u2], [u0, u1, u2],
[dict(initial=x0, inplace=u2), dict(initial=x1, inplace=u1)], [dict(initial=x0, inplace=u2), dict(initial=x1, inplace=u1)],
...@@ -1098,12 +1152,12 @@ class TestScanInplaceOptimizer: ...@@ -1098,12 +1152,12 @@ class TestScanInplaceOptimizer:
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
mode=self.mode, mode=self.mode,
return_updates=False,
) )
f9 = function( f9 = function(
[mu0, mu1, mu2, x0, x1], [mu0, mu1, mu2, x0, x1],
outputs, outputs,
updates=updates,
mode=self.mode, mode=self.mode,
allow_input_downcast=True, allow_input_downcast=True,
) )
...@@ -1155,7 +1209,7 @@ class TestScanInplaceOptimizer: ...@@ -1155,7 +1209,7 @@ class TestScanInplaceOptimizer:
u0_t * W_in + x1_tm1 * W + u2_tm1 + u2_t + u2_tp1, u0_t * W_in + x1_tm1 * W + u2_tm1 + u2_t + u2_tp1,
] ]
outputs, updates = scan( outputs = scan(
f_rnn_shared, f_rnn_shared,
[u0, dict(input=u1, taps=[0, 1]), dict(input=u2, taps=[-1, 0, +1])], [u0, dict(input=u1, taps=[0, 1]), dict(input=u2, taps=[-1, 0, +1])],
[dict(initial=x0), dict(initial=x1)], [dict(initial=x0), dict(initial=x1)],
...@@ -1164,11 +1218,11 @@ class TestScanInplaceOptimizer: ...@@ -1164,11 +1218,11 @@ class TestScanInplaceOptimizer:
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
mode=self.mode, mode=self.mode,
return_updates=False,
) )
f9 = function( f9 = function(
[mu0, mu1, mu2, x0, x1], [mu0, mu1, mu2, x0, x1],
outputs, outputs,
updates=updates,
mode=self.mode, mode=self.mode,
allow_input_downcast=True, allow_input_downcast=True,
) )
...@@ -1202,8 +1256,12 @@ class TestScanInplaceOptimizer: ...@@ -1202,8 +1256,12 @@ class TestScanInplaceOptimizer:
vx1 = asarrayX(rng.uniform()) vx1 = asarrayX(rng.uniform())
x0 = shared(vx0) x0 = shared(vx0)
x1 = shared(vx1) x1 = shared(vx1)
outputs, updates = scan( outputs = scan(
lambda x, y: (x + asarrayX(1), y + asarrayX(1)), [], [x0, x1], n_steps=3 lambda x, y: (x + asarrayX(1), y + asarrayX(1)),
[],
[x0, x1],
n_steps=3,
return_updates=False,
) )
x0 = asarrayX(np.zeros((4,))) x0 = asarrayX(np.zeros((4,)))
x0[0] = vx0 x0[0] = vx0
...@@ -1212,7 +1270,7 @@ class TestScanInplaceOptimizer: ...@@ -1212,7 +1270,7 @@ class TestScanInplaceOptimizer:
to_replace = outputs[0].owner.inputs[0].owner.inputs[1] to_replace = outputs[0].owner.inputs[0].owner.inputs[1]
outputs = clone_replace(outputs, replace=[(to_replace, x0)]) outputs = clone_replace(outputs, replace=[(to_replace, x0)])
f9 = function([], outputs, updates=updates, mode=self.mode) f9 = function([], outputs, mode=self.mode)
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)] scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 not in scan_node[0].op.destroy_map assert 0 not in scan_node[0].op.destroy_map
assert 1 in scan_node[0].op.destroy_map assert 1 in scan_node[0].op.destroy_map
...@@ -1249,7 +1307,7 @@ class TestSaveMem: ...@@ -1249,7 +1307,7 @@ class TestSaveMem:
y_tm1 + dot(x_tm1, W_out), y_tm1 + dot(x_tm1, W_out),
] ]
_outputs, updates = scan( outs = scan(
f_rnn_cmpl, f_rnn_cmpl,
[u1, u2], [u1, u2],
[None, dict(initial=x0), dict(initial=y0, taps=[-1, -3])], [None, dict(initial=x0), dict(initial=y0, taps=[-1, -3])],
...@@ -1257,12 +1315,12 @@ class TestSaveMem: ...@@ -1257,12 +1315,12 @@ class TestSaveMem:
n_steps=None, n_steps=None,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
outputs = [_outputs[0][-1], _outputs[1][-1], _outputs[2][-1]] outputs = [outs[0][-1], outs[1][-1], outs[2][-1]]
f4 = function( f4 = function(
[u1, u2, x0, y0, W_in1], [u1, u2, x0, y0, W_in1],
outputs, outputs,
updates=updates,
allow_input_downcast=True, allow_input_downcast=True,
mode=self.mode, mode=self.mode,
) )
...@@ -1297,14 +1355,18 @@ class TestSaveMem: ...@@ -1297,14 +1355,18 @@ class TestSaveMem:
u = vector("u") u = vector("u")
idx = iscalar("idx") idx = iscalar("idx")
jdx = iscalar("jdx") jdx = iscalar("jdx")
[x1, x2, x3, x4, x5, x6, x7], updates = scan( [x1, x2, x3, x4, x5, x6, x7] = scan(
f_rnn, u, n_steps=None, truncate_gradient=-1, go_backwards=False f_rnn,
u,
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
return_updates=False,
) )
f2 = function( f2 = function(
[u, idx, jdx], [u, idx, jdx],
[x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]], [x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]],
updates=updates,
allow_input_downcast=True, allow_input_downcast=True,
mode=self.mode.excluding("scan_push_out_seq"), mode=self.mode.excluding("scan_push_out_seq"),
) )
...@@ -1341,10 +1403,8 @@ class TestSaveMem: ...@@ -1341,10 +1403,8 @@ class TestSaveMem:
def test_save_mem_reduced_number_of_steps_constant(self): def test_save_mem_reduced_number_of_steps_constant(self):
x0 = pt.scalar("x0") x0 = pt.scalar("x0")
xs, _ = scan( xs = scan(
lambda xtm1: xtm1 + 1, lambda xtm1: xtm1 + 1, outputs_info=[x0], n_steps=10, return_updates=False
outputs_info=[x0],
n_steps=10,
) )
fn = function([x0], xs[:5], mode=self.mode) fn = function([x0], xs[:5], mode=self.mode)
...@@ -1358,10 +1418,11 @@ class TestSaveMem: ...@@ -1358,10 +1418,11 @@ class TestSaveMem:
def test_save_mem_cannot_reduce_constant_number_of_steps(self): def test_save_mem_cannot_reduce_constant_number_of_steps(self):
x0 = pt.scalar("x0") x0 = pt.scalar("x0")
[xs, ys], _ = scan( [xs, ys] = scan(
lambda xtm1, ytm1: (xtm1 + 1, ytm1 - 1), lambda xtm1, ytm1: (xtm1 + 1, ytm1 - 1),
outputs_info=[x0, x0], outputs_info=[x0, x0],
n_steps=10, n_steps=10,
return_updates=False,
) )
# Because of ys[-1] we need all the steps! # Because of ys[-1] we need all the steps!
...@@ -1399,7 +1460,7 @@ class TestSaveMem: ...@@ -1399,7 +1460,7 @@ class TestSaveMem:
x20 = scalar("x20") x20 = scalar("x20")
x30 = vector("x30") x30 = vector("x30")
x40 = scalar("x40") x40 = scalar("x40")
[x1, x2, x3, x4, x5, _x6, _x7], updates = scan( [x1, x2, x3, x4, x5, _x6, _x7] = scan(
step, step,
u, u,
[ [
...@@ -1414,12 +1475,12 @@ class TestSaveMem: ...@@ -1414,12 +1475,12 @@ class TestSaveMem:
n_steps=None, n_steps=None,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
f = function( f = function(
[u, x10, x20, x30, x40], [u, x10, x20, x30, x40],
[x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]], [x1[-7], x2[-3:-1], x3[-6:], x4[-1], x5[-1]],
updates=updates,
allow_input_downcast=True, allow_input_downcast=True,
mode=self.mode, mode=self.mode,
) )
...@@ -1479,10 +1540,11 @@ class TestSaveMem: ...@@ -1479,10 +1540,11 @@ class TestSaveMem:
def test_savemem_does_not_duplicate_number_of_scan_nodes(self): def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
var = pt.ones(()) var = pt.ones(())
values, _ = scan( values = scan(
lambda x: ([x], (), until(x)), lambda x: ([x], (), until(x)),
outputs_info=[var], outputs_info=[var],
n_steps=2, n_steps=2,
return_updates=False,
) )
tmp_fn = function([var], values, mode=self.mode) tmp_fn = function([var], values, mode=self.mode)
...@@ -1493,10 +1555,11 @@ class TestSaveMem: ...@@ -1493,10 +1555,11 @@ class TestSaveMem:
def test_savemem_opt(self, benchmark): def test_savemem_opt(self, benchmark):
y0 = shared(np.ones((2, 10))) y0 = shared(np.ones((2, 10)))
[_y1, y2], _updates = scan( [_y1, y2] = scan(
lambda y: [y, y], lambda y: [y, y],
outputs_info=[dict(initial=y0, taps=[-2]), None], outputs_info=[dict(initial=y0, taps=[-2]), None],
n_steps=5, n_steps=5,
return_updates=False,
) )
# TODO FIXME: Make this a real test and assert something. # TODO FIXME: Make this a real test and assert something.
fn = function([], y2.sum(), mode=self.mode) fn = function([], y2.sum(), mode=self.mode)
...@@ -1515,23 +1578,25 @@ class TestSaveMem: ...@@ -1515,23 +1578,25 @@ class TestSaveMem:
return dot(h_tm1, w) + x_t_t return dot(h_tm1, w) + x_t_t
def outer_scan_step(x_t, w): def outer_scan_step(x_t, w):
h, _ = scan( h = scan(
inner_scan_step, inner_scan_step,
sequences=[x_t[1:]], sequences=[x_t[1:]],
outputs_info=[x_t[0]], outputs_info=[x_t[0]],
non_sequences=[w], non_sequences=[w],
strict=True, strict=True,
name="the_inner_scan", name="the_inner_scan",
return_updates=False,
) )
return h return h
def get_outputs(x, w): def get_outputs(x, w):
features, _ = scan( features = scan(
outer_scan_step, outer_scan_step,
sequences=[x], sequences=[x],
non_sequences=[w], non_sequences=[w],
strict=True, strict=True,
name="the_outer_scan", name="the_outer_scan",
return_updates=False,
) )
return_val = grad(features.sum(), w) return_val = grad(features.sum(), w)
...@@ -1571,7 +1636,7 @@ class TestSaveMem: ...@@ -1571,7 +1636,7 @@ class TestSaveMem:
state = vector("state") state = vector("state")
n_steps = iscalar("nsteps") n_steps = iscalar("nsteps")
output, updates = scan( output = scan(
f_pow2, f_pow2,
[], [],
state, state,
...@@ -1579,13 +1644,13 @@ class TestSaveMem: ...@@ -1579,13 +1644,13 @@ class TestSaveMem:
n_steps=n_steps, n_steps=n_steps,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False, go_backwards=False,
return_updates=False,
) )
nw_shape = ivector("nw_shape") nw_shape = ivector("nw_shape")
# Note that the output is reshaped to 3 dimensional tensor, and # Note that the output is reshaped to 3 dimensional tensor, and
my_f = function( my_f = function(
[state, n_steps, nw_shape], [state, n_steps, nw_shape],
[reshape(output, nw_shape, ndim=3)[:-2], output[:-4]], [reshape(output, nw_shape, ndim=3)[:-2], output[:-4]],
updates=updates,
allow_input_downcast=True, allow_input_downcast=True,
) )
nodes = [x for x in my_f.maker.fgraph.toposort() if isinstance(x.op, Scan)] nodes = [x for x in my_f.maker.fgraph.toposort() if isinstance(x.op, Scan)]
...@@ -1599,11 +1664,12 @@ class TestSaveMem: ...@@ -1599,11 +1664,12 @@ class TestSaveMem:
n_steps = scalar("n_steps", dtype="int64") n_steps = scalar("n_steps", dtype="int64")
x0 = vector("x0") x0 = vector("x0")
ys, _ = pytensor.scan( ys = pytensor.scan(
# Fibonacci Sequence # Fibonacci Sequence
lambda xtm2, xtm1: (xtm1 + xtm2, {}, until(xtm1 >= 34)), lambda xtm2, xtm1: (xtm1 + xtm2, {}, until(xtm1 >= 34)),
outputs_info=[{"initial": x0, "taps": [-2, -1]}], outputs_info=[{"initial": x0, "taps": [-2, -1]}],
n_steps=n_steps, n_steps=n_steps,
return_updates=False,
) )
# Save memory is triggered by choosing only last value # Save memory is triggered by choosing only last value
y = ys[-1] y = ys[-1]
...@@ -1629,10 +1695,11 @@ class TestSaveMem: ...@@ -1629,10 +1695,11 @@ class TestSaveMem:
def test_while_scan_map(self): def test_while_scan_map(self):
xs = vector("xs") xs = vector("xs")
ys, _ = pytensor.scan( ys = pytensor.scan(
lambda x: (x + 1, {}, until(x + 1 >= 10)), lambda x: (x + 1, {}, until(x + 1 >= 10)),
outputs_info=[None], outputs_info=[None],
sequences=[xs], sequences=[xs],
return_updates=False,
) )
# Save memory is triggered by choosing only last value # Save memory is triggered by choosing only last value
y = ys[-1] y = ys[-1]
...@@ -1656,11 +1723,12 @@ class TestSaveMem: ...@@ -1656,11 +1723,12 @@ class TestSaveMem:
n_steps = scalar("n_steps", dtype="int64") n_steps = scalar("n_steps", dtype="int64")
# while loop # while loop
[ys, zs], _ = pytensor.scan( [ys, zs] = pytensor.scan(
lambda s, xtm1: ((xtm1 + 1, xtm1 + 1 + s), {}, until(xtm1 >= 99)), lambda s, xtm1: ((xtm1 + 1, xtm1 + 1 + s), {}, until(xtm1 >= 99)),
sequences=[seq], sequences=[seq],
outputs_info=[x0, None], outputs_info=[x0, None],
n_steps=n_steps, n_steps=n_steps,
return_updates=False,
) )
# Save memory is triggered by choosing only last value # Save memory is triggered by choosing only last value
y = ys[-1] y = ys[-1]
...@@ -1696,10 +1764,11 @@ class TestSaveMem: ...@@ -1696,10 +1764,11 @@ class TestSaveMem:
val_test = np.zeros(val_shape, dtype=val.dtype) val_test = np.zeros(val_shape, dtype=val.dtype)
init = pt.full((2,), val) init = pt.full((2,), val)
ys, _ = pytensor.scan( ys = pytensor.scan(
fn=lambda *args: pt.add(*args), fn=lambda *args: pt.add(*args),
outputs_info=[{"initial": init, "taps": (-2, -1)}], outputs_info=[{"initial": init, "taps": (-2, -1)}],
n_steps=100, n_steps=100,
return_updates=False,
) )
out = ys[:-50] if keep_beginning else ys[-50:] out = ys[:-50] if keep_beginning else ys[-50:]
...@@ -1729,12 +1798,13 @@ def test_inner_replace_dot(): ...@@ -1729,12 +1798,13 @@ def test_inner_replace_dot():
mode = get_default_mode().including("scan") # .excluding("BlasOpt") mode = get_default_mode().including("scan") # .excluding("BlasOpt")
o, _ = scan( o = scan(
lambda hi, him1, W: (hi, dot(hi + him1, W)), lambda hi, him1, W: (hi, dot(hi + him1, W)),
outputs_info=[pt.zeros([h.shape[1]]), None], outputs_info=[pt.zeros([h.shape[1]]), None],
sequences=[h], sequences=[h],
non_sequences=[W], non_sequences=[W],
mode=mode, mode=mode,
return_updates=False,
) )
f = function([W, h], o, mode=mode) f = function([W, h], o, mode=mode)
...@@ -1753,11 +1823,12 @@ def test_alloc_inputs1(): ...@@ -1753,11 +1823,12 @@ def test_alloc_inputs1():
def lambda_fn(h, W1, W2): def lambda_fn(h, W1, W2):
return dot(h, W1 * W2) return dot(h, W1 * W2)
o, _ = scan( o = scan(
lambda_fn, lambda_fn,
outputs_info=h0, outputs_info=h0,
non_sequences=[W1, pt.zeros_like(W2)], non_sequences=[W1, pt.zeros_like(W2)],
n_steps=5, n_steps=5,
return_updates=False,
) )
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan")) f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
...@@ -1786,12 +1857,13 @@ def test_alloc_inputs2(): ...@@ -1786,12 +1857,13 @@ def test_alloc_inputs2():
def lambda_fn(W1, h, W2): def lambda_fn(W1, h, W2):
return W1 * dot(h, W2) return W1 * dot(h, W2)
o, _ = scan( o = scan(
lambda_fn, lambda_fn,
sequences=pt.zeros_like(W1), sequences=pt.zeros_like(W1),
outputs_info=h0, outputs_info=h0,
non_sequences=[pt.zeros_like(W2)], non_sequences=[pt.zeros_like(W2)],
n_steps=5, n_steps=5,
return_updates=False,
) )
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan")) f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
...@@ -1821,12 +1893,13 @@ def test_alloc_inputs3(): ...@@ -1821,12 +1893,13 @@ def test_alloc_inputs3():
def lambda_fn(W1, h, W2): def lambda_fn(W1, h, W2):
return W1 * dot(h, W2) return W1 * dot(h, W2)
o, _ = scan( o = scan(
lambda_fn, lambda_fn,
sequences=pt.zeros_like(W1), sequences=pt.zeros_like(W1),
outputs_info=h0, outputs_info=h0,
non_sequences=[pt.zeros_like(W2)], non_sequences=[pt.zeros_like(W2)],
n_steps=5, n_steps=5,
return_updates=False,
) )
# TODO FIXME: This result depends on unrelated rewrites in the "fast" mode. # TODO FIXME: This result depends on unrelated rewrites in the "fast" mode.
...@@ -1848,7 +1921,7 @@ def test_opt_order(): ...@@ -1848,7 +1921,7 @@ def test_opt_order():
x = matrix("x") x = matrix("x")
A = matrix("A") A = matrix("A")
z, _updates = scan(dot, sequences=[], non_sequences=[x, A], n_steps=2) z = scan(dot, sequences=[], non_sequences=[x, A], n_steps=2, return_updates=False)
f = function([x, A], z, mode="FAST_RUN") f = function([x, A], z, mode="FAST_RUN")
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
......
...@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): ...@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
A = tensor("A", shape=(3, 3)) A = tensor("A", shape=(3, 3))
x0 = tensor("b", shape=(3, 4)) x0 = tensor("b", shape=(3, 4))
xs, _ = scan( xs = scan(
lambda xtm1, A: solve(A, xtm1, assume_a=assume_a, transposed=transposed), lambda xtm1, A: solve(A, xtm1, assume_a=assume_a, transposed=transposed),
outputs_info=[x0], outputs_info=[x0],
non_sequences=[A], non_sequences=[A],
n_steps=10, n_steps=10,
return_updates=False,
) )
fn_no_opt = function( fn_no_opt = function(
......
...@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type(): ...@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type():
def test_scan_gradient_core_type(): def test_scan_gradient_core_type():
n_steps = 3 n_steps = 3
seq = tensor("seq", shape=(n_steps, 1), dtype="float64") seq = tensor("seq", shape=(n_steps, 1), dtype="float64")
out, _ = scan( out = scan(
lambda s: s, lambda s: s,
sequences=[seq], sequences=[seq],
n_steps=n_steps, n_steps=n_steps,
return_updates=False,
) )
vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64") vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论