提交 0670ac2f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fuse consecutive Elemwise nodes with multiple clients

上级 d5cb23a5
...@@ -652,10 +652,10 @@ class Elemwise(OpenMPOp): ...@@ -652,10 +652,10 @@ class Elemwise(OpenMPOp):
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
# Postpone the ufunc building to the last minutes due to: # Postpone the ufunc building to the last minutes due to:
# - NumPy ufunc support only up to 31 inputs. # - NumPy ufunc support only up to 32 operands (inputs and outputs)
# But our c code support more. # But our c code support more.
# - nfunc is reused for scipy and scipy is optional # - nfunc is reused for scipy and scipy is optional
if len(node.inputs) > 32 and self.ufunc and impl == "py": if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py":
impl = "c" impl = "c"
if getattr(self, "nfunc_spec", None) and impl != "c": if getattr(self, "nfunc_spec", None) and impl != "c":
...@@ -677,7 +677,7 @@ class Elemwise(OpenMPOp): ...@@ -677,7 +677,7 @@ class Elemwise(OpenMPOp):
self.nfunc = module self.nfunc = module
if ( if (
len(node.inputs) < 32 (len(node.inputs) + len(node.outputs)) <= 32
and (self.nfunc is None or self.scalar_op.nin != len(node.inputs)) and (self.nfunc is None or self.scalar_op.nin != len(node.inputs))
and self.ufunc is None and self.ufunc is None
and impl == "py" and impl == "py"
...@@ -727,28 +727,18 @@ class Elemwise(OpenMPOp): ...@@ -727,28 +727,18 @@ class Elemwise(OpenMPOp):
self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl) self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl)
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
if len(node.inputs) >= 32: if (len(node.inputs) + len(node.outputs)) > 32:
# Some versions of NumPy will segfault, other will raise a # Some versions of NumPy will segfault, other will raise a
# ValueError, if the number of inputs to a ufunc is 32 or more. # ValueError, if the number of operands in an ufunc is more than 32.
# In that case, the C version should be used, or Elemwise fusion # In that case, the C version should be used, or Elemwise fusion
# should be disabled. # should be disabled.
# FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage) super().perform(node, inputs, output_storage)
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))): for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
if len(set(dim_shapes) - {1}) > 1: if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}") raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
# Determine the shape of outputs
out_shape = []
for values in zip(*[input.shape for input in inputs]):
if any(v == 0 for v in values):
# All non-broadcasted dimensions should be zero
assert max(values) <= 1
out_shape.append(0)
else:
out_shape.append(max(values))
out_shape = tuple(out_shape)
ufunc_args = inputs ufunc_args = inputs
ufunc_kwargs = {} ufunc_kwargs = {}
# We supported in the past calling manually op.perform. # We supported in the past calling manually op.perform.
......
...@@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py ...@@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py
pytensor/tensor/random/op.py pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py pytensor/tensor/rewriting/basic.py
pytensor/tensor/rewriting/elemwise.py
pytensor/tensor/shape.py pytensor/tensor/shape.py
pytensor/tensor/slinalg.py pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py pytensor/tensor/subtensor.py
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as at import pytensor.tensor as at
from pytensor.compile import UnusedInputError from pytensor.compile import UnusedInputError, get_mode
from pytensor.compile.function import function, pfunc from pytensor.compile.function import function, pfunc
from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.io import In from pytensor.compile.io import In
...@@ -200,7 +200,12 @@ class TestPfunc: ...@@ -200,7 +200,12 @@ class TestPfunc:
bval = np.arange(5) bval = np.arange(5)
b.set_value(bval, borrow=True) b.set_value(bval, borrow=True)
bval = data_of(b) bval = data_of(b)
f = pfunc([], [b_out], updates=[(b, (b_out + 3))], mode="FAST_RUN") f = pfunc(
[],
[b_out],
updates=[(b, (b_out + 3))],
mode=get_mode("FAST_RUN").excluding("fusion"),
)
assert (f() == (np.arange(5) * 2)).all() assert (f() == (np.arange(5) * 2)).all()
# because of the update # because of the update
assert (b.get_value(borrow=True) == ((np.arange(5) * 2) + 3)).all() assert (b.get_value(borrow=True) == ((np.arange(5) * 2) + 3)).all()
......
import contextlib
import numpy as np import numpy as np
import pytest import pytest
...@@ -17,11 +15,14 @@ from pytensor.graph.rewriting.basic import check_stack_trace, out2in ...@@ -17,11 +15,14 @@ from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite
from pytensor.tensor.basic import MakeVector from pytensor.tensor.basic import MakeVector
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import add
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import ( from pytensor.tensor.math import (
add,
bitwise_and, bitwise_and,
bitwise_or, bitwise_or,
cos, cos,
...@@ -29,6 +30,7 @@ from pytensor.tensor.math import ( ...@@ -29,6 +30,7 @@ from pytensor.tensor.math import (
dot, dot,
eq, eq,
exp, exp,
ge,
int_div, int_div,
invert, invert,
iround, iround,
...@@ -900,6 +902,72 @@ class TestFusion: ...@@ -900,6 +902,72 @@ class TestFusion:
fxv * np.sin(fsv), fxv * np.sin(fsv),
"float32", "float32",
), ),
# Multiple output cases # 72
(
(
# sum(logp)
at_sum(-((fx - fy) ** 2) / 2),
# grad(logp)
at.grad(at_sum(-((fx - fy) ** 2) / 2), wrt=fx),
),
(fx, fy),
(fxv, fyv),
3,
(
np.sum(-((fxv - fyv) ** 2) / 2),
-(fxv - fyv),
),
("float32", "float32"),
),
# Two Composite graphs that share the same input, but are split by
# a non-elemwise operation (Assert)
(
(
log(
ge(
assert_op(
at_abs(fx),
at_all(ge(at_abs(fx), 0)),
),
0,
)
),
),
(fx,),
(fxv,),
4,
(np.zeros_like(fxv),),
("float32",),
),
# Two subgraphs that share the same non-fuseable input, but are otherwise
# completely independent
(
(
true_div(
mul(
at_sum(fx + 5), # breaks fusion
exp(fx),
),
(fx + 5),
),
),
(fx,),
(fxv,),
4,
(np.sum(fxv + 5) * np.exp(fxv) / (fxv + 5),),
("float32",),
),
pytest.param(
(
(sin(exp(fx)), exp(sin(fx))),
(fx,),
(fxv,),
1,
(np.sin(np.exp(fxv)), np.exp(np.sin(fxv))),
("float32", "float32"),
),
marks=pytest.mark.xfail, # Not implemented yet
),
], ],
) )
def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
...@@ -910,23 +978,34 @@ class TestFusion: ...@@ -910,23 +978,34 @@ class TestFusion:
if isinstance(out_dtype, dict): if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy] out_dtype = out_dtype[config.cast_policy]
if not isinstance(g, (tuple, list)):
g = (g,)
answer = (answer,)
out_dtype = (out_dtype,)
if self._shared is None: if self._shared is None:
f = function(list(sym_inputs), g, mode=self.mode) f = function(list(sym_inputs), g, mode=self.mode)
for x in range(nb_repeat): for x in range(nb_repeat):
out = f(*val_inputs) out = f(*val_inputs)
if not isinstance(out, list):
out = (out,)
else: else:
out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out") out = [
assert out.dtype == g.dtype self._shared(np.zeros((5,) * g_.ndim, dtype=od), "out")
f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode) for g_, od in zip(g, out_dtype)
]
assert all(o.dtype == g_.dtype for o, g_ in zip(out, g))
f = function(sym_inputs, [], updates=list(zip(out, g)), mode=self.mode)
for x in range(nb_repeat): for x in range(nb_repeat):
f(*val_inputs) f(*val_inputs)
out = out.get_value() out = [o.get_value() for o in out]
atol = 1e-8 atol = 1e-8
if out_dtype == "float32": if any(o == "float32" for o in out_dtype):
atol = 1e-6 atol = 1e-6
assert np.allclose(out, answer * nb_repeat, atol=atol) for o, a in zip(out, answer):
np.testing.assert_allclose(o, a * nb_repeat, atol=atol)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)]
...@@ -939,13 +1018,15 @@ class TestFusion: ...@@ -939,13 +1018,15 @@ class TestFusion:
# input of g, # input of g,
# check that the number of input to the Composite # check that the number of input to the Composite
# Elemwise is ok # Elemwise is ok
if len(set(g.owner.inputs)) == len(g.owner.inputs): for g_ in g:
if len(set(g_.owner.inputs)) == len(g_.owner.inputs):
expected_len_sym_inputs = sum( expected_len_sym_inputs = sum(
not isinstance(x, Constant) for x in topo_[0].inputs not isinstance(x, Constant) for x in topo_[0].inputs
) )
assert expected_len_sym_inputs == len(sym_inputs) assert expected_len_sym_inputs == len(sym_inputs)
assert out_dtype == out.dtype for od, o in zip(out_dtype, out):
assert od == o.dtype
def test_fusion_35_inputs(self): def test_fusion_35_inputs(self):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit.""" r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
...@@ -1006,6 +1087,30 @@ class TestFusion: ...@@ -1006,6 +1087,30 @@ class TestFusion:
for node in dlogp.maker.fgraph.toposort() for node in dlogp.maker.fgraph.toposort()
) )
@pytest.mark.xfail(reason="Fails due to #1244")
def test_add_mul_fusion_precedence(self):
"""Test that additions and multiplications are "fused together" before
a `Composite` `Op` is introduced. This fusion is done by canonicalization
"""
x, y, z = vectors("x", "y", "z")
out = log((x + y + z) / (x * y * z))
f = pytensor.function([x, y, z], out, mode=self.mode)
# There should be a single Composite Op
nodes = f.maker.fgraph.apply_nodes
assert len(nodes) == 1
(node,) = nodes
assert isinstance(node.op, Elemwise)
scalar_op = node.op.scalar_op
assert isinstance(scalar_op, Composite)
assert [node.op for node in scalar_op.fgraph.toposort()] == [
# There should be a single mul
aes.mul,
# There should be a single add
aes.add,
aes.true_div,
aes.log,
]
def test_add_mul_fusion_inplace(self): def test_add_mul_fusion_inplace(self):
x, y, z = dmatrices("xyz") x, y, z = dmatrices("xyz")
out = dot(x, y) + x + y + z out = dot(x, y) + x + y + z
...@@ -1082,11 +1187,8 @@ class TestFusion: ...@@ -1082,11 +1187,8 @@ class TestFusion:
@pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]]) @pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]])
def test_test_values(self, test_value): def test_test_values(self, test_value):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions. """Make sure that `local_elemwise_fusion_op` uses test values correctly
when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
""" """
x, y, z = dmatrices("xyz") x, y, z = dmatrices("xyz")
...@@ -1094,26 +1196,19 @@ class TestFusion: ...@@ -1094,26 +1196,19 @@ class TestFusion:
y.tag.test_value = test_value y.tag.test_value = test_value
z.tag.test_value = test_value z.tag.test_value = test_value
if test_value.size == 0:
cm = pytest.raises(ValueError)
else:
cm = contextlib.suppress()
with config.change_flags( with config.change_flags(
compute_test_value="raise", compute_test_value_opt="raise" compute_test_value="raise", compute_test_value_opt="raise"
): ):
out = x * y + z out = x * y + z
with cm:
f = function([x, y, z], out, mode=self.mode) f = function([x, y, z], out, mode=self.mode)
if test_value.size != 0:
# Confirm that the fusion happened # Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1 assert len(f.maker.fgraph.toposort()) == 1
x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs
assert np.array_equal( assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] f.maker.fgraph.outputs[0].tag.test_value,
np.full_like(test_value, 2.0),
) )
@pytest.mark.parametrize("linker", ["cvm", "py"]) @pytest.mark.parametrize("linker", ["cvm", "py"])
...@@ -1227,6 +1322,26 @@ class TestFusion: ...@@ -1227,6 +1322,26 @@ class TestFusion:
aes.mul, aes.mul,
} }
def test_multiple_outputs_fused_root_elemwise(self):
"""Test that a root elemwise output (single layer) is reused when
there is another fused output"""
# By default, we do not introduce Composite for single layers of Elemwise
x = at.vector("x")
out1 = at.cos(x)
f = pytensor.function([x], out1, mode=self.mode)
nodes = tuple(f.maker.fgraph.apply_nodes)
assert len(nodes) == 1
assert isinstance(nodes[0].op.scalar_op, aes.Cos)
# However, when it can be composed with another output, we should not
# compute that root Elemwise twice
out2 = at.log(out1)
f = pytensor.function([x], [out1, out2], mode=self.mode)
nodes = tuple(f.maker.fgraph.apply_nodes)
assert len(nodes) == 1
assert isinstance(nodes[0].op.scalar_op, Composite)
class TimesN(aes.basic.UnaryScalarOp): class TimesN(aes.basic.UnaryScalarOp):
""" """
......
...@@ -887,10 +887,9 @@ class TestLocalSubtensorLift: ...@@ -887,10 +887,9 @@ class TestLocalSubtensorLift:
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle) assert isinstance(prog[0].op, DimShuffle)
assert isinstance(prog[1].op.scalar_op, aes.Composite) # Composite{add,exp} assert isinstance(prog[1].op.scalar_op, aes.Composite) # Composite{add,exp}
assert prog[2].op == add or prog[3].op == add
# first subtensor # first subtensor
assert isinstance(prog[2].op, Subtensor) or isinstance(prog[3].op, Subtensor) assert isinstance(prog[2].op, Subtensor)
assert len(prog) == 4 assert len(prog) == 3
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
def test_basic_7(self): def test_basic_7(self):
......
...@@ -273,8 +273,7 @@ def test_debugprint(): ...@@ -273,8 +273,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
exp_res = dedent( exp_res = dedent(
r""" r"""
Elemwise{Composite{(i0 + (i1 - i2))}} 4 Elemwise{Composite{(i2 + (i0 - i1))}} 4
|A
|InplaceDimShuffle{x,0} v={0: [0]} 3 |InplaceDimShuffle{x,0} v={0: [0]} 3
| |CGemv{inplace} d={0: [0]} 2 | |CGemv{inplace} d={0: [0]} 2
| |AllocEmpty{dtype='float64'} 1 | |AllocEmpty{dtype='float64'} 1
...@@ -285,6 +284,7 @@ def test_debugprint(): ...@@ -285,6 +284,7 @@ def test_debugprint():
| |<TensorType(float64, (?,))> | |<TensorType(float64, (?,))>
| |TensorConstant{0.0} | |TensorConstant{0.0}
|D |D
|A
""" """
).lstrip() ).lstrip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论