提交 65cc349a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Simplify scan helper logic

上级 44720f71
import typing import typing
import warnings import warnings
from functools import reduce
from itertools import chain from itertools import chain
import numpy as np import numpy as np
...@@ -16,7 +17,6 @@ from pytensor.graph.type import HasShape ...@@ -16,7 +17,6 @@ from pytensor.graph.type import HasShape
from pytensor.graph.utils import MissingInputError, TestValueError from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until from pytensor.scan.utils import expand_empty, safe_new, until
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import minimum from pytensor.tensor.math import minimum
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
...@@ -143,31 +143,6 @@ def get_updates_and_outputs(ls): ...@@ -143,31 +143,6 @@ def get_updates_and_outputs(ls):
raise ValueError(error_msg) raise ValueError(error_msg)
def isNaN_or_Inf_or_None(x):
isNone = x is None
try:
isNaN = np.isnan(x)
isInf = np.isinf(x)
isStr = isinstance(x, str)
except Exception:
isNaN = False
isInf = False
isStr = False
if not isNaN and not isInf:
try:
val = get_underlying_scalar_constant_value(x)
isInf = np.isinf(val)
isNaN = np.isnan(val)
except Exception:
isNaN = False
isInf = False
if isinstance(x, Constant) and isinstance(x.data, str):
isStr = True
else:
isStr = False
return isNone or isNaN or isInf or isStr
def _manage_output_api_change(outputs, updates, return_updates): def _manage_output_api_change(outputs, updates, return_updates):
if return_updates: if return_updates:
warnings.warn( warnings.warn(
...@@ -505,7 +480,7 @@ def scan( ...@@ -505,7 +480,7 @@ def scan(
# This helper eagerly skips the Scan if n_steps is known to be 1 # This helper eagerly skips the Scan if n_steps is known to be 1
single_step_requested = False single_step_requested = False
if isinstance(n_steps, float | int): if isinstance(n_steps, int | float):
single_step_requested = n_steps == 1 single_step_requested = n_steps == 1
else: else:
try: try:
...@@ -676,33 +651,20 @@ def scan( ...@@ -676,33 +651,20 @@ def scan(
if nw_name is not None: if nw_name is not None:
nw_seq.name = nw_name nw_seq.name = nw_name
# Since we've added all sequences now we need to level them up based on if n_steps is None:
# n_steps or their different shapes if not scan_seqs:
lengths_vec = [seq.shape[0] for seq in scan_seqs]
if not isNaN_or_Inf_or_None(n_steps):
# ^ N_steps should also be considered
lengths_vec.append(pt.as_tensor(n_steps))
if len(lengths_vec) == 0:
# ^ No information about the number of steps
raise ValueError( raise ValueError(
"No information about the number of steps " "No information about the number of steps provided. "
"provided. Either provide a value for " "Either provide a value for n_steps argument of scan or provide an input sequence."
"n_steps argument of scan or provide an input "
"sequence"
) )
actual_n_steps = reduce(minimum, [seq.shape[0] for seq in scan_seqs])
# If the user has provided the number of steps, do that regardless ( and
# raise an error if the sequences are not long enough )
if isNaN_or_Inf_or_None(n_steps):
actual_n_steps = lengths_vec[0]
for contestant in lengths_vec[1:]:
actual_n_steps = minimum(actual_n_steps, contestant)
else: else:
actual_n_steps = pt.as_tensor(n_steps) actual_n_steps = pt.as_tensor(n_steps, dtype="int64", ndim=0)
# Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs] scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
# Conventions : # Conventions :
# mit_mot = multiple input taps, multiple output taps ( only provided # mit_mot = multiple input taps, multiple output taps ( only provided
# by the gradient function ) # by the gradient function )
...@@ -899,10 +861,8 @@ def scan( ...@@ -899,10 +861,8 @@ def scan(
raw_inner_outputs = fn(*args) raw_inner_outputs = fn(*args)
condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs) condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs)
if condition is not None: as_while = condition is not None
as_while = True
else:
as_while = False
## ##
# Step 3. Check if we actually need scan and remove it if we don't # Step 3. Check if we actually need scan and remove it if we don't
## ##
...@@ -934,7 +894,7 @@ def scan( ...@@ -934,7 +894,7 @@ def scan(
# extract still missing inputs (there still might be so) and add them # extract still missing inputs (there still might be so) and add them
# as non sequences at the end of our args # as non sequences at the end of our args
if condition is not None: if as_while:
outputs.append(condition) outputs.append(condition)
fake_nonseqs = [x.type() for x in non_seqs] fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = clone_replace( fake_outputs = clone_replace(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论