提交 65cc349a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Simplify scan helper logic

上级 44720f71
import typing
import warnings
from functools import reduce
from itertools import chain
import numpy as np
......@@ -16,7 +17,6 @@ from pytensor.graph.type import HasShape
from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import minimum
from pytensor.tensor.shape import shape_padleft
......@@ -143,31 +143,6 @@ def get_updates_and_outputs(ls):
raise ValueError(error_msg)
def isNaN_or_Inf_or_None(x):
isNone = x is None
try:
isNaN = np.isnan(x)
isInf = np.isinf(x)
isStr = isinstance(x, str)
except Exception:
isNaN = False
isInf = False
isStr = False
if not isNaN and not isInf:
try:
val = get_underlying_scalar_constant_value(x)
isInf = np.isinf(val)
isNaN = np.isnan(val)
except Exception:
isNaN = False
isInf = False
if isinstance(x, Constant) and isinstance(x.data, str):
isStr = True
else:
isStr = False
return isNone or isNaN or isInf or isStr
def _manage_output_api_change(outputs, updates, return_updates):
if return_updates:
warnings.warn(
......@@ -505,7 +480,7 @@ def scan(
# This helper eagerly skips the Scan if n_steps is known to be 1
single_step_requested = False
if isinstance(n_steps, float | int):
if isinstance(n_steps, int | float):
single_step_requested = n_steps == 1
else:
try:
......@@ -676,33 +651,20 @@ def scan(
if nw_name is not None:
nw_seq.name = nw_name
# Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes
lengths_vec = [seq.shape[0] for seq in scan_seqs]
if not isNaN_or_Inf_or_None(n_steps):
# ^ N_steps should also be considered
lengths_vec.append(pt.as_tensor(n_steps))
if len(lengths_vec) == 0:
# ^ No information about the number of steps
if n_steps is None:
if not scan_seqs:
raise ValueError(
"No information about the number of steps "
"provided. Either provide a value for "
"n_steps argument of scan or provide an input "
"sequence"
"No information about the number of steps provided. "
"Either provide a value for n_steps argument of scan or provide an input sequence."
)
# If the user has provided the number of steps, do that regardless ( and
# raise an error if the sequences are not long enough )
if isNaN_or_Inf_or_None(n_steps):
actual_n_steps = lengths_vec[0]
for contestant in lengths_vec[1:]:
actual_n_steps = minimum(actual_n_steps, contestant)
actual_n_steps = reduce(minimum, [seq.shape[0] for seq in scan_seqs])
else:
actual_n_steps = pt.as_tensor(n_steps)
actual_n_steps = pt.as_tensor(n_steps, dtype="int64", ndim=0)
# Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
# Conventions :
# mit_mot = multiple input taps, multiple output taps ( only provided
# by the gradient function )
......@@ -899,10 +861,8 @@ def scan(
raw_inner_outputs = fn(*args)
condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs)
if condition is not None:
as_while = True
else:
as_while = False
as_while = condition is not None
##
# Step 3. Check if we actually need scan and remove it if we don't
##
......@@ -934,7 +894,7 @@ def scan(
# extract still missing inputs (there still might be so) and add them
# as non sequences at the end of our args
if condition is not None:
if as_while:
outputs.append(condition)
fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = clone_replace(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论