提交 d10c61ba authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify scan helper logic

return_steps has not been a thing for 14 years
上级 20e5b721
import warnings import warnings
from itertools import chain
import numpy as np import numpy as np
...@@ -9,7 +10,7 @@ from pytensor.configdefaults import config ...@@ -9,7 +10,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Constant, Variable
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.traversal import graph_inputs from pytensor.graph.traversal import explicit_graph_inputs
from pytensor.graph.utils import MissingInputError, TestValueError from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until from pytensor.scan.utils import expand_empty, safe_new, until
...@@ -475,19 +476,15 @@ def scan( ...@@ -475,19 +476,15 @@ def scan(
else: else:
non_seqs.append(elem) non_seqs.append(elem)
# If we provided a known number of steps ( before compilation) # This helper eagerly skips the Scan if n_steps is known to be 1
# and if that number is 1 or -1, then we can skip the Scan Op, single_step_requested = False
# and just apply the inner function once
# To do that we check here to see the nature of n_steps
n_fixed_steps = None
if isinstance(n_steps, float | int): if isinstance(n_steps, float | int):
n_fixed_steps = int(n_steps) single_step_requested = n_steps == 1
else: else:
try: try:
n_fixed_steps = pt.get_scalar_constant_value(n_steps) single_step_requested = pt.get_scalar_constant_value(n_steps) == 1
except NotScalarConstantError: except NotScalarConstantError:
n_fixed_steps = None pass
# Check n_steps is an int # Check n_steps is an int
if hasattr(n_steps, "dtype") and str(n_steps.dtype) not in integer_dtypes: if hasattr(n_steps, "dtype") and str(n_steps.dtype) not in integer_dtypes:
...@@ -497,7 +494,6 @@ def scan( ...@@ -497,7 +494,6 @@ def scan(
n_seqs = len(seqs) n_seqs = len(seqs)
n_outs = len(outs_info) n_outs = len(outs_info)
return_steps = {}
# wrap sequences in a dictionary if they are not already dictionaries # wrap sequences in a dictionary if they are not already dictionaries
for i in range(n_seqs): for i in range(n_seqs):
if not isinstance(seqs[i], dict): if not isinstance(seqs[i], dict):
...@@ -700,7 +696,6 @@ def scan( ...@@ -700,7 +696,6 @@ def scan(
mit_sot_inner_inputs = [] mit_sot_inner_inputs = []
mit_sot_inner_slices = [] mit_sot_inner_slices = []
mit_sot_inner_outputs = [] mit_sot_inner_outputs = []
mit_sot_return_steps = {}
mit_sot_tap_array = [] mit_sot_tap_array = []
mit_sot_rightOrder = [] mit_sot_rightOrder = []
...@@ -709,7 +704,6 @@ def scan( ...@@ -709,7 +704,6 @@ def scan(
sit_sot_inner_inputs = [] sit_sot_inner_inputs = []
sit_sot_inner_slices = [] sit_sot_inner_slices = []
sit_sot_inner_outputs = [] sit_sot_inner_outputs = []
sit_sot_return_steps = {}
sit_sot_rightOrder = [] sit_sot_rightOrder = []
# go through outputs picking up time slices as needed # go through outputs picking up time slices as needed
...@@ -755,8 +749,6 @@ def scan( ...@@ -755,8 +749,6 @@ def scan(
) )
sit_sot_inner_slices.append(actual_arg) sit_sot_inner_slices.append(actual_arg)
if i in return_steps:
sit_sot_return_steps[n_sit_sot] = return_steps[i]
sit_sot_inner_inputs.append(arg) sit_sot_inner_inputs.append(arg)
sit_sot_rightOrder.append(i) sit_sot_rightOrder.append(i)
n_sit_sot += 1 n_sit_sot += 1
...@@ -774,8 +766,6 @@ def scan( ...@@ -774,8 +766,6 @@ def scan(
expand_empty(init_out["initial"][:mintap], actual_n_steps) expand_empty(init_out["initial"][:mintap], actual_n_steps)
) )
if i in return_steps:
mit_sot_return_steps[n_mit_sot] = return_steps[i]
mit_sot_rightOrder.append(i) mit_sot_rightOrder.append(i)
n_mit_sot += 1 n_mit_sot += 1
for k in init_out["taps"]: for k in init_out["taps"]:
...@@ -819,7 +809,7 @@ def scan( ...@@ -819,7 +809,7 @@ def scan(
offset = 0 offset = 0
for idx in range(n_mit_sot): for idx in range(n_mit_sot):
n_inputs = len(mit_sot_tap_array[idx]) n_inputs = len(mit_sot_tap_array[idx])
if n_fixed_steps in (1, -1): if single_step_requested:
_ordered_args[mit_sot_rightOrder[idx]] = mit_sot_inner_slices[ _ordered_args[mit_sot_rightOrder[idx]] = mit_sot_inner_slices[
offset : offset + n_inputs offset : offset + n_inputs
] ]
...@@ -830,17 +820,14 @@ def scan( ...@@ -830,17 +820,14 @@ def scan(
offset += n_inputs offset += n_inputs
for idx in range(n_sit_sot): for idx in range(n_sit_sot):
if n_fixed_steps in (1, -1): if single_step_requested:
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_slices[idx]] _ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_slices[idx]]
else: else:
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]] _ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
ordered_args = [] ordered_args = list(chain.from_iterable(_ordered_args))
for ls in _ordered_args: if single_step_requested:
ordered_args += ls
if n_fixed_steps in (1, -1):
args = inner_slices + ordered_args + non_seqs args = inner_slices + ordered_args + non_seqs
else: else:
args = inner_seqs + ordered_args + non_seqs args = inner_seqs + ordered_args + non_seqs
...@@ -863,7 +850,7 @@ def scan( ...@@ -863,7 +850,7 @@ def scan(
# Step 3. Check if we actually need scan and remove it if we don't # Step 3. Check if we actually need scan and remove it if we don't
## ##
if n_fixed_steps in (1, -1): if single_step_requested:
for pos, inner_out in enumerate(outputs): for pos, inner_out in enumerate(outputs):
# we need to see if we need to pad our sequences with an # we need to see if we need to pad our sequences with an
# extra dimension; case example : we return an # extra dimension; case example : we return an
...@@ -871,7 +858,7 @@ def scan( ...@@ -871,7 +858,7 @@ def scan(
# then, if we return the output as given by the innner function # then, if we return the output as given by the innner function
# this will represent only a slice and it will have one # this will represent only a slice and it will have one
# dimension less. # dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1: if isinstance(inner_out.type, TensorType):
outputs[pos] = shape_padleft(inner_out) outputs[pos] = shape_padleft(inner_out)
if not return_list and len(outputs) == 1: if not return_list and len(outputs) == 1:
...@@ -896,15 +883,10 @@ def scan( ...@@ -896,15 +883,10 @@ def scan(
fake_outputs = clone_replace( fake_outputs = clone_replace(
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True)) outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
) )
all_inputs = filter( known_inputs = [*args, *fake_nonseqs]
lambda x: ( extra_inputs = [
isinstance(x, Variable) x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs
and not isinstance(x, SharedVariable) ]
and not isinstance(x, Constant)
),
graph_inputs(fake_outputs),
)
extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs]
non_seqs += extra_inputs non_seqs += extra_inputs
# Note we do not use all_inputs directly since the order of variables # Note we do not use all_inputs directly since the order of variables
# in args is quite important # in args is quite important
...@@ -1033,13 +1015,10 @@ def scan( ...@@ -1033,13 +1015,10 @@ def scan(
# Step 5.4 Outputs with no taps used in the input # Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0 n_nit_sot = 0
nit_sot_inner_outputs = [] nit_sot_inner_outputs = []
nit_sot_return_steps = {}
nit_sot_rightOrder = [] nit_sot_rightOrder = []
for i, out in enumerate(outs_info): for i, out in enumerate(outs_info):
if "taps" not in out: if "taps" not in out:
nit_sot_inner_outputs.append(outputs[i]) nit_sot_inner_outputs.append(outputs[i])
if i in return_steps:
nit_sot_return_steps[n_nit_sot] = return_steps[i]
nit_sot_rightOrder.append(i) nit_sot_rightOrder.append(i)
n_nit_sot += 1 n_nit_sot += 1
...@@ -1173,37 +1152,25 @@ def scan( ...@@ -1173,37 +1152,25 @@ def scan(
update_map = OrderedUpdates() update_map = OrderedUpdates()
def remove_dimensions(outs, steps_return, offsets=None): def remove_dimensions(outs, offsets=None):
out_ls = [] out_ls = []
for idx, out in enumerate(outs): for idx, out in enumerate(outs):
if idx in steps_return: if offsets is None:
if steps_return[idx] > 1: out_ls.append(out)
out_ls.append(out[-steps_return[idx] :])
else:
out_ls.append(out[-1])
else: else:
if offsets is None: out_ls.append(out[offsets[idx] :])
out_ls.append(out)
else:
out_ls.append(out[offsets[idx] :])
return out_ls return out_ls
offset = n_mit_mot offset = n_mit_mot
offsets = [abs(np.min(x)) for x in mit_sot_tap_array] offsets = [abs(np.min(x)) for x in mit_sot_tap_array]
mit_sot_outs = remove_dimensions( mit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_mit_sot], offsets)
scan_outs[offset : offset + n_mit_sot], mit_sot_return_steps, offsets
)
offset += n_mit_sot offset += n_mit_sot
offsets = [1 for x in range(n_sit_sot)] offsets = [1 for x in range(n_sit_sot)]
sit_sot_outs = remove_dimensions( sit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_sit_sot], offsets)
scan_outs[offset : offset + n_sit_sot], sit_sot_return_steps, offsets
)
offset += n_sit_sot offset += n_sit_sot
nit_sot_outs = remove_dimensions( nit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_nit_sot])
scan_outs[offset : offset + n_nit_sot], nit_sot_return_steps
)
offset += n_nit_sot offset += n_nit_sot
for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]): for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]):
...@@ -1232,4 +1199,4 @@ def scan( ...@@ -1232,4 +1199,4 @@ def scan(
elif len(scan_out_list) == 0: elif len(scan_out_list) == 0:
scan_out_list = None scan_out_list = None
return (scan_out_list, update_map) return scan_out_list, update_map
...@@ -3650,67 +3650,6 @@ class TestExamples: ...@@ -3650,67 +3650,6 @@ class TestExamples:
if config.mode != "FAST_COMPILE": if config.mode != "FAST_COMPILE":
assert nb_shape_i == 1 assert nb_shape_i == 1
def test_return_steps(self):
rng = np.random.default_rng(utt.fetch_seed())
vW_in2 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
vW = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2)))
vWout = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
vW_in1 = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2)))
v_u1 = asarrayX(rng.uniform(-0.5, 0.5, size=(8, 2)))
v_u2 = asarrayX(rng.uniform(-0.5, 0.5, size=(8,)))
v_x0 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
v_y0 = asarrayX(rng.uniform(size=(3,)))
W_in2 = shared(vW_in2, name="win2")
W = shared(vW, name="w")
W_out = shared(vWout, name="wout")
W_in1 = matrix("win")
u1 = matrix("u1")
u2 = vector("u2")
x0 = vector("x0")
y0 = vector("y0")
def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1):
return [
y_tm3 + 1,
dot(u1_t, W_in1) + u2_t * W_in2 + dot(x_tm1, W),
y_tm1 + dot(x_tm1, W_out),
]
rval, updates = scan(
f_rnn_cmpl,
[u1, u2],
[None, dict(initial=x0), dict(initial=y0, taps=[-1, -3])],
W_in1,
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
)
outputs = []
outputs += [rval[0][-3:]]
outputs += [rval[1][-2:]]
outputs += [rval[2][-4:]]
f4 = function(
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
)
# compute the values in numpy
v_x = np.zeros((8, 2), dtype=config.floatX)
v_y = np.zeros((8,), dtype=config.floatX)
v_x[0] = np.dot(v_u1[0], vW_in1) + v_u2[0] * vW_in2 + np.dot(v_x0, vW)
v_y[0] = np.dot(v_x0, vWout) + v_y0[2]
for i in range(1, 8):
v_x[i] = np.dot(v_u1[i], vW_in1) + v_u2[i] * vW_in2 + np.dot(v_x[i - 1], vW)
v_y[i] = np.dot(v_x[i - 1], vWout) + v_y[i - 1]
(_pytensor_dump, pytensor_x, pytensor_y) = f4(v_u1, v_u2, v_x0, v_y0, vW_in1)
utt.assert_allclose(pytensor_x, v_x[-2:])
utt.assert_allclose(pytensor_y, v_y[-4:])
def test_until_random_infer_shape(self): def test_until_random_infer_shape(self):
""" """
Test for a crash in scan.infer_shape when using both Test for a crash in scan.infer_shape when using both
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论