提交 d10c61ba authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify scan helper logic

return_steps has not been a thing for 14 years
上级 20e5b721
import warnings
from itertools import chain
import numpy as np
......@@ -9,7 +10,7 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.traversal import graph_inputs
from pytensor.graph.traversal import explicit_graph_inputs
from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until
......@@ -475,19 +476,15 @@ def scan(
else:
non_seqs.append(elem)
# If we provided a known number of steps ( before compilation)
# and if that number is 1 or -1, then we can skip the Scan Op,
# and just apply the inner function once
# To do that we check here to see the nature of n_steps
n_fixed_steps = None
# This helper eagerly skips the Scan if n_steps is known to be 1
single_step_requested = False
if isinstance(n_steps, float | int):
n_fixed_steps = int(n_steps)
single_step_requested = n_steps == 1
else:
try:
n_fixed_steps = pt.get_scalar_constant_value(n_steps)
single_step_requested = pt.get_scalar_constant_value(n_steps) == 1
except NotScalarConstantError:
n_fixed_steps = None
pass
# Check n_steps is an int
if hasattr(n_steps, "dtype") and str(n_steps.dtype) not in integer_dtypes:
......@@ -497,7 +494,6 @@ def scan(
n_seqs = len(seqs)
n_outs = len(outs_info)
return_steps = {}
# wrap sequences in a dictionary if they are not already dictionaries
for i in range(n_seqs):
if not isinstance(seqs[i], dict):
......@@ -700,7 +696,6 @@ def scan(
mit_sot_inner_inputs = []
mit_sot_inner_slices = []
mit_sot_inner_outputs = []
mit_sot_return_steps = {}
mit_sot_tap_array = []
mit_sot_rightOrder = []
......@@ -709,7 +704,6 @@ def scan(
sit_sot_inner_inputs = []
sit_sot_inner_slices = []
sit_sot_inner_outputs = []
sit_sot_return_steps = {}
sit_sot_rightOrder = []
# go through outputs picking up time slices as needed
......@@ -755,8 +749,6 @@ def scan(
)
sit_sot_inner_slices.append(actual_arg)
if i in return_steps:
sit_sot_return_steps[n_sit_sot] = return_steps[i]
sit_sot_inner_inputs.append(arg)
sit_sot_rightOrder.append(i)
n_sit_sot += 1
......@@ -774,8 +766,6 @@ def scan(
expand_empty(init_out["initial"][:mintap], actual_n_steps)
)
if i in return_steps:
mit_sot_return_steps[n_mit_sot] = return_steps[i]
mit_sot_rightOrder.append(i)
n_mit_sot += 1
for k in init_out["taps"]:
......@@ -819,7 +809,7 @@ def scan(
offset = 0
for idx in range(n_mit_sot):
n_inputs = len(mit_sot_tap_array[idx])
if n_fixed_steps in (1, -1):
if single_step_requested:
_ordered_args[mit_sot_rightOrder[idx]] = mit_sot_inner_slices[
offset : offset + n_inputs
]
......@@ -830,17 +820,14 @@ def scan(
offset += n_inputs
for idx in range(n_sit_sot):
if n_fixed_steps in (1, -1):
if single_step_requested:
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_slices[idx]]
else:
_ordered_args[sit_sot_rightOrder[idx]] = [sit_sot_inner_inputs[idx]]
ordered_args = []
for ls in _ordered_args:
ordered_args += ls
if n_fixed_steps in (1, -1):
ordered_args = list(chain.from_iterable(_ordered_args))
if single_step_requested:
args = inner_slices + ordered_args + non_seqs
else:
args = inner_seqs + ordered_args + non_seqs
......@@ -863,7 +850,7 @@ def scan(
# Step 3. Check if we actually need scan and remove it if we don't
##
if n_fixed_steps in (1, -1):
if single_step_requested:
for pos, inner_out in enumerate(outputs):
# we need to see if we need to pad our sequences with an
# extra dimension; case example : we return an
......@@ -871,7 +858,7 @@ def scan(
# then, if we return the output as given by the innner function
# this will represent only a slice and it will have one
# dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
if isinstance(inner_out.type, TensorType):
outputs[pos] = shape_padleft(inner_out)
if not return_list and len(outputs) == 1:
......@@ -896,15 +883,10 @@ def scan(
fake_outputs = clone_replace(
outputs, replace=dict(zip(non_seqs, fake_nonseqs, strict=True))
)
all_inputs = filter(
lambda x: (
isinstance(x, Variable)
and not isinstance(x, SharedVariable)
and not isinstance(x, Constant)
),
graph_inputs(fake_outputs),
)
extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs]
known_inputs = [*args, *fake_nonseqs]
extra_inputs = [
x for x in explicit_graph_inputs(fake_outputs) if x not in known_inputs
]
non_seqs += extra_inputs
# Note we do not use all_inputs directly since the order of variables
# in args is quite important
......@@ -1033,13 +1015,10 @@ def scan(
# Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0
nit_sot_inner_outputs = []
nit_sot_return_steps = {}
nit_sot_rightOrder = []
for i, out in enumerate(outs_info):
if "taps" not in out:
nit_sot_inner_outputs.append(outputs[i])
if i in return_steps:
nit_sot_return_steps[n_nit_sot] = return_steps[i]
nit_sot_rightOrder.append(i)
n_nit_sot += 1
......@@ -1173,37 +1152,25 @@ def scan(
update_map = OrderedUpdates()
def remove_dimensions(outs, steps_return, offsets=None):
def remove_dimensions(outs, offsets=None):
out_ls = []
for idx, out in enumerate(outs):
if idx in steps_return:
if steps_return[idx] > 1:
out_ls.append(out[-steps_return[idx] :])
else:
out_ls.append(out[-1])
if offsets is None:
out_ls.append(out)
else:
if offsets is None:
out_ls.append(out)
else:
out_ls.append(out[offsets[idx] :])
out_ls.append(out[offsets[idx] :])
return out_ls
offset = n_mit_mot
offsets = [abs(np.min(x)) for x in mit_sot_tap_array]
mit_sot_outs = remove_dimensions(
scan_outs[offset : offset + n_mit_sot], mit_sot_return_steps, offsets
)
mit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_mit_sot], offsets)
offset += n_mit_sot
offsets = [1 for x in range(n_sit_sot)]
sit_sot_outs = remove_dimensions(
scan_outs[offset : offset + n_sit_sot], sit_sot_return_steps, offsets
)
sit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_sit_sot], offsets)
offset += n_sit_sot
nit_sot_outs = remove_dimensions(
scan_outs[offset : offset + n_nit_sot], nit_sot_return_steps
)
nit_sot_outs = remove_dimensions(scan_outs[offset : offset + n_nit_sot])
offset += n_nit_sot
for idx, update_rule in enumerate(scan_outs[offset : offset + n_shared_outs]):
......@@ -1232,4 +1199,4 @@ def scan(
elif len(scan_out_list) == 0:
scan_out_list = None
return (scan_out_list, update_map)
return scan_out_list, update_map
......@@ -3650,67 +3650,6 @@ class TestExamples:
if config.mode != "FAST_COMPILE":
assert nb_shape_i == 1
def test_return_steps(self):
rng = np.random.default_rng(utt.fetch_seed())
vW_in2 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
vW = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2)))
vWout = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
vW_in1 = asarrayX(rng.uniform(-0.5, 0.5, size=(2, 2)))
v_u1 = asarrayX(rng.uniform(-0.5, 0.5, size=(8, 2)))
v_u2 = asarrayX(rng.uniform(-0.5, 0.5, size=(8,)))
v_x0 = asarrayX(rng.uniform(-0.5, 0.5, size=(2,)))
v_y0 = asarrayX(rng.uniform(size=(3,)))
W_in2 = shared(vW_in2, name="win2")
W = shared(vW, name="w")
W_out = shared(vWout, name="wout")
W_in1 = matrix("win")
u1 = matrix("u1")
u2 = vector("u2")
x0 = vector("x0")
y0 = vector("y0")
def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, y_tm3, W_in1):
return [
y_tm3 + 1,
dot(u1_t, W_in1) + u2_t * W_in2 + dot(x_tm1, W),
y_tm1 + dot(x_tm1, W_out),
]
rval, updates = scan(
f_rnn_cmpl,
[u1, u2],
[None, dict(initial=x0), dict(initial=y0, taps=[-1, -3])],
W_in1,
n_steps=None,
truncate_gradient=-1,
go_backwards=False,
)
outputs = []
outputs += [rval[0][-3:]]
outputs += [rval[1][-2:]]
outputs += [rval[2][-4:]]
f4 = function(
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
)
# compute the values in numpy
v_x = np.zeros((8, 2), dtype=config.floatX)
v_y = np.zeros((8,), dtype=config.floatX)
v_x[0] = np.dot(v_u1[0], vW_in1) + v_u2[0] * vW_in2 + np.dot(v_x0, vW)
v_y[0] = np.dot(v_x0, vWout) + v_y0[2]
for i in range(1, 8):
v_x[i] = np.dot(v_u1[i], vW_in1) + v_u2[i] * vW_in2 + np.dot(v_x[i - 1], vW)
v_y[i] = np.dot(v_x[i - 1], vWout) + v_y[i - 1]
(_pytensor_dump, pytensor_x, pytensor_y) = f4(v_u1, v_u2, v_x0, v_y0, vW_in1)
utt.assert_allclose(pytensor_x, v_x[-2:])
utt.assert_allclose(pytensor_y, v_y[-4:])
def test_until_random_infer_shape(self):
"""
Test for a crash in scan.infer_shape when using both
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论