提交 91d3b7c0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not merge while scans with different until condition

The rewrite did not check if nominal variables in the graph of the until condition corresponded to the equivalent outer variables
上级 eb552eef
......@@ -17,7 +17,9 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import (
Apply,
Constant,
NominalVariable,
Variable,
ancestors,
apply_depends_on,
equal_computations,
graph_inputs,
......@@ -1950,11 +1952,13 @@ class ScanMerge(GraphRewriter):
Questionable, we should also consider profile ?
"""
rep = set_nodes[0]
op = node.op
rep_node = set_nodes[0]
rep_op = rep_node.op
if (
rep.op.info.as_while != node.op.info.as_while
or node.op.truncate_gradient != rep.op.truncate_gradient
or node.op.mode != rep.op.mode
op.info.as_while != rep_op.info.as_while
or op.truncate_gradient != rep_op.truncate_gradient
or op.mode != rep_op.mode
):
return False
......@@ -1964,7 +1968,7 @@ class ScanMerge(GraphRewriter):
except NotScalarConstantError:
pass
rep_nsteps = rep.inputs[0]
rep_nsteps = rep_node.inputs[0]
try:
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
except NotScalarConstantError:
......@@ -1978,13 +1982,40 @@ class ScanMerge(GraphRewriter):
if apply_depends_on(node, nd) or apply_depends_on(nd, node):
return False
if not node.op.info.as_while:
if not op.info.as_while:
return True
cond = node.op.inner_outputs[-1]
rep_cond = rep.op.inner_outputs[-1]
return equal_computations(
[cond], [rep_cond], node.op.inner_inputs, rep.op.inner_inputs
)
# We need to check the while conditions are identical
conds = [op.inner_outputs[-1]]
rep_conds = [rep_op.inner_outputs[-1]]
if not equal_computations(
conds, rep_conds, op.inner_inputs, rep_op.inner_inputs
):
return False
# If they depend on inner inputs we need to check for equivalence on the respective outer inputs
nominal_inputs = [a for a in ancestors(conds) if isinstance(a, NominalVariable)]
if not nominal_inputs:
return True
rep_nominal_inputs = [
a for a in ancestors(rep_conds) if isinstance(a, NominalVariable)
]
conds = []
rep_conds = []
mapping = op.get_oinp_iinp_iout_oout_mappings()["outer_inp_from_inner_inp"]
rep_mapping = rep_op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp"
]
inner_inputs = op.inner_inputs
rep_inner_inputs = rep_op.inner_inputs
for nominal_input, rep_nominal_input in zip(nominal_inputs, rep_nominal_inputs):
conds.append(node.inputs[mapping[inner_inputs.index(nominal_input)]])
rep_conds.append(
rep_node.inputs[rep_mapping[rep_inner_inputs.index(rep_nominal_input)]]
)
return equal_computations(conds, rep_conds)
def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort
......
......@@ -15,6 +15,7 @@ from pytensor.graph.replace import clone_replace
from pytensor.scan.op import Scan
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
from pytensor.scan.utils import until
from pytensor.tensor import stack
from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Dot, dot, sigmoid
......@@ -796,7 +797,13 @@ class TestPushOutAddScan:
class TestScanMerge:
mode = get_default_mode().including("scan")
mode = get_default_mode().including("scan").excluding("scan_pushout_seqs_ops")
@staticmethod
def count_scans(fn):
nodes = fn.maker.fgraph.apply_nodes
scans = [node for node in nodes if isinstance(node.op, Scan)]
return len(scans)
def test_basic(self):
x = vector()
......@@ -808,56 +815,38 @@ class TestScanMerge:
sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[y])
f = function(
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
f = function([x, y], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
sx, upx = scan(sum, sequences=[x], n_steps=2)
sy, upy = scan(sum, sequences=[y], n_steps=3)
f = function(
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
f = function([x, y], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
sx, upx = scan(sum, sequences=[x], n_steps=4)
sy, upy = scan(sum, sequences=[y], n_steps=4)
f = function(
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
f = function([x, y], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 1
sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x])
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 1
sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE")
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 1
sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x], truncate_gradient=1)
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops"))
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
def test_three_scans(self):
r"""
......@@ -877,12 +866,8 @@ class TestScanMerge:
sy, upy = scan(sum, sequences=[2 * y + 2], n_steps=4, name="Y")
sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z")
f = function(
[x, y], [sy, sz], mode=self.mode.excluding("scan_pushout_seqs_ops")
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
f = function([x, y], [sy, sz], mode=self.mode)
assert self.count_scans(f) == 2
rng = np.random.default_rng(utt.fetch_seed())
x_val = rng.uniform(size=(4,)).astype(config.floatX)
......@@ -913,6 +898,112 @@ class TestScanMerge:
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
@config.change_flags(cxx="") # Just for faster compilation
def test_while_scan(self):
x = vector("x")
y = vector("y")
def add(s):
return s + 1, until(s > 5)
def sub(s):
return s - 1, until(s > 5)
def sub_alt(s):
return s - 1, until(s > 4)
sx, upx = scan(add, sequences=[x])
sy, upy = scan(sub, sequences=[y])
f = function([x, y], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
sx, upx = scan(add, sequences=[x])
sy, upy = scan(sub, sequences=[x])
f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 1
sx, upx = scan(add, sequences=[x])
sy, upy = scan(sub_alt, sequences=[x])
f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
@config.change_flags(cxx="") # Just for faster compilation
def test_while_scan_nominal_dependency(self):
"""Test case where condition depends on nominal variables.
This is a regression test for #509
"""
c1 = scalar("c1")
c2 = scalar("c2")
x = vector("x", shape=(5,))
y = vector("y", shape=(5,))
z = vector("z", shape=(5,))
def add(s1, s2, const):
return s1 + 1, until(s2 > const)
def sub(s1, s2, const):
return s1 - 1, until(s2 > const)
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, -z], non_sequences=[c1])
f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
res_sx, res_sy = f(
x=[0, 0, 0, 0, 0],
y=[0, 0, 0, 0, 0],
z=[0, 1, 2, 3, 4],
c1=0,
)
np.testing.assert_array_equal(res_sx, [1, 1])
np.testing.assert_array_equal(res_sy, [-1, -1, -1, -1, -1])
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c2])
f = pytensor.function(
inputs=[x, y, z, c1, c2], outputs=[sx, sy], mode=self.mode
)
assert self.count_scans(f) == 2
res_sx, res_sy = f(
x=[0, 0, 0, 0, 0],
y=[0, 0, 0, 0, 0],
z=[0, 1, 2, 3, 4],
c1=3,
c2=1,
)
np.testing.assert_array_equal(res_sx, [1, 1, 1, 1, 1])
np.testing.assert_array_equal(res_sy, [-1, -1, -1])
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c1])
f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
assert self.count_scans(f) == 1
def nested_scan(c, x, z):
sx, _ = scan(add, sequences=[x, z], non_sequences=[c])
sy, _ = scan(sub, sequences=[x, z], non_sequences=[c])
return sx.sum() + sy.sum()
sz, _ = scan(
nested_scan,
sequences=[stack([c1, c2])],
non_sequences=[x, z],
mode=self.mode,
)
f = pytensor.function(inputs=[x, z, c1, c2], outputs=sz, mode=mode)
[scan_node] = [
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
inner_f = scan_node.op.fn
assert self.count_scans(inner_f) == 1
class TestScanInplaceOptimizer:
mode = get_default_mode().including("scan_make_inplace", "inplace")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论