提交 91d3b7c0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Do not merge while scans with different until condition

The rewrite did not check if nominal variables in the graph of the until condition corresponded to the equivalent outer variables
上级 eb552eef
...@@ -17,7 +17,9 @@ from pytensor.configdefaults import config ...@@ -17,7 +17,9 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import ( from pytensor.graph.basic import (
Apply, Apply,
Constant, Constant,
NominalVariable,
Variable, Variable,
ancestors,
apply_depends_on, apply_depends_on,
equal_computations, equal_computations,
graph_inputs, graph_inputs,
...@@ -1950,11 +1952,13 @@ class ScanMerge(GraphRewriter): ...@@ -1950,11 +1952,13 @@ class ScanMerge(GraphRewriter):
Questionable, we should also consider profile ? Questionable, we should also consider profile ?
""" """
rep = set_nodes[0] op = node.op
rep_node = set_nodes[0]
rep_op = rep_node.op
if ( if (
rep.op.info.as_while != node.op.info.as_while op.info.as_while != rep_op.info.as_while
or node.op.truncate_gradient != rep.op.truncate_gradient or op.truncate_gradient != rep_op.truncate_gradient
or node.op.mode != rep.op.mode or op.mode != rep_op.mode
): ):
return False return False
...@@ -1964,7 +1968,7 @@ class ScanMerge(GraphRewriter): ...@@ -1964,7 +1968,7 @@ class ScanMerge(GraphRewriter):
except NotScalarConstantError: except NotScalarConstantError:
pass pass
rep_nsteps = rep.inputs[0] rep_nsteps = rep_node.inputs[0]
try: try:
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps)) rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
except NotScalarConstantError: except NotScalarConstantError:
...@@ -1978,14 +1982,41 @@ class ScanMerge(GraphRewriter): ...@@ -1978,14 +1982,41 @@ class ScanMerge(GraphRewriter):
if apply_depends_on(node, nd) or apply_depends_on(nd, node): if apply_depends_on(node, nd) or apply_depends_on(nd, node):
return False return False
if not node.op.info.as_while: if not op.info.as_while:
return True return True
cond = node.op.inner_outputs[-1]
rep_cond = rep.op.inner_outputs[-1] # We need to check the while conditions are identical
return equal_computations( conds = [op.inner_outputs[-1]]
[cond], [rep_cond], node.op.inner_inputs, rep.op.inner_inputs rep_conds = [rep_op.inner_outputs[-1]]
if not equal_computations(
conds, rep_conds, op.inner_inputs, rep_op.inner_inputs
):
return False
# If they depend on inner inputs we need to check for equivalence on the respective outer inputs
nominal_inputs = [a for a in ancestors(conds) if isinstance(a, NominalVariable)]
if not nominal_inputs:
return True
rep_nominal_inputs = [
a for a in ancestors(rep_conds) if isinstance(a, NominalVariable)
]
conds = []
rep_conds = []
mapping = op.get_oinp_iinp_iout_oout_mappings()["outer_inp_from_inner_inp"]
rep_mapping = rep_op.get_oinp_iinp_iout_oout_mappings()[
"outer_inp_from_inner_inp"
]
inner_inputs = op.inner_inputs
rep_inner_inputs = rep_op.inner_inputs
for nominal_input, rep_nominal_input in zip(nominal_inputs, rep_nominal_inputs):
conds.append(node.inputs[mapping[inner_inputs.index(nominal_input)]])
rep_conds.append(
rep_node.inputs[rep_mapping[rep_inner_inputs.index(rep_nominal_input)]]
) )
return equal_computations(conds, rep_conds)
def apply(self, fgraph): def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort # Collect all scan nodes ordered according to toposort
scan_nodes = [nd for nd in fgraph.toposort() if isinstance(nd.op, Scan)] scan_nodes = [nd for nd in fgraph.toposort() if isinstance(nd.op, Scan)]
......
...@@ -15,6 +15,7 @@ from pytensor.graph.replace import clone_replace ...@@ -15,6 +15,7 @@ from pytensor.graph.replace import clone_replace
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge from pytensor.scan.rewriting import ScanInplaceOptimizer, ScanMerge
from pytensor.scan.utils import until from pytensor.scan.utils import until
from pytensor.tensor import stack
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Dot, dot, sigmoid from pytensor.tensor.math import Dot, dot, sigmoid
...@@ -796,7 +797,13 @@ class TestPushOutAddScan: ...@@ -796,7 +797,13 @@ class TestPushOutAddScan:
class TestScanMerge: class TestScanMerge:
mode = get_default_mode().including("scan") mode = get_default_mode().including("scan").excluding("scan_pushout_seqs_ops")
@staticmethod
def count_scans(fn):
nodes = fn.maker.fgraph.apply_nodes
scans = [node for node in nodes if isinstance(node.op, Scan)]
return len(scans)
def test_basic(self): def test_basic(self):
x = vector() x = vector()
...@@ -808,56 +815,38 @@ class TestScanMerge: ...@@ -808,56 +815,38 @@ class TestScanMerge:
sx, upx = scan(sum, sequences=[x]) sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[y]) sy, upy = scan(sum, sequences=[y])
f = function( f = function([x, y], [sx, sy], mode=self.mode)
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops") assert self.count_scans(f) == 2
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
sx, upx = scan(sum, sequences=[x], n_steps=2) sx, upx = scan(sum, sequences=[x], n_steps=2)
sy, upy = scan(sum, sequences=[y], n_steps=3) sy, upy = scan(sum, sequences=[y], n_steps=3)
f = function( f = function([x, y], [sx, sy], mode=self.mode)
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops") assert self.count_scans(f) == 2
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
sx, upx = scan(sum, sequences=[x], n_steps=4) sx, upx = scan(sum, sequences=[x], n_steps=4)
sy, upy = scan(sum, sequences=[y], n_steps=4) sy, upy = scan(sum, sequences=[y], n_steps=4)
f = function( f = function([x, y], [sx, sy], mode=self.mode)
[x, y], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops") assert self.count_scans(f) == 1
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
sx, upx = scan(sum, sequences=[x]) sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x]) sy, upy = scan(sum, sequences=[x])
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")) f = function([x], [sx, sy], mode=self.mode)
topo = f.maker.fgraph.toposort() assert self.count_scans(f) == 1
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
sx, upx = scan(sum, sequences=[x]) sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE") sy, upy = scan(sum, sequences=[x], mode="FAST_COMPILE")
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")) f = function([x], [sx, sy], mode=self.mode)
topo = f.maker.fgraph.toposort() assert self.count_scans(f) == 1
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 1
sx, upx = scan(sum, sequences=[x]) sx, upx = scan(sum, sequences=[x])
sy, upy = scan(sum, sequences=[x], truncate_gradient=1) sy, upy = scan(sum, sequences=[x], truncate_gradient=1)
f = function([x], [sx, sy], mode=self.mode.excluding("scan_pushout_seqs_ops")) f = function([x], [sx, sy], mode=self.mode)
topo = f.maker.fgraph.toposort() assert self.count_scans(f) == 2
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
def test_three_scans(self): def test_three_scans(self):
r""" r"""
...@@ -877,12 +866,8 @@ class TestScanMerge: ...@@ -877,12 +866,8 @@ class TestScanMerge:
sy, upy = scan(sum, sequences=[2 * y + 2], n_steps=4, name="Y") sy, upy = scan(sum, sequences=[2 * y + 2], n_steps=4, name="Y")
sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z") sz, upz = scan(sum, sequences=[sx], n_steps=4, name="Z")
f = function( f = function([x, y], [sy, sz], mode=self.mode)
[x, y], [sy, sz], mode=self.mode.excluding("scan_pushout_seqs_ops") assert self.count_scans(f) == 2
)
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(n.op, Scan)]
assert len(scans) == 2
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
x_val = rng.uniform(size=(4,)).astype(config.floatX) x_val = rng.uniform(size=(4,)).astype(config.floatX)
...@@ -913,6 +898,112 @@ class TestScanMerge: ...@@ -913,6 +898,112 @@ class TestScanMerge:
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2]) assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1]) assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
@config.change_flags(cxx="") # Just for faster compilation
def test_while_scan(self):
x = vector("x")
y = vector("y")
def add(s):
return s + 1, until(s > 5)
def sub(s):
return s - 1, until(s > 5)
def sub_alt(s):
return s - 1, until(s > 4)
sx, upx = scan(add, sequences=[x])
sy, upy = scan(sub, sequences=[y])
f = function([x, y], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
sx, upx = scan(add, sequences=[x])
sy, upy = scan(sub, sequences=[x])
f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 1
sx, upx = scan(add, sequences=[x])
sy, upy = scan(sub_alt, sequences=[x])
f = function([x], [sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
@config.change_flags(cxx="") # Just for faster compilation
def test_while_scan_nominal_dependency(self):
"""Test case where condition depends on nominal variables.
This is a regression test for #509
"""
c1 = scalar("c1")
c2 = scalar("c2")
x = vector("x", shape=(5,))
y = vector("y", shape=(5,))
z = vector("z", shape=(5,))
def add(s1, s2, const):
return s1 + 1, until(s2 > const)
def sub(s1, s2, const):
return s1 - 1, until(s2 > const)
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, -z], non_sequences=[c1])
f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
assert self.count_scans(f) == 2
res_sx, res_sy = f(
x=[0, 0, 0, 0, 0],
y=[0, 0, 0, 0, 0],
z=[0, 1, 2, 3, 4],
c1=0,
)
np.testing.assert_array_equal(res_sx, [1, 1])
np.testing.assert_array_equal(res_sy, [-1, -1, -1, -1, -1])
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c2])
f = pytensor.function(
inputs=[x, y, z, c1, c2], outputs=[sx, sy], mode=self.mode
)
assert self.count_scans(f) == 2
res_sx, res_sy = f(
x=[0, 0, 0, 0, 0],
y=[0, 0, 0, 0, 0],
z=[0, 1, 2, 3, 4],
c1=3,
c2=1,
)
np.testing.assert_array_equal(res_sx, [1, 1, 1, 1, 1])
np.testing.assert_array_equal(res_sy, [-1, -1, -1])
sx, _ = scan(add, sequences=[x, z], non_sequences=[c1])
sy, _ = scan(sub, sequences=[y, z], non_sequences=[c1])
f = pytensor.function(inputs=[x, y, z, c1], outputs=[sx, sy], mode=self.mode)
assert self.count_scans(f) == 1
def nested_scan(c, x, z):
sx, _ = scan(add, sequences=[x, z], non_sequences=[c])
sy, _ = scan(sub, sequences=[x, z], non_sequences=[c])
return sx.sum() + sy.sum()
sz, _ = scan(
nested_scan,
sequences=[stack([c1, c2])],
non_sequences=[x, z],
mode=self.mode,
)
f = pytensor.function(inputs=[x, z, c1, c2], outputs=sz, mode=mode)
[scan_node] = [
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
inner_f = scan_node.op.fn
assert self.count_scans(inner_f) == 1
class TestScanInplaceOptimizer: class TestScanInplaceOptimizer:
mode = get_default_mode().including("scan_make_inplace", "inplace") mode = get_default_mode().including("scan_make_inplace", "inplace")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论