提交 0670ac2f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fuse consecutive Elemwise nodes with multiple clients

上级 d5cb23a5
......@@ -652,10 +652,10 @@ class Elemwise(OpenMPOp):
def prepare_node(self, node, storage_map, compute_map, impl):
# Postpone the ufunc building to the last minutes due to:
# - NumPy ufunc support only up to 31 inputs.
# - NumPy ufunc support only up to 32 operands (inputs and outputs)
# But our c code support more.
# - nfunc is reused for scipy and scipy is optional
if len(node.inputs) > 32 and self.ufunc and impl == "py":
if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py":
impl = "c"
if getattr(self, "nfunc_spec", None) and impl != "c":
......@@ -677,7 +677,7 @@ class Elemwise(OpenMPOp):
self.nfunc = module
if (
len(node.inputs) < 32
(len(node.inputs) + len(node.outputs)) <= 32
and (self.nfunc is None or self.scalar_op.nin != len(node.inputs))
and self.ufunc is None
and impl == "py"
......@@ -727,28 +727,18 @@ class Elemwise(OpenMPOp):
self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl)
def perform(self, node, inputs, output_storage):
if len(node.inputs) >= 32:
if (len(node.inputs) + len(node.outputs)) > 32:
# Some versions of NumPy will segfault, other will raise a
# ValueError, if the number of inputs to a ufunc is 32 or more.
# ValueError, if the number of operands in an ufunc is more than 32.
# In that case, the C version should be used, or Elemwise fusion
# should be disabled.
# FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage)
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
# Determine the shape of outputs
out_shape = []
for values in zip(*[input.shape for input in inputs]):
if any(v == 0 for v in values):
# All non-broadcasted dimensions should be zero
assert max(values) <= 1
out_shape.append(0)
else:
out_shape.append(max(values))
out_shape = tuple(out_shape)
ufunc_args = inputs
ufunc_kwargs = {}
# We supported in the past calling manually op.perform.
......
......@@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py
pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py
pytensor/tensor/rewriting/elemwise.py
pytensor/tensor/shape.py
pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py
......
......@@ -2,7 +2,7 @@ import numpy as np
import pytest
import pytensor.tensor as at
from pytensor.compile import UnusedInputError
from pytensor.compile import UnusedInputError, get_mode
from pytensor.compile.function import function, pfunc
from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.io import In
......@@ -200,7 +200,12 @@ class TestPfunc:
bval = np.arange(5)
b.set_value(bval, borrow=True)
bval = data_of(b)
f = pfunc([], [b_out], updates=[(b, (b_out + 3))], mode="FAST_RUN")
f = pfunc(
[],
[b_out],
updates=[(b, (b_out + 3))],
mode=get_mode("FAST_RUN").excluding("fusion"),
)
assert (f() == (np.arange(5) * 2)).all()
# because of the update
assert (b.get_value(borrow=True) == ((np.arange(5) * 2) + 3)).all()
......
import contextlib
import numpy as np
import pytest
......@@ -17,11 +15,14 @@ from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import Composite
from pytensor.tensor.basic import MakeVector
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import add
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import (
add,
bitwise_and,
bitwise_or,
cos,
......@@ -29,6 +30,7 @@ from pytensor.tensor.math import (
dot,
eq,
exp,
ge,
int_div,
invert,
iround,
......@@ -900,6 +902,72 @@ class TestFusion:
fxv * np.sin(fsv),
"float32",
),
# Multiple output cases # 72
(
(
# sum(logp)
at_sum(-((fx - fy) ** 2) / 2),
# grad(logp)
at.grad(at_sum(-((fx - fy) ** 2) / 2), wrt=fx),
),
(fx, fy),
(fxv, fyv),
3,
(
np.sum(-((fxv - fyv) ** 2) / 2),
-(fxv - fyv),
),
("float32", "float32"),
),
# Two Composite graphs that share the same input, but are split by
# a non-elemwise operation (Assert)
(
(
log(
ge(
assert_op(
at_abs(fx),
at_all(ge(at_abs(fx), 0)),
),
0,
)
),
),
(fx,),
(fxv,),
4,
(np.zeros_like(fxv),),
("float32",),
),
# Two subgraphs that share the same non-fuseable input, but are otherwise
# completely independent
(
(
true_div(
mul(
at_sum(fx + 5), # breaks fusion
exp(fx),
),
(fx + 5),
),
),
(fx,),
(fxv,),
4,
(np.sum(fxv + 5) * np.exp(fxv) / (fxv + 5),),
("float32",),
),
pytest.param(
(
(sin(exp(fx)), exp(sin(fx))),
(fx,),
(fxv,),
1,
(np.sin(np.exp(fxv)), np.exp(np.sin(fxv))),
("float32", "float32"),
),
marks=pytest.mark.xfail, # Not implemented yet
),
],
)
def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
......@@ -910,23 +978,34 @@ class TestFusion:
if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy]
if not isinstance(g, (tuple, list)):
g = (g,)
answer = (answer,)
out_dtype = (out_dtype,)
if self._shared is None:
f = function(list(sym_inputs), g, mode=self.mode)
for x in range(nb_repeat):
out = f(*val_inputs)
if not isinstance(out, list):
out = (out,)
else:
out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out")
assert out.dtype == g.dtype
f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode)
out = [
self._shared(np.zeros((5,) * g_.ndim, dtype=od), "out")
for g_, od in zip(g, out_dtype)
]
assert all(o.dtype == g_.dtype for o, g_ in zip(out, g))
f = function(sym_inputs, [], updates=list(zip(out, g)), mode=self.mode)
for x in range(nb_repeat):
f(*val_inputs)
out = out.get_value()
out = [o.get_value() for o in out]
atol = 1e-8
if out_dtype == "float32":
if any(o == "float32" for o in out_dtype):
atol = 1e-6
assert np.allclose(out, answer * nb_repeat, atol=atol)
for o, a in zip(out, answer):
np.testing.assert_allclose(o, a * nb_repeat, atol=atol)
topo = f.maker.fgraph.toposort()
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)]
......@@ -939,13 +1018,15 @@ class TestFusion:
# input of g,
# check that the number of input to the Composite
# Elemwise is ok
if len(set(g.owner.inputs)) == len(g.owner.inputs):
expected_len_sym_inputs = sum(
not isinstance(x, Constant) for x in topo_[0].inputs
)
assert expected_len_sym_inputs == len(sym_inputs)
for g_ in g:
if len(set(g_.owner.inputs)) == len(g_.owner.inputs):
expected_len_sym_inputs = sum(
not isinstance(x, Constant) for x in topo_[0].inputs
)
assert expected_len_sym_inputs == len(sym_inputs)
assert out_dtype == out.dtype
for od, o in zip(out_dtype, out):
assert od == o.dtype
def test_fusion_35_inputs(self):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
......@@ -1006,6 +1087,30 @@ class TestFusion:
for node in dlogp.maker.fgraph.toposort()
)
@pytest.mark.xfail(reason="Fails due to #1244")
def test_add_mul_fusion_precedence(self):
"""Test that additions and multiplications are "fused together" before
a `Composite` `Op` is introduced. This fusion is done by canonicalization
"""
x, y, z = vectors("x", "y", "z")
out = log((x + y + z) / (x * y * z))
f = pytensor.function([x, y, z], out, mode=self.mode)
# There should be a single Composite Op
nodes = f.maker.fgraph.apply_nodes
assert len(nodes) == 1
(node,) = nodes
assert isinstance(node.op, Elemwise)
scalar_op = node.op.scalar_op
assert isinstance(scalar_op, Composite)
assert [node.op for node in scalar_op.fgraph.toposort()] == [
# There should be a single mul
aes.mul,
# There should be a single add
aes.add,
aes.true_div,
aes.log,
]
def test_add_mul_fusion_inplace(self):
x, y, z = dmatrices("xyz")
out = dot(x, y) + x + y + z
......@@ -1082,11 +1187,8 @@ class TestFusion:
@pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]])
def test_test_values(self, test_value):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
"""Make sure that `local_elemwise_fusion_op` uses test values correctly
when they have zero dimensions.
"""
x, y, z = dmatrices("xyz")
......@@ -1094,27 +1196,20 @@ class TestFusion:
y.tag.test_value = test_value
z.tag.test_value = test_value
if test_value.size == 0:
cm = pytest.raises(ValueError)
else:
cm = contextlib.suppress()
with config.change_flags(
compute_test_value="raise", compute_test_value_opt="raise"
):
out = x * y + z
with cm:
f = function([x, y, z], out, mode=self.mode)
f = function([x, y, z], out, mode=self.mode)
if test_value.size != 0:
# Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1
# Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1
x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs
assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
)
assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value,
np.full_like(test_value, 2.0),
)
@pytest.mark.parametrize("linker", ["cvm", "py"])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
......@@ -1227,6 +1322,26 @@ class TestFusion:
aes.mul,
}
def test_multiple_outputs_fused_root_elemwise(self):
"""Test that a root elemwise output (single layer) is reused when
there is another fused output"""
# By default, we do not introduce Composite for single layers of Elemwise
x = at.vector("x")
out1 = at.cos(x)
f = pytensor.function([x], out1, mode=self.mode)
nodes = tuple(f.maker.fgraph.apply_nodes)
assert len(nodes) == 1
assert isinstance(nodes[0].op.scalar_op, aes.Cos)
# However, when it can be composed with another output, we should not
# compute that root Elemwise twice
out2 = at.log(out1)
f = pytensor.function([x], [out1, out2], mode=self.mode)
nodes = tuple(f.maker.fgraph.apply_nodes)
assert len(nodes) == 1
assert isinstance(nodes[0].op.scalar_op, Composite)
class TimesN(aes.basic.UnaryScalarOp):
"""
......
......@@ -887,10 +887,9 @@ class TestLocalSubtensorLift:
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle)
assert isinstance(prog[1].op.scalar_op, aes.Composite) # Composite{add,exp}
assert prog[2].op == add or prog[3].op == add
# first subtensor
assert isinstance(prog[2].op, Subtensor) or isinstance(prog[3].op, Subtensor)
assert len(prog) == 4
assert isinstance(prog[2].op, Subtensor)
assert len(prog) == 3
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
def test_basic_7(self):
......
......@@ -273,8 +273,7 @@ def test_debugprint():
s = s.getvalue()
exp_res = dedent(
r"""
Elemwise{Composite{(i0 + (i1 - i2))}} 4
|A
Elemwise{Composite{(i2 + (i0 - i1))}} 4
|InplaceDimShuffle{x,0} v={0: [0]} 3
| |CGemv{inplace} d={0: [0]} 2
| |AllocEmpty{dtype='float64'} 1
......@@ -285,6 +284,7 @@ def test_debugprint():
| |<TensorType(float64, (?,))>
| |TensorConstant{0.0}
|D
|A
"""
).lstrip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论