提交 d7edde21 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix constant number of steps reduction in ScanSaveMem rewrite

isinstance(..., int) does not recognize numpy.integers Also remove maxsize logic
上级 b27c59d1
......@@ -29,7 +29,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
# Extract JAX scan inputs
outer_inputs = list(outer_inputs)
n_steps = outer_inputs[0] # JAX `length`
seqs = op.outer_seqs(outer_inputs) # JAX `xs`
seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs`
mit_sot_init = []
for tap, seq in zip(
......
......@@ -3,7 +3,6 @@
import copy
import dataclasses
from itertools import chain
from sys import maxsize
from typing import cast
import numpy as np
......@@ -1351,10 +1350,9 @@ def scan_save_mem(fgraph, node):
get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
+ 1
)
if stop == maxsize or stop == get_scalar_constant_value(
length, raise_not_constant=False
):
if stop == get_scalar_constant_value(length, raise_not_constant=False):
stop = None
global_nsteps = None
else:
# there is a **gotcha** here ! Namely, scan returns an
# array that contains the initial state of the output
......@@ -1366,21 +1364,13 @@ def scan_save_mem(fgraph, node):
# initial state)
stop = stop - init_l[i]
# 2.3.3 we might get away with less number of steps
# 2.3.3 we might get away with fewer steps
if stop is not None and global_nsteps is not None:
# yes if it is a tensor
if isinstance(stop, Variable):
global_nsteps["sym"] += [stop]
# not if it is maxsize
elif isinstance(stop, int) and stop == maxsize:
global_nsteps = None
# yes if it is a int k, 0 < k < maxsize
elif isinstance(stop, int) and global_nsteps["real"] < stop:
global_nsteps["real"] = stop
# yes if it is a int k, 0 < k < maxsize
elif isinstance(stop, int) and stop > 0:
pass
# not otherwise
elif isinstance(stop, int | np.integer):
global_nsteps["real"] = max(global_nsteps["real"], stop)
else:
global_nsteps = None
......@@ -1703,10 +1693,7 @@ def scan_save_mem(fgraph, node):
- init_l[pos]
+ store_steps[pos]
)
if (
cnf_slice[0].stop is not None
and cnf_slice[0].stop != maxsize
):
if cnf_slice[0].stop is not None:
stop = (
cnf_slice[0].stop
- nw_steps
......
......@@ -9,7 +9,7 @@ from pytensor.compile.io import In
from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad, jacobian
from pytensor.graph.basic import equal_computations
from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.scan.op import Scan
......@@ -1208,7 +1208,7 @@ class TestScanInplaceOptimizer:
class TestSaveMem:
mode = get_default_mode().including("scan_save_mem", "scan_save_mem")
mode = get_default_mode().including("scan_save_mem")
def test_save_mem(self):
rng = np.random.default_rng(utt.fetch_seed())
......@@ -1295,11 +1295,27 @@ class TestSaveMem:
[x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]],
updates=updates,
allow_input_downcast=True,
mode=self.mode,
mode=self.mode.excluding("scan_push_out_seq"),
)
# Check we actually have a Scan in the compiled function
[scan_node] = [
node for node in f2.maker.fgraph.toposort() if isinstance(node.op, Scan)
]
# get random initial values
rng = np.random.default_rng(utt.fetch_seed())
v_u = rng.uniform(-5.0, 5.0, size=(20,))
v_u = rng.uniform(-5.0, 5.0, size=(20,)).astype(u.type.dtype)
# Check the number of steps is actually reduced from 20
n_steps = scan_node.inputs[0]
n_steps_fn = pytensor.function(
[u, idx, jdx], n_steps, accept_inplace=True, on_unused_input="ignore"
)
assert n_steps_fn(u=v_u, idx=3, jdx=15) == 11 # x5[const=-10] requires 11 steps
assert n_steps_fn(u=v_u, idx=3, jdx=3) == 18 # x6[jdx=-3] requires 18 steps
assert n_steps_fn(u=v_u, idx=16, jdx=15) == 17 # x3[idx=16] requires 17 steps
assert n_steps_fn(u=v_u, idx=-5, jdx=15) == 16 # x3[idx=-5] requires 16 steps
assert n_steps_fn(u=v_u, idx=19, jdx=15) == 20 # x3[idx=19] requires 20 steps
# compute the output in numpy
tx1, tx2, tx3, tx4, tx5, tx6, tx7 = f2(v_u, 3, 15)
......@@ -1312,6 +1328,49 @@ class TestSaveMem:
utt.assert_allclose(tx6, v_u[-15] + 6.0)
utt.assert_allclose(tx7, v_u[:-15] + 7.0)
def test_save_mem_reduced_number_of_steps_constant(self):
x0 = pt.scalar("x0")
xs, _ = scan(
lambda xtm1: xtm1 + 1,
outputs_info=[x0],
n_steps=10,
)
fn = function([x0], xs[:5], mode=self.mode)
[scan_node] = [
node for node in fn.maker.fgraph.toposort() if isinstance(node.op, Scan)
]
n_steps = scan_node.inputs[0]
assert isinstance(n_steps, Constant) and n_steps.data == 5
np.testing.assert_allclose(fn(0), np.arange(1, 11)[:5])
def test_save_mem_cannot_reduce_constant_number_of_steps(self):
x0 = pt.scalar("x0")
[xs, ys], _ = scan(
lambda xtm1, ytm1: (xtm1 + 1, ytm1 - 1),
outputs_info=[x0, x0],
n_steps=10,
)
# Because of ys[-1] we need all the steps!
fn = function([x0], [xs[:5], ys[-1]], mode=self.mode)
[scan_node] = [
node for node in fn.maker.fgraph.toposort() if isinstance(node.op, Scan)
]
n_steps = scan_node.inputs[0]
assert isinstance(n_steps, Constant) and n_steps.data == 10
res_x, res_y = fn(0)
np.testing.assert_allclose(
res_x,
np.arange(1, 11)[:5],
)
np.testing.assert_allclose(
res_y,
-np.arange(1, 11)[-1],
)
def test_save_mem_store_steps(self):
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
return (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论