提交 78293400 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Start deprecating shared updates API in Scan

Using DeprecationWarning to keep it visible only for devs for now
上级 1d19c375
...@@ -2188,7 +2188,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): ...@@ -2188,7 +2188,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
# It is possible that the inputs are disconnected from expr, # It is possible that the inputs are disconnected from expr,
# even if they are connected to cost. # even if they are connected to cost.
# This should not be an error. # This should not be an error.
hess, updates = pytensor.scan( hess = pytensor.scan(
lambda i, y, x: grad( lambda i, y, x: grad(
y[i], y[i],
x, x,
...@@ -2197,9 +2197,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"): ...@@ -2197,9 +2197,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
), ),
sequences=pytensor.tensor.arange(expr.shape[0]), sequences=pytensor.tensor.arange(expr.shape[0]),
non_sequences=[expr, input], non_sequences=[expr, input],
) return_updates=False,
assert not updates, (
"Scan has returned a list of updates; this should not happen."
) )
hessians.append(hess) hessians.append(hess)
return as_list_or_tuple(using_list, using_tuple, hessians) return as_list_or_tuple(using_list, using_tuple, hessians)
......
...@@ -168,6 +168,26 @@ def isNaN_or_Inf_or_None(x): ...@@ -168,6 +168,26 @@ def isNaN_or_Inf_or_None(x):
return isNone or isNaN or isInf or isStr return isNone or isNaN or isInf or isStr
def _manage_output_api_change(outputs, updates, return_updates):
if return_updates:
warnings.warn(
"Scan return signature will change. Updates dict will not be returned, only the first argument. "
"Pass `return_updates=False` to conform to the new API and avoid this warning",
DeprecationWarning,
# Only meant for developers for now. Switch to FutureWarning to warn users, before removing.
stacklevel=3,
)
else:
if updates:
raise ValueError(
f"return_updates=False but Scan produced updates {updates}. "
"Make sure to use outputs_info to handle all recurrent states, and not rely on shared variable updates."
)
return outputs
return outputs, updates
def scan( def scan(
fn, fn,
sequences=None, sequences=None,
...@@ -182,6 +202,7 @@ def scan( ...@@ -182,6 +202,7 @@ def scan(
allow_gc=None, allow_gc=None,
strict=False, strict=False,
return_list=False, return_list=False,
return_updates: bool = True,
): ):
r"""This function constructs and applies a `Scan` `Op` to the provided arguments. r"""This function constructs and applies a `Scan` `Op` to the provided arguments.
...@@ -900,7 +921,7 @@ def scan( ...@@ -900,7 +921,7 @@ def scan(
if not return_list and len(outputs) == 1: if not return_list and len(outputs) == 1:
outputs = outputs[0] outputs = outputs[0]
return (outputs, updates) return _manage_output_api_change(outputs, updates, return_updates)
## ##
# Step 4. Compile the dummy function # Step 4. Compile the dummy function
...@@ -919,6 +940,8 @@ def scan( ...@@ -919,6 +940,8 @@ def scan(
fake_outputs = clone_replace( fake_outputs = clone_replace(
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True)) outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
) )
# TODO: Once we don't treat shared variables specially we should use `truncated_graph_inputs`
# to find implicit inputs in a way that reduces the size of the inner function
known_inputs = [*args, *fake_nonseqs] known_inputs = [*args, *fake_nonseqs]
extra_inputs = [ extra_inputs = [
x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs
...@@ -1074,7 +1097,7 @@ def scan( ...@@ -1074,7 +1097,7 @@ def scan(
if not isinstance(arg, SharedVariable | Constant) if not isinstance(arg, SharedVariable | Constant)
] ]
inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True))) inner_replacements.update(dict(zip(other_scan_args, other_inner_args, strict=True))) # type: ignore[arg-type]
if strict: if strict:
non_seqs_set = set(non_sequences if non_sequences is not None else []) non_seqs_set = set(non_sequences if non_sequences is not None else [])
...@@ -1123,7 +1146,7 @@ def scan( ...@@ -1123,7 +1146,7 @@ def scan(
if condition is not None: if condition is not None:
inner_outs.append(condition) inner_outs.append(condition)
new_outs = clone_replace(inner_outs, replace=inner_replacements) new_outs = clone_replace(inner_outs, replace=inner_replacements) # type: ignore[arg-type]
## ##
# Step 7. Create the Scan Op # Step 7. Create the Scan Op
...@@ -1211,12 +1234,14 @@ def scan( ...@@ -1211,12 +1234,14 @@ def scan(
offset += n_nit_sot offset += n_nit_sot
# Support for explicit untraced sit_sot # Legacy support for explicit untraced sit_sot and those built with update dictionary
# Switch to n_untraced_sit_sot_outs after deprecation period
n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder) n_explicit_untraced_sit_sot_outs = len(untraced_sit_sot_rightOrder)
untraced_sit_sot_outs = scan_outs[ untraced_sit_sot_outs = scan_outs[
offset : offset + n_explicit_untraced_sit_sot_outs offset : offset + n_explicit_untraced_sit_sot_outs
] ]
# Legacy support: map shared outputs to their updates
offset += n_explicit_untraced_sit_sot_outs offset += n_explicit_untraced_sit_sot_outs
for idx, update_rule in enumerate(scan_outs[offset:]): for idx, update_rule in enumerate(scan_outs[offset:]):
update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule update_map[untraced_sit_sot_scan_inputs[idx]] = update_rule
...@@ -1245,8 +1270,8 @@ def scan( ...@@ -1245,8 +1270,8 @@ def scan(
update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1] update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
scan_out_list = [x for x in scan_out_list if x is not None] scan_out_list = [x for x in scan_out_list if x is not None]
if not return_list and len(scan_out_list) == 1: if not return_list and len(scan_out_list) == 1:
scan_out_list = scan_out_list[0] scan_out_list = scan_out_list[0] # type: ignore[assignment]
elif len(scan_out_list) == 0: elif len(scan_out_list) == 0:
scan_out_list = None scan_out_list = None # type: ignore[assignment]
return scan_out_list, update_map return _manage_output_api_change(scan_out_list, update_map, return_updates)
...@@ -13,6 +13,7 @@ def scan_checkpoints( ...@@ -13,6 +13,7 @@ def scan_checkpoints(
n_steps=None, n_steps=None,
save_every_N=10, save_every_N=10,
padding=True, padding=True,
return_updates=True,
): ):
"""Scan function that uses less memory, but is more restrictive. """Scan function that uses less memory, but is more restrictive.
...@@ -157,24 +158,28 @@ def scan_checkpoints( ...@@ -157,24 +158,28 @@ def scan_checkpoints(
] * len(new_nitsots) ] * len(new_nitsots)
# Call the user-provided function with the proper arguments # Call the user-provided function with the proper arguments
results, updates = scan( results_and_updates = scan(
fn=fn, fn=fn,
sequences=i_sequences[:-1], sequences=i_sequences[:-1],
outputs_info=i_outputs_infos, outputs_info=i_outputs_infos,
non_sequences=i_non_sequences, non_sequences=i_non_sequences,
name=name + "_inner", name=name + "_inner",
n_steps=i_sequences[-1], n_steps=i_sequences[-1],
return_updates=return_updates,
) )
if return_updates:
results, updates = results_and_updates
else:
results = results_and_updates
updates = {}
if not isinstance(results, list): if not isinstance(results, list):
results = [results] results = [results]
# Keep only the last timestep of every output but keep all the updates # Keep only the last timestep of every output but keep all the updates
if not isinstance(results, list):
return results[-1], updates
else:
return [r[-1] for r in results], updates return [r[-1] for r in results], updates
results, updates = scan( return scan(
fn=outer_step, fn=outer_step,
sequences=o_sequences, sequences=o_sequences,
outputs_info=outputs_info, outputs_info=outputs_info,
...@@ -182,6 +187,5 @@ def scan_checkpoints( ...@@ -182,6 +187,5 @@ def scan_checkpoints(
name=name + "_outer", name=name + "_outer",
n_steps=o_n_steps, n_steps=o_n_steps,
allow_gc=True, allow_gc=True,
return_updates=return_updates,
) )
return results, updates
...@@ -16,6 +16,7 @@ def map( ...@@ -16,6 +16,7 @@ def map(
go_backwards=False, go_backwards=False,
mode=None, mode=None,
name=None, name=None,
return_updates=True,
): ):
"""Construct a `Scan` `Op` that functions like `map`. """Construct a `Scan` `Op` that functions like `map`.
...@@ -50,6 +51,7 @@ def map( ...@@ -50,6 +51,7 @@ def map(
go_backwards=go_backwards, go_backwards=go_backwards,
mode=mode, mode=mode,
name=name, name=name,
return_updates=return_updates,
) )
...@@ -61,6 +63,7 @@ def reduce( ...@@ -61,6 +63,7 @@ def reduce(
go_backwards=False, go_backwards=False,
mode=None, mode=None,
name=None, name=None,
return_updates=True,
): ):
"""Construct a `Scan` `Op` that functions like `reduce`. """Construct a `Scan` `Op` that functions like `reduce`.
...@@ -97,14 +100,29 @@ def reduce( ...@@ -97,14 +100,29 @@ def reduce(
truncate_gradient=-1, truncate_gradient=-1,
mode=mode, mode=mode,
name=name, name=name,
return_updates=return_updates,
) )
if return_updates:
if isinstance(rval[0], list | tuple): if isinstance(rval[0], list | tuple):
return [x[-1] for x in rval[0]], rval[1] return [x[-1] for x in rval[0]], rval[1]
else: else:
return rval[0][-1], rval[1] return rval[0][-1], rval[1]
else:
if isinstance(rval, list | tuple):
return [x[-1] for x in rval]
else:
return rval[-1]
def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None): def foldl(
fn,
sequences,
outputs_info,
non_sequences=None,
mode=None,
name=None,
return_updates=True,
):
"""Construct a `Scan` `Op` that functions like Haskell's `foldl`. """Construct a `Scan` `Op` that functions like Haskell's `foldl`.
Parameters Parameters
...@@ -135,10 +153,19 @@ def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None) ...@@ -135,10 +153,19 @@ def foldl(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
go_backwards=False, go_backwards=False,
mode=mode, mode=mode,
name=name, name=name,
return_updates=return_updates,
) )
def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None): def foldr(
fn,
sequences,
outputs_info,
non_sequences=None,
mode=None,
name=None,
return_updates=True,
):
"""Construct a `Scan` `Op` that functions like Haskell's `foldr`. """Construct a `Scan` `Op` that functions like Haskell's `foldr`.
Parameters Parameters
...@@ -169,4 +196,5 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None) ...@@ -169,4 +196,5 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None)
go_backwards=True, go_backwards=True,
mode=mode, mode=mode,
name=name, name=name,
return_updates=return_updates,
) )
...@@ -314,11 +314,12 @@ def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable: ...@@ -314,11 +314,12 @@ def _wrap_pad(x: TensorVariable, pad_width: TensorVariable) -> TensorVariable:
def _build_padding_one_direction(array, array_flipped, repeats, *, inner_func, axis): def _build_padding_one_direction(array, array_flipped, repeats, *, inner_func, axis):
[_, parts], _ = scan( [_, parts] = scan(
inner_func, inner_func,
non_sequences=[array, array_flipped], non_sequences=[array, array_flipped],
outputs_info=[0, None], outputs_info=[0, None],
n_steps=repeats, n_steps=repeats,
return_updates=False,
) )
parts = moveaxis(parts, 0, axis) parts = moveaxis(parts, 0, axis)
......
...@@ -27,7 +27,7 @@ from pytensor.compile.monitormode import MonitorMode ...@@ -27,7 +27,7 @@ from pytensor.compile.monitormode import MonitorMode
from pytensor.compile.sharedvalue import shared from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
from pytensor.graph.basic import Apply, equal_computations from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import vectorize_graph from pytensor.graph.replace import vectorize_graph
...@@ -67,6 +67,7 @@ from pytensor.tensor.type import ( ...@@ -67,6 +67,7 @@ from pytensor.tensor.type import (
vector, vector,
) )
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.unittest_tools import assert_equal_computations
if config.mode == "FAST_COMPILE": if config.mode == "FAST_COMPILE":
...@@ -4139,3 +4140,43 @@ def test_rng_outputs_info(): ...@@ -4139,3 +4140,43 @@ def test_rng_outputs_info():
xs_ref.append(rng_ref.normal(xs_ref[-1])) xs_ref.append(rng_ref.normal(xs_ref[-1]))
assert random_generator_type.values_eq(rng_ref, rng_final_eval) assert random_generator_type.values_eq(rng_ref, rng_final_eval)
np.testing.assert_allclose(xs_eval, xs_ref[1:]) np.testing.assert_allclose(xs_eval, xs_ref[1:])
@pytest.mark.filterwarnings("error")
def test_return_updates_api_change():
err_msg = "return_updates=False but Scan produced updates"
warn_msg = "Scan return signature will change. Updates dict will not be returned"
x = shared(np.array(0, dtype="float64"))
with pytest.warns(DeprecationWarning, match=warn_msg):
traced1, updates1 = scan(
lambda: {x: x + 1},
outputs_info=[],
n_steps=5,
)
assert traced1 is None
assert len(updates1) == 1 and x in updates1
with pytest.warns(DeprecationWarning, match=warn_msg):
traced2, updates2 = scan(
lambda x: x + 1,
outputs_info=[x],
n_steps=5,
)
assert isinstance(traced2, Variable)
assert isinstance(updates2, dict) and not updates2
traced3 = scan(
lambda x: x + 1,
outputs_info=[x],
n_steps=5,
return_updates=False,
)
assert isinstance(traced3, Variable)
assert_equal_computations(list(updates1.values()), [traced2[-1]])
assert_equal_computations([traced2], [traced3])
with pytest.raises(ValueError, match=err_msg):
scan(lambda: {x: x + 1}, outputs_info=[], n_steps=5, return_updates=False)
...@@ -9,44 +9,53 @@ from pytensor.tensor.basic import arange, ones_like ...@@ -9,44 +9,53 @@ from pytensor.tensor.basic import arange, ones_like
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import iscalar, vector
@pytest.mark.parametrize("return_updates", [True, False])
class TestScanCheckpoint: class TestScanCheckpoint:
def setup_method(self): def setup_method(self, return_updates):
self.k = iscalar("k") self.k = iscalar("k")
self.A = vector("A") self.A = vector("A")
seq = arange(self.k, dtype="float32") + 1 seq = arange(self.k, dtype="float32") + 1
result, _ = scan( result_raw = scan(
fn=lambda s, prior_result, A: prior_result * A / s, fn=lambda s, prior_result, A: prior_result * A / s,
outputs_info=ones_like(self.A), outputs_info=ones_like(self.A),
sequences=[seq], sequences=[seq],
non_sequences=self.A, non_sequences=self.A,
n_steps=self.k, n_steps=self.k,
return_updates=return_updates,
) )
result_check, _ = scan_checkpoints( result_check_raw = scan_checkpoints(
fn=lambda s, prior_result, A: prior_result * A / s, fn=lambda s, prior_result, A: prior_result * A / s,
outputs_info=ones_like(self.A), outputs_info=ones_like(self.A),
sequences=[seq], sequences=[seq],
non_sequences=self.A, non_sequences=self.A,
n_steps=self.k, n_steps=self.k,
save_every_N=100, save_every_N=100,
return_updates=return_updates,
) )
if return_updates:
result, _ = result_raw
result_check, _ = result_check_raw
else:
result = result_raw
result_check = result_check_raw
self.result = result[-1] self.result = result[-1]
self.result_check = result_check[-1] self.result_check = result_check[-1]
self.grad_A = grad(self.result.sum(), self.A) self.grad_A = grad(self.result.sum(), self.A)
self.grad_A_check = grad(self.result_check.sum(), self.A) self.grad_A_check = grad(self.result_check.sum(), self.A)
def test_forward_pass(self): def test_forward_pass(self, return_updates):
# Test forward computation of A**k. # Test forward computation of A**k.
f = function(inputs=[self.A, self.k], outputs=[self.result, self.result_check]) f = function(inputs=[self.A, self.k], outputs=[self.result, self.result_check])
out, out_check = f(range(10), 101) out, out_check = f(range(10), 101)
assert np.allclose(out, out_check) assert np.allclose(out, out_check)
def test_backward_pass(self): def test_backward_pass(self, return_updates):
# Test gradient computation of A**k. # Test gradient computation of A**k.
f = function(inputs=[self.A, self.k], outputs=[self.grad_A, self.grad_A_check]) f = function(inputs=[self.A, self.k], outputs=[self.grad_A, self.grad_A_check])
out, out_check = f(range(10), 101) out, out_check = f(range(10), 101)
assert np.allclose(out, out_check) assert np.allclose(out, out_check)
def test_taps_error(self): def test_taps_error(self, return_updates):
# Test that an error rises if we use taps in outputs_info. # Test that an error rises if we use taps in outputs_info.
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
scan_checkpoints(lambda: None, [], {"initial": self.A, "taps": [-2]}) scan_checkpoints(lambda: None, [], {"initial": self.A, "taps": [-2]})
import numpy as np import numpy as np
import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function, grad, shared from pytensor import config, function, grad, shared
...@@ -11,24 +12,41 @@ from tests import unittest_tools as utt ...@@ -11,24 +12,41 @@ from tests import unittest_tools as utt
from tests.scan.test_basic import clone_optimized_graph, grab_scan_node from tests.scan.test_basic import clone_optimized_graph, grab_scan_node
def test_reduce(): @pytest.mark.parametrize("return_updates", [True, False])
def test_reduce(return_updates):
v = vector("v") v = vector("v")
s = scalar("s") s = scalar("s")
result, updates = pt_reduce(lambda x, y: x + y, v, s) result_raw = pt_reduce(lambda x, y: x + y, v, s, return_updates=return_updates)
if return_updates:
result, updates = result_raw
assert not updates
else:
result = result_raw
f = function([v, s], result, updates=updates, allow_input_downcast=True) f = function([v, s], result, allow_input_downcast=True)
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
v_v = rng.uniform(-5.0, 5.0, size=(5,)) v_v = rng.uniform(-5.0, 5.0, size=(5,))
assert abs(np.sum(v_v) - f(v_v, 0.0)) < 1e-3 assert abs(np.sum(v_v) - f(v_v, 0.0)) < 1e-3
def test_map(): @pytest.mark.parametrize("return_updates", [True, False])
def test_map(return_updates):
v = vector("v") v = vector("v")
abs_expr, abs_updates = pt_map( abs_expr_raw = pt_map(
lambda x: abs(x), v, [], truncate_gradient=-1, go_backwards=False lambda x: abs(x),
v,
[],
truncate_gradient=-1,
go_backwards=False,
return_updates=return_updates,
) )
if return_updates:
abs_expr, abs_updates = abs_expr_raw
assert not abs_updates
else:
abs_expr = abs_expr_raw
f = function([v], abs_expr, updates=abs_updates, allow_input_downcast=True) f = function([v], abs_expr, allow_input_downcast=True)
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
vals = rng.uniform(-5.0, 5.0, size=(10,)) vals = rng.uniform(-5.0, 5.0, size=(10,))
...@@ -39,10 +57,11 @@ def test_map(): ...@@ -39,10 +57,11 @@ def test_map():
def test_reduce_memory_consumption(): def test_reduce_memory_consumption():
x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX)) x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX))
o, _ = pt_reduce( o = pt_reduce(
lambda v, acc: acc + v, lambda v, acc: acc + v,
x, x,
pt.constant(np.asarray(0.0, dtype=config.floatX)), pt.constant(np.asarray(0.0, dtype=config.floatX)),
return_updates=False,
) )
mode = FAST_RUN mode = FAST_RUN
mode = mode.excluding("inplace") mode = mode.excluding("inplace")
...@@ -69,13 +88,20 @@ def test_reduce_memory_consumption(): ...@@ -69,13 +88,20 @@ def test_reduce_memory_consumption():
utt.assert_allclose(f2(), np.ones((10,))) utt.assert_allclose(f2(), np.ones((10,)))
def test_foldl_memory_consumption(): @pytest.mark.parametrize("return_updates", [True, False])
def test_foldl_memory_consumption(return_updates):
x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX)) x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX))
o, _ = foldl( o_raw = foldl(
lambda v, acc: acc + v, lambda v, acc: acc + v,
x, x,
pt.constant(np.asarray(0.0, dtype=config.floatX)), pt.constant(np.asarray(0.0, dtype=config.floatX)),
return_updates=return_updates,
) )
if return_updates:
o, updates = o_raw
assert not updates
else:
o = o_raw
mode = FAST_RUN mode = FAST_RUN
mode = mode.excluding("inplace") mode = mode.excluding("inplace")
...@@ -102,13 +128,20 @@ def test_foldl_memory_consumption(): ...@@ -102,13 +128,20 @@ def test_foldl_memory_consumption():
utt.assert_allclose(f2(), np.ones((10,))) utt.assert_allclose(f2(), np.ones((10,)))
def test_foldr_memory_consumption(): @pytest.mark.parametrize("return_updates", [True, False])
def test_foldr_memory_consumption(return_updates):
x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX)) x = shared(np.asarray(np.random.uniform(size=(10,)), dtype=config.floatX))
o, _ = foldr( o_raw = foldr(
lambda v, acc: acc + v, lambda v, acc: acc + v,
x, x,
pt.constant(np.asarray(0.0, dtype=config.floatX)), pt.constant(np.asarray(0.0, dtype=config.floatX)),
return_updates=return_updates,
) )
if return_updates:
o, updates = o_raw
assert not updates
else:
o = o_raw
mode = FAST_RUN mode = FAST_RUN
mode = mode.excluding("inplace") mode = mode.excluding("inplace")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论