提交 d7edde21 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix constant number of steps reduction in ScanSaveMem rewrite

isinstance(..., int) does not recognize numpy.integers Also remove maxsize logic
上级 b27c59d1
...@@ -29,7 +29,7 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -29,7 +29,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
# Extract JAX scan inputs # Extract JAX scan inputs
outer_inputs = list(outer_inputs) outer_inputs = list(outer_inputs)
n_steps = outer_inputs[0] # JAX `length` n_steps = outer_inputs[0] # JAX `length`
seqs = op.outer_seqs(outer_inputs) # JAX `xs` seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs`
mit_sot_init = [] mit_sot_init = []
for tap, seq in zip( for tap, seq in zip(
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import copy import copy
import dataclasses import dataclasses
from itertools import chain from itertools import chain
from sys import maxsize
from typing import cast from typing import cast
import numpy as np import numpy as np
...@@ -1351,10 +1350,9 @@ def scan_save_mem(fgraph, node): ...@@ -1351,10 +1350,9 @@ def scan_save_mem(fgraph, node):
get_scalar_constant_value(cf_slice[0], raise_not_constant=False) get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
+ 1 + 1
) )
if stop == maxsize or stop == get_scalar_constant_value( if stop == get_scalar_constant_value(length, raise_not_constant=False):
length, raise_not_constant=False
):
stop = None stop = None
global_nsteps = None
else: else:
# there is a **gotcha** here ! Namely, scan returns an # there is a **gotcha** here ! Namely, scan returns an
# array that contains the initial state of the output # array that contains the initial state of the output
...@@ -1366,21 +1364,13 @@ def scan_save_mem(fgraph, node): ...@@ -1366,21 +1364,13 @@ def scan_save_mem(fgraph, node):
# initial state) # initial state)
stop = stop - init_l[i] stop = stop - init_l[i]
# 2.3.3 we might get away with less number of steps # 2.3.3 we might get away with fewer steps
if stop is not None and global_nsteps is not None: if stop is not None and global_nsteps is not None:
# yes if it is a tensor # yes if it is a tensor
if isinstance(stop, Variable): if isinstance(stop, Variable):
global_nsteps["sym"] += [stop] global_nsteps["sym"] += [stop]
# not if it is maxsize elif isinstance(stop, int | np.integer):
elif isinstance(stop, int) and stop == maxsize: global_nsteps["real"] = max(global_nsteps["real"], stop)
global_nsteps = None
# yes if it is a int k, 0 < k < maxsize
elif isinstance(stop, int) and global_nsteps["real"] < stop:
global_nsteps["real"] = stop
# yes if it is a int k, 0 < k < maxsize
elif isinstance(stop, int) and stop > 0:
pass
# not otherwise
else: else:
global_nsteps = None global_nsteps = None
...@@ -1703,10 +1693,7 @@ def scan_save_mem(fgraph, node): ...@@ -1703,10 +1693,7 @@ def scan_save_mem(fgraph, node):
- init_l[pos] - init_l[pos]
+ store_steps[pos] + store_steps[pos]
) )
if ( if cnf_slice[0].stop is not None:
cnf_slice[0].stop is not None
and cnf_slice[0].stop != maxsize
):
stop = ( stop = (
cnf_slice[0].stop cnf_slice[0].stop
- nw_steps - nw_steps
......
...@@ -9,7 +9,7 @@ from pytensor.compile.io import In ...@@ -9,7 +9,7 @@ from pytensor.compile.io import In
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad, jacobian from pytensor.gradient import grad, jacobian
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
...@@ -1208,7 +1208,7 @@ class TestScanInplaceOptimizer: ...@@ -1208,7 +1208,7 @@ class TestScanInplaceOptimizer:
class TestSaveMem: class TestSaveMem:
mode = get_default_mode().including("scan_save_mem", "scan_save_mem") mode = get_default_mode().including("scan_save_mem")
def test_save_mem(self): def test_save_mem(self):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
...@@ -1295,11 +1295,27 @@ class TestSaveMem: ...@@ -1295,11 +1295,27 @@ class TestSaveMem:
[x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]], [x1[:2], x2[4], x3[idx], x4[:idx], x5[-10], x6[-jdx], x7[:-jdx]],
updates=updates, updates=updates,
allow_input_downcast=True, allow_input_downcast=True,
mode=self.mode, mode=self.mode.excluding("scan_push_out_seq"),
) )
# Check we actually have a Scan in the compiled function
[scan_node] = [
node for node in f2.maker.fgraph.toposort() if isinstance(node.op, Scan)
]
# get random initial values # get random initial values
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
v_u = rng.uniform(-5.0, 5.0, size=(20,)) v_u = rng.uniform(-5.0, 5.0, size=(20,)).astype(u.type.dtype)
# Check the number of steps is actually reduced from 20
n_steps = scan_node.inputs[0]
n_steps_fn = pytensor.function(
[u, idx, jdx], n_steps, accept_inplace=True, on_unused_input="ignore"
)
assert n_steps_fn(u=v_u, idx=3, jdx=15) == 11 # x5[const=-10] requires 11 steps
assert n_steps_fn(u=v_u, idx=3, jdx=3) == 18 # x6[jdx=-3] requires 18 steps
assert n_steps_fn(u=v_u, idx=16, jdx=15) == 17 # x3[idx=16] requires 17 steps
assert n_steps_fn(u=v_u, idx=-5, jdx=15) == 16 # x3[idx=-5] requires 16 steps
assert n_steps_fn(u=v_u, idx=19, jdx=15) == 20 # x3[idx=19] requires 20 steps
# compute the output in numpy # compute the output in numpy
tx1, tx2, tx3, tx4, tx5, tx6, tx7 = f2(v_u, 3, 15) tx1, tx2, tx3, tx4, tx5, tx6, tx7 = f2(v_u, 3, 15)
...@@ -1312,6 +1328,49 @@ class TestSaveMem: ...@@ -1312,6 +1328,49 @@ class TestSaveMem:
utt.assert_allclose(tx6, v_u[-15] + 6.0) utt.assert_allclose(tx6, v_u[-15] + 6.0)
utt.assert_allclose(tx7, v_u[:-15] + 7.0) utt.assert_allclose(tx7, v_u[:-15] + 7.0)
def test_save_mem_reduced_number_of_steps_constant(self):
x0 = pt.scalar("x0")
xs, _ = scan(
lambda xtm1: xtm1 + 1,
outputs_info=[x0],
n_steps=10,
)
fn = function([x0], xs[:5], mode=self.mode)
[scan_node] = [
node for node in fn.maker.fgraph.toposort() if isinstance(node.op, Scan)
]
n_steps = scan_node.inputs[0]
assert isinstance(n_steps, Constant) and n_steps.data == 5
np.testing.assert_allclose(fn(0), np.arange(1, 11)[:5])
def test_save_mem_cannot_reduce_constant_number_of_steps(self):
x0 = pt.scalar("x0")
[xs, ys], _ = scan(
lambda xtm1, ytm1: (xtm1 + 1, ytm1 - 1),
outputs_info=[x0, x0],
n_steps=10,
)
# Because of ys[-1] we need all the steps!
fn = function([x0], [xs[:5], ys[-1]], mode=self.mode)
[scan_node] = [
node for node in fn.maker.fgraph.toposort() if isinstance(node.op, Scan)
]
n_steps = scan_node.inputs[0]
assert isinstance(n_steps, Constant) and n_steps.data == 10
res_x, res_y = fn(0)
np.testing.assert_allclose(
res_x,
np.arange(1, 11)[:5],
)
np.testing.assert_allclose(
res_y,
-np.arange(1, 11)[-1],
)
def test_save_mem_store_steps(self): def test_save_mem_store_steps(self):
def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1): def f_rnn(u_t, x1_tm1, x1_tm3, x2_tm1, x3tm2, x3_tm1, x4_tm1):
return ( return (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论