提交 78293400 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Start deprecating shared updates API in Scan

Using DeprecationWarning to keep it visible only for devs for now
上级 1d19c375
......@@ -2188,7 +2188,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
# It is possible that the inputs are disconnected from expr,
# even if they are connected to cost.
# This should not be an error.
hess, updates = pytensor.scan(
hess = pytensor.scan(
lambda i, y, x: grad(
y[i],
x,
......@@ -2197,9 +2197,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
),
sequences=pytensor.tensor.arange(expr.shape[0]),
non_sequences=[expr, input],
)
assert not updates, (
"Scan has returned a list of updates; this should not happen."
return_updates=False,
)
hessians.append(hess)
return as_list_or_tuple(using_list, using_tuple, hessians)
......
......@@ -168,6 +168,26 @@ def isNaN_or_Inf_or_None(x):
return isNone or isNaN or isInf or isStr
def _manage_output_api_change(outputs, updates, return_updates):
if return_updates:
warnings.warn(
"Scan return signature will change. Updates dict will not be returned, only the first argument. "
"Pass `return_updates=False` to conform to the new API and avoid this warning",
DeprecationWarning,
# Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
stacklevel=3,
)
else:
if updates:
raise ValueError(
f"return_updates=False but Scan produced updates {updates}. "
"Make sure to use outputs_info to handle all recurrent states, and not rely on shared variable updates."
)
return outputs
return outputs, updates
def scan(
fn,
sequences=None,
......@@ -182,6 +202,7 @@ def scan(
allow_gc=None,
strict=False,
return_list=False,
return_updates: bool = True,
):
r"""This function constructs and applies a `Scan` `Op` to the provided arguments.
......@@ -900,7 +921,7 @@ def scan(
if not return_list and len(outputs) == 1:
outputs = outputs[0]
return (outputs, updates)
return _manage_output_api_change(outputs, updates, return_updates)
##
# Step 4. Compile the dummy function
......@@ -919,6 +940,8 @@ def scan(
fake_outputs = clone_replace(
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
)
# TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs`
# to find implicit inputs in a way that reduces the size of the inner function
known_inputs = [*args, *fake_nonseqs]
extra_inputs = [
x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs
......@@ -1074,7 +1097,7 @@ def scan(
if not isinstance(arg, SharedVariable | Constant)
]
inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True)))
inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True))) # type: ignore[arg-type]
if strict:
non_seqs_set = set(non_sequences if non_sequences is not None else [])
......@@ -1123,7 +1146,7 @@ def scan(
if condition is not None:
inner_outs.append(condition)
new_outs = clone_replace(inner_outs, replace=inner_replacements)
new_outs = clone_replace(inner_outs, replace=inner_replacements) # type: ignore[arg-type]
##
# Step 7. Create the Scan Op
......@@ -1211,12 +1234,14 @@ def scan(
offset += n_nit_sot
# Support for explicit untraced sit_sot
# Legacy support for explicit untraced sit_sot and those built with update dictionary
# Switch to n_untraced_sit_sot_outs after deprecation period
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
untraced_sit_sot_outs = scan_outs[
offset : offset + n_explicit_untraced_sit_sot_outs
]
# Legacy support: map shared outputs to their updates
offset += n_explicit_untraced_sit_sot_outs
for idx, update_rule in enumerate(scan_outs[offset:]):
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
......@@ -1245,8 +1270,8 @@ def scan(
update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
scan_out_list = [x for x in scan_out_list if x is not None]
if not return_list and len(scan_out_list) == 1:
scan_out_list = scan_out_list[0]
scan_out_list = scan_out_list[0] # type: ignore[assignment]
elif len(scan_out_list) == 0:
scan_out_list = None
scan_out_list = None # type: ignore[assignment]
return scan_out_list, update_map
return _manage_output_api_change(scan_out_list, update_map, return_updates)
......@@ -13,6 +13,7 @@ def scan_checkpoints(
n_steps=None,
save_every_N=10,
padding=True,
return_updates=True,
):
"""Scan function that uses less memory, but is more restrictive.
......@@ -157,24 +158,28 @@ def scan_checkpoints(
] * len(new_nitsots)
# Call the user-provided function with the proper arguments
results, updates = scan(
results_and_updates = scan(
fn=fn,
sequences=i_sequences[:-1],
outputs_info=i_outputs_infos,
non_sequences=i_non_sequences,
name=name + "_inner",
n_steps=i_sequences[-1],
return_updates=return_updates,
)
if return_updates:
results, updates = results_and_updates
else:
results = results_and_updates
updates = {}
if not isinstance(results, list):
results = [results]
# Keep only the last timestep of every output but keep all the updates
if not isinstance(results, list):
return results[-1], updates
else:
return [r[-1] for r in results], updates
results, updates = scan(
return scan(
fn=outer_step,
sequences=o_sequences,
outputs_info=outputs_info,
......@@ -182,6 +187,5 @@ def scan_checkpoints(
name=name + "_outer",
n_steps=o_n_steps,
allow_gc=True,
return_updates=return_updates,
)
return results, updates
......@@ -16,6 +16,7 @@ def map(
go_backwards=False,
mode=None,
name=None,
return_updates=True,
):
"""Construct a `Scan` `Op` that functions like `map`.
......@@ -50,6 +51,7 @@ def map(
go_backwards=go_backwards,
mode=mode,
name=name,
return_updates=return_updates,
)
......@@ -61,6 +63,7 @@ def reduce(
go_backwards=False,
mode=None,
name=None,
return_updates=True,
):
"""Construct a `Scan` `Op` that functions like `reduce`.
......@@ -97,14 +100,29 @@ def reduce(
truncate_gradient=-1,
mode=mode,
name=name,
return_updates=return_updates,
)
if return_updates:
if isinstance(rval[0], list | tuple):
return [x[-1] for x in rval[0]], rval[1]
else:
return rval[0][-1], rval[1]
else:
if isinstance(rval, list | tuple):
return [x[-1] for x in rval]
else:
return rval[-1]
def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None):
def foldl(
fn,
sequences,
outputs_info,
non_sequences=None,
mode=None,
name=None,
return_updates=True,
):
"""Construct a `Scan` `Op` that functions like Haskell's `foldl`.
Parameters
......@@ -135,10 +153,19 @@ def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
go_backwards=False,
mode=mode,
name=name,
return_updates=return_updates,
)
def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None):
def foldr(
fn,
sequences,
outputs_info,
non_sequences=None,
mode=None,
name=None,
return_updates=True,
):
"""Construct a `Scan` `Op` that functions like Haskell's `foldr`.
Parameters
......@@ -169,4 +196,5 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
go_backwards=True,
mode=mode,
name=name,
return_updates=return_updates,
)
......@@ -314,11 +314,12 @@ def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
def _build_padding_one_direction(array, array_flipped, repeats, *, inner_func, axis):
[_, parts], _ = scan(
[_, parts] = scan(
inner_func,
non_sequences=[array, array_flipped],
outputs_info=[0, None],
n_steps=repeats,
return_updates=False,
)
parts = moveaxis(parts, 0, axis)
......
......@@ -27,7 +27,7 @@ from pytensor.compile.monitormode import MonitorMode
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
from pytensor.graph.basic import Apply, equal_computations
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.replace import vectorize_graph
......@@ -67,6 +67,7 @@ from pytensor.tensor.type import (
vector,
)
from tests import unittest_tools as utt
from tests.unittest_tools import assert_equal_computations
if config.mode == "FAST_COMPILE":
......@@ -4139,3 +4140,43 @@ def test_rng_outputs_info():
xs_ref.append(rng_ref.normal(xs_ref[-1]))
assert random_generator_type.values_eq(rng_ref, rng_final_eval)
np.testing.assert_allclose(xs_eval, xs_ref[1:])
@pytest.mark.filterwarnings("error")
def test_return_updates_api_change():
err_msg = "return_updates=False but Scan produced updates"
warn_msg = "Scan return signature will change. Updates dict will not be returned"
x = shared(np.array(0, dtype="float64"))
with pytest.warns(DeprecationWarning, match=warn_msg):
traced1, updates1 = scan(
lambda: {x: x + 1},
outputs_info=[],
n_steps=5,
)
assert traced1 is None
assert len(updates1) == 1 and x in updates1
with pytest.warns(DeprecationWarning, match=warn_msg):
traced2, updates2 = scan(
lambda x: x + 1,
outputs_info=[x],
n_steps=5,
)
assert isinstance(traced2, Variable)
assert isinstance(updates2, dict) and not updates2
traced3 = scan(
lambda x: x + 1,
outputs_info=[x],
n_steps=5,
return_updates=False,
)
assert isinstance(traced3, Variable)
assert_equal_computations(list(updates1.values()), [traced2[-1]])
assert_equal_computations([traced2], [traced3])
with pytest.raises(ValueError, match=err_msg):
scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False)
......@@ -9,44 +9,53 @@ from pytensor.tensor.basic import arange, ones_like
from pytensor.tensor.type import iscalar, vector
@pytest.mark.parametrize("return_updates", [True, False])
class TestScanCheckpoint:
def setup_method(self):
def setup_method(self, return_updates):
self.k = iscalar("k")
self.A = vector("A")
seq = arange(self.k, dtype="float32") + 1
result, _ = scan(
result_raw = scan(
fn=lambda s, prior_result, A: prior_result * A / s,
outputs_info=ones_like(self.A),
sequences=[seq],
non_sequences=self.A,
n_steps=self.k,
return_updates=return_updates,
)
result_check, _ = scan_checkpoints(
result_check_raw = scan_checkpoints(
fn=lambda s, prior_result, A: prior_result * A / s,
outputs_info=ones_like(self.A),
sequences=[seq],
non_sequences=self.A,
n_steps=self.k,
save_every_N=100,
return_updates=return_updates,
)
if return_updates:
result, _ = result_raw
result_check, _ = result_check_raw
else:
result = result_raw
result_check = result_check_raw
self.result = result[-1]
self.result_check = result_check[-1]
self.grad_A = grad(self.result.sum(), self.A)
self.grad_A_check = grad(self.result_check.sum(), self.A)
def test_forward_pass(self):
def test_forward_pass(self, return_updates):
# Test forward computation of A**k.
f = function(inputs=[self.A, self.k], outputs=[self.result, self.result_check])
out, out_check = f(range(10), 101)
assert np.allclose(out, out_check)
def test_backward_pass(self):
def test_backward_pass(self, return_updates):
# Test gradient computation of A**k.
f = function(inputs=[self.A, self.k], outputs=[self.grad_A, self.grad_A_check])
out, out_check = f(range(10), 101)
assert np.allclose(out, out_check)
def test_taps_error(self):
def test_taps_error(self, return_updates):
# Test that an error rises if we use taps in outputs_info.
with pytest.raises(RuntimeError):
scan_checkpoints(lambda: None, [], {"initial": self.A, "taps": [-2]})
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config, function, grad, shared
......@@ -11,24 +12,41 @@ from tests import unittest_tools as utt
from tests.scan.test_basic import clone_optimized_graph, grab_scan_node
def test_reduce():
@pytest.mark.parametrize("return_updates", [True, False])
def test_reduce(return_updates):
v = vector("v")
s = scalar("s")
result, updates = pt_reduce(lambda x, y: x + y, v, s)
result_raw = pt_reduce(lambda x, y: x + y, v, s, return_updates=return_updates)
if return_updates:
result, updates = result_raw
assert not updates
else:
result = result_raw
f = function([v, s], result, updates=updates, allow_input_downcast=True)
f = function([v, s], result, allow_input_downcast=True)
rng = np.random.default_rng(utt.fetch_seed())
v_v = rng.uniform(-5.0, 5.0, size=(5,))
assert abs(np.sum(v_v) - f(v_v, 0.0)) < 1e-3
def test_map():
@pytest.mark.parametrize("return_updates", [True, False])
def test_map(return_updates):
v = vector("v")
abs_expr, abs_updates = pt_map(
lambda x: abs(x), v, [], truncate_gradient=-1, go_backwards=False
abs_expr_raw = pt_map(
lambda x: abs(x),
v,
[],
truncate_gradient=-1,
go_backwards=False,
return_updates=return_updates,
)
if return_updates:
abs_expr, abs_updates = abs_expr_raw
assert not abs_updates
else:
abs_expr = abs_expr_raw
f = function([v], abs_expr, updates=abs_updates, allow_input_downcast=True)
f = function([v], abs_expr, allow_input_downcast=True)
rng = np.random.default_rng(utt.fetch_seed())
vals = rng.uniform(-5.0, 5.0, size=(10,))
......@@ -39,10 +57,11 @@ def test_map():
def test_reduce_memory_consumption():
x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX))
o, _ = pt_reduce(
o = pt_reduce(
lambda v, acc: acc + v,
x,
pt.constant(np.asarray(0.0, dtype=config.floatX)),
return_updates=False,
)
mode = FAST_RUN
mode = mode.excluding("inplace")
......@@ -69,13 +88,20 @@ def test_reduce_memory_consumption():
utt.assert_allclose(f2(), np.ones((10,)))
def test_foldl_memory_consumption():
@pytest.mark.parametrize("return_updates", [True, False])
def test_foldl_memory_consumption(return_updates):
x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX))
o, _ = foldl(
o_raw = foldl(
lambda v, acc: acc + v,
x,
pt.constant(np.asarray(0.0, dtype=config.floatX)),
return_updates=return_updates,
)
if return_updates:
o, updates = o_raw
assert not updates
else:
o = o_raw
mode = FAST_RUN
mode = mode.excluding("inplace")
......@@ -102,13 +128,20 @@ def test_foldl_memory_consumption():
utt.assert_allclose(f2(), np.ones((10,)))
def test_foldr_memory_consumption():
@pytest.mark.parametrize("return_updates", [True, False])
def test_foldr_memory_consumption(return_updates):
x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX))
o, _ = foldr(
o_raw = foldr(
lambda v, acc: acc + v,
x,
pt.constant(np.asarray(0.0, dtype=config.floatX)),
return_updates=return_updates,
)
if return_updates:
o, updates = o_raw
assert not updates
else:
o = o_raw
mode = FAST_RUN
mode = mode.excluding("inplace")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论