提交 0670ac2f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fuse consecutive Elemwise nodes with multiple clients

上级 d5cb23a5
...@@ -652,10 +652,10 @@ class Elemwise(OpenMPOp): ...@@ -652,10 +652,10 @@ class Elemwise(OpenMPOp):
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
# Postpone the ufunc building to the last minutes due to: # Postpone the ufunc building to the last minutes due to:
# - NumPy ufunc support only up to 31 inputs. # - NumPy ufunc support only up to 32 operands (inputs and outputs)
# But our c code support more. # But our c code support more.
# - nfunc is reused for scipy and scipy is optional # - nfunc is reused for scipy and scipy is optional
if len(node.inputs) > 32 and self.ufunc and impl == "py": if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py":
impl = "c" impl = "c"
if getattr(self, "nfunc_spec", None) and impl != "c": if getattr(self, "nfunc_spec", None) and impl != "c":
...@@ -677,7 +677,7 @@ class Elemwise(OpenMPOp): ...@@ -677,7 +677,7 @@ class Elemwise(OpenMPOp):
self.nfunc = module self.nfunc = module
if ( if (
len(node.inputs) < 32 (len(node.inputs) + len(node.outputs)) <= 32
and (self.nfunc is None or self.scalar_op.nin != len(node.inputs)) and (self.nfunc is None or self.scalar_op.nin != len(node.inputs))
and self.ufunc is None and self.ufunc is None
and impl == "py" and impl == "py"
...@@ -727,28 +727,18 @@ class Elemwise(OpenMPOp): ...@@ -727,28 +727,18 @@ class Elemwise(OpenMPOp):
self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl) self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl)
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
if len(node.inputs) >= 32: if (len(node.inputs) + len(node.outputs)) > 32:
# Some versions of NumPy will segfault, other will raise a # Some versions of NumPy will segfault, other will raise a
# ValueError, if the number of inputs to a ufunc is 32 or more. # ValueError, if the number of operands in an ufunc is more than 32.
# In that case, the C version should be used, or Elemwise fusion # In that case, the C version should be used, or Elemwise fusion
# should be disabled. # should be disabled.
# FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage) super().perform(node, inputs, output_storage)
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))): for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
if len(set(dim_shapes) - {1}) > 1: if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}") raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
# Determine the shape of outputs
out_shape = []
for values in zip(*[input.shape for input in inputs]):
if any(v == 0 for v in values):
# All non-broadcasted dimensions should be zero
assert max(values) <= 1
out_shape.append(0)
else:
out_shape.append(max(values))
out_shape = tuple(out_shape)
ufunc_args = inputs ufunc_args = inputs
ufunc_kwargs = {} ufunc_kwargs = {}
# We supported in the past calling manually op.perform. # We supported in the past calling manually op.perform.
......
import sys import sys
import time from collections import defaultdict, deque
from collections import defaultdict from functools import lru_cache
from typing import Optional from typing import DefaultDict, Generator, List, Set, Tuple, TypeVar
from warnings import warn from warnings import warn
import pytensor import pytensor
import pytensor.scalar.basic as aes import pytensor.scalar.basic as aes
from pytensor import compile from pytensor import clone_replace, compile
from pytensor.compile.mode import get_target_language from pytensor.compile.mode import get_target_language
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, io_toposort from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort
from pytensor.graph.features import ReplaceValidate from pytensor.graph.features import ReplaceValidate
from pytensor.graph.op import compute_test_value, get_test_value from pytensor.graph.fg import ApplyOrOutput
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter, EquilibriumGraphRewriter,
GraphRewriter, GraphRewriter,
...@@ -20,7 +21,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -20,7 +21,7 @@ from pytensor.graph.rewriting.basic import (
node_rewriter, node_rewriter,
) )
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
...@@ -592,333 +593,438 @@ def local_add_mul_fusion(fgraph, node): ...@@ -592,333 +593,438 @@ def local_add_mul_fusion(fgraph, node):
return [output] return [output]
def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None): def elemwise_max_operands_fct(node) -> int:
r"""Create a recursive function that fuses `Elemwise` `Op`\s. # `Elemwise.perform` uses NumPy ufuncs and they are limited to 32 operands (inputs and outputs)
if not config.cxx:
The basic idea is that we loop through an `Elemwise` node's inputs, find return 32
other `Elemwise` nodes, determine the scalars input types for all of the return 1024
`Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types
and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a
new "fused" `Elemwise`.
It's parameterized in order to work for `Elemwise` `Op`\s.
Parameters
----------
op_class : type
`Elemwise` class (the one that we want to fuse)
max_input_fct : callable
A function that returns the maximum number of inputs that this `Elemwise`
can take.
On the CPU we limit to 32 input variables since that is the maximum
NumPy support.
maker: callable
A function with the signature ``(node, *args)`` that constructs an
`op_class` instance (e.g. ``op_class(*args)``).
"""
if maker is None:
def maker(node, scalar_op):
return op_class(scalar_op)
def local_fuse(fgraph, node):
r"""Fuse `Elemwise` `Op`\s in a node.
As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the
same shape.
For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C
compiler do the cast.
The number of dimensions is validated at call time by PyTensor itself. class FusionOptimizer(GraphRewriter):
"""Graph optimizer that fuses consecutive Elemwise operations."""
""" def add_requirements(self, fgraph):
# TODO: use broadcast flag? fgraph.attach_feature(ReplaceValidate())
# TODO: don't do this rewrite as a `NodeRewriter`. @staticmethod
# Analyze the graph in terms of elemwise subgraphs, and then def elemwise_to_scalar(inputs, outputs):
# replace each subgraph with a Composite version. replace_inputs = [(inp, inp.clone()) for inp in inputs]
outputs = clone_replace(outputs, replace=replace_inputs)
# TODO: use malloc and copy to transfer arguments that don't inputs = [inp for _, inp in replace_inputs]
# fit within the parameter space of 256 bytes fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False)
# middle_inputs = []
# TODO: Merge with multiple output to merge when an inputs
# have multiple clients. This can't be done with a `NodeRewriter`
# TODO: Related: Support composites with multiple outputs
# TODO: Use Composite to combine Elemwise and Reduce
# operations. We have to loop over the data anyway... might
# as well sum it up while we're at it (this can be trickier
# than i'm making it seound here. The data-traversal should be
# done contiguously, and the summing-up might not be easy or
# worthwhile if the summation axis doesn't line up with a
# contiguous dimension)
if type(node.op) is not op_class:
return False
if len(node.outputs) > 1:
# We don't support fusion for nodes with multiple outputs.
return
inputs = [] # inputs of the new Elemwise op.
s_inputs = [] # inputs of the new scalar op used by the Composite.
# Inputs of the new scalar op that represents the current node.
s_g = []
# There is a hard limit of 256 bytes for the formal argument list to a
# GPU kernel function.
max_nb_input = max_input_fct(node)
# The number of inputs to the new fused op if we do not fuse more
# inputs.
new_nb_input = len(node.inputs)
# Did we fuse something?
# Needed as we can fuse unary op that don't change the number of
# inputs.
# And there is a case where the inputs are the same as the current
# node. That won't change the number of inputs of the new op.
fused = False
for i in node.inputs:
scalar_node: Optional[Apply] = None
# Will store inputs of the fused node that are not currently inputs
# of the node we want to create (to avoid duplicating inputs).
tmp_input = []
# Same as tmp_input, but for scalars.
tmp_scalar = []
# We should not check the number of inputs here
# As fusing op don't always change the number of input.
# If a variable is used as multiple into to the same node,
# we still want to fusion. So we take the set.
if (
i.owner
and isinstance(i.owner.op, op_class)
and len({n for n, idx in fgraph.clients[i]}) == 1
and
# Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate
# computation due to broadcast.
i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable
):
try:
tmp_s_input = []
# we should not put duplicate input into s_inputs and inputs
for ii in i.owner.inputs:
if ii in inputs:
tmp_s_input.append(s_inputs[inputs.index(ii)])
elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else:
tmp = aes.get_scalar_type(ii.type.dtype).make_variable()
try:
tv = get_test_value(ii)
# Sometimes the original inputs have
# zero-valued shapes in some dimensions, which
# implies that this whole scalar thing doesn't
# make sense (i.e. we're asking for the scalar
# value of an entry in a zero-dimensional
# array).
# This will eventually lead to an error in the
# `compute_test_value` call below when/if
# `config.compute_test_value_opt` is enabled
# (for debugging, more or less)
tmp.tag.test_value = tv.item()
except (TestValueError, ValueError):
pass
tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])
# Use the `Op.make_node` interface in case `Op.__call__`
# has been customized
scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input)
if config.compute_test_value_opt != "off":
# This is required because `Op.make_node` won't do it
compute_test_value(scalar_node)
# If the scalar_op doesn't have a C implementation, we skip
# its fusion to allow fusion of the other ops
i.owner.op.scalar_op.c_code(
scalar_node,
"test_presence_of_c_code",
["x" for x in i.owner.inputs],
["z" for z in i.owner.outputs],
{"fail": "%(fail)s"},
)
except (NotImplementedError, MethodNotDefined): scalar_inputs = [
warn( aes.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
"Rewrite warning: " ]
f"The Op {i.owner.op.scalar_op} does not provide a C implementation." middle_scalar_inputs = []
" As well as being potentially slow, this also disables "
"loop fusion." for node in fg.toposort():
node_scalar_inputs = []
for inp in node.inputs:
if inp in inputs:
node_scalar_inputs.append(scalar_inputs[inputs.index(inp)])
elif inp in middle_inputs:
node_scalar_inputs.append(
middle_scalar_inputs[middle_inputs.index(inp)]
) )
scalar_node = None
# Compute the number of inputs in case we fuse this input.
# We subtract 1 because we replace the existing input with the new
# inputs from `tmp_input`.
new_nb_input_ = new_nb_input + len(tmp_input) - 1
# If the new input is already an input of the current node, it was
# already counted when `new_nb_input` was initialized to
# len(node.inputs).
# This can happen when a variable is used both by the Elemwise to
# fuse and the current node.
for x in tmp_input:
if x in node.inputs:
new_nb_input_ -= 1
if scalar_node and (new_nb_input_ <= max_nb_input):
fused = True
new_nb_input = new_nb_input_
inputs.extend(tmp_input)
s_inputs.extend(tmp_scalar)
s_g.extend(scalar_node.outputs)
else:
# We must support the case where the same variable appears many
# times within the inputs
if inputs.count(i) == node.inputs.count(i):
s = s_inputs[inputs.index(i)]
else: else:
s = aes.get_scalar_type(i.type.dtype).make_variable() new_scalar_input = aes.get_scalar_type(
if config.compute_test_value_opt != "off": inp.type.dtype
try: ).make_variable()
v = get_test_value(i) node_scalar_inputs.append(new_scalar_input)
# See the zero-dimensional test value situation middle_scalar_inputs.append(new_scalar_input)
# described above. middle_inputs.append(inp)
s.tag.test_value = v.item()
except (TestValueError, ValueError): new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs)
pass middle_scalar_inputs.append(new_scalar_node.outputs[0])
middle_inputs.append(node.outputs[0])
inputs.append(i)
s_inputs.append(s) scalar_outputs = [
s_g.append(s) middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs
]
if not fused: return scalar_inputs, scalar_outputs
return False
if new_nb_input != len(inputs) or len(s_inputs) != len(inputs):
# TODO FIXME: This shouldn't be a generic `Exception`
raise Exception(
"Something has gone wrong with the elemwise fusion rewrite; skipping."
)
s_new_out = node.op.scalar_op(*s_g, return_list=True)
try:
s_new_out[0].owner.op.c_code(
s_new_out[0].owner,
"test_presence_of_c_code",
["x" for x in s_g],
["z" for x in s_new_out],
{"fail": "%(fail)s"},
)
except (NotImplementedError, MethodNotDefined):
name = str(s_new_out[0].owner.op)
warn(
"Rewrite warning: "
f"The Op {name} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
return False
# create the composite op.
composite_op = aes.Composite(s_inputs, s_new_out)
# create the new node.
# Do not call make_node to have test_value
new_node = maker(node, composite_op)(*inputs).owner
assert len(new_node.outputs) == 1
assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype
if len(new_node.inputs) > max_nb_input: def apply(self, fgraph):
warn( nb_replacement = 0
"Loop fusion failed because the resulting node "
"would exceed the kernel argument limit."
)
return False
# we fuse as many that we can at the same time to make debug mode faster
# debug mode will be faster as it won't test all intermediate step.
while True:
ret = local_fuse(fgraph, new_node)
if ret is not False and ret is not None:
assert len(ret) == len(new_node.outputs)
assert len(ret) == 1
new_node = ret[0].owner
else:
break
return new_node.outputs if fgraph.profile:
validate_before = fgraph.profile.validate_time
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
return local_fuse max_operands = elemwise_max_operands_fct(None)
def find_next_fuseable_subgraph(
fg: FunctionGraph,
) -> Generator[Tuple[List[Variable], List[Variable]], None, None]:
"""Find all subgraphs in a FunctionGraph that can be fused together
Yields
-------
List of inputs and outputs that determine subgraphs which can be fused.
This generator assumes that such subgraph is replaced by a single
Elemwise Composite before being accessed again in the next iteration.
"""
FUSEABLE_MAPPING = DefaultDict[Variable, List[Apply]]
UNFUSEABLE_MAPPING = DefaultDict[Variable, Set[ApplyOrOutput]]
def initialize_fuseable_mappings(
*, fg: FunctionGraph
) -> Tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]:
@lru_cache(maxsize=None)
def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
# TODO: This should not play a role in non-c backends!
if node.op.scalar_op.supports_c_code(node.inputs, node.outputs):
return True
else:
warn(
"Optimization Warning: "
f"The Op {node.op.scalar_op} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
return False
# Fuseable nodes have to be accessed in a deterministic manner
# to ensure the rewrite remains deterministic.
# This is not a problem from unfuseable ones, as they can never
# become part of the graph.
fuseable_clients: FUSEABLE_MAPPING = defaultdict(list)
unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set)
for out, clients in fg.clients.items():
out_maybe_fuseable = (
out.owner
and isinstance(out.owner.op, Elemwise)
# and not isinstance(out.owner.op.scalar_op, aes.Composite)
and len(out.owner.outputs) == 1
and elemwise_scalar_op_has_c_code(out.owner)
)
for client, _ in clients:
if (
out_maybe_fuseable
and not isinstance(client, str) # "output"
and isinstance(client.op, Elemwise)
# and not isinstance(client.op.scalar_op, aes.Composite)
and len(client.outputs) == 1
and out.type.broadcastable
== client.outputs[0].type.broadcastable
and elemwise_scalar_op_has_c_code(client)
):
if client not in fuseable_clients[out]:
fuseable_clients[out].append(client)
else:
unfuseable_clients[out].add(client)
return fuseable_clients, unfuseable_clients
def find_fuseable_subgraph(
*,
fg: FunctionGraph,
visited_nodes: Set[Apply],
fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING,
) -> Tuple[List[Variable], List[Variable]]:
KT = TypeVar("KT")
VT = TypeVar("VT", list, set)
def shallow_clone_defaultdict(
d: DefaultDict[KT, VT]
) -> DefaultDict[KT, VT]:
new_dict: DefaultDict[KT, VT] = defaultdict(d.default_factory)
new_dict.update({k: v.copy() for k, v in d.items()})
return new_dict
def variables_depend_on(
variables, depend_on, stop_search_at=None
) -> bool:
return any(
a in depend_on
for a in ancestors(variables, blockers=stop_search_at)
)
toposort = fg.toposort()
for starting_node in toposort:
if starting_node in visited_nodes:
continue
def elemwise_max_input_fct(node): starting_out = starting_node.outputs[0]
# `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs. if not fuseable_clients.get(starting_out):
if not config.cxx: visited_nodes.add(starting_node)
return 31 continue
return 1024
subgraph_inputs: List[Variable] = []
subgraph_outputs: List[Variable] = []
unfuseable_clients_subgraph: Set[Variable] = set()
local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct) # Shallow cloning of maps so that they can be manipulated in place
fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients)
unfuseable_clients_clone = shallow_clone_defaultdict(
unfuseable_clients
)
fuseable_nodes_to_visit = deque([starting_node])
# We now try to expand as much as possible towards the potentially
# fuseable clients and ancestors to detect the largest possible
# subgraph that can be Composed together into a single `Op`. The
# largest issue to watch out is for cyclical dependencies, where
# some inputs or clients may depend on other nodes of the same
# subgraph via a path that cannot be included in the Composite
# (unfuseable)
while fuseable_nodes_to_visit:
next_node = fuseable_nodes_to_visit.popleft()
visited_nodes.add(next_node)
next_out = next_node.outputs[0]
# If the output variable of next_node has no fuseable clients
# or has unfuseable clients, then next_node must become an output
# if it is to be fused.
must_become_output = (
next_out not in fuseable_clients_temp
or next_out in unfuseable_clients_clone
)
class FusionOptimizer(GraphRewriter): # We have backtracked to this node, and it may no longer be a viable output,
"""Graph rewriter that simply runs node fusion operations. # so we remove it and check again as if we had never seen this node
if must_become_output and next_out in subgraph_outputs:
subgraph_outputs.remove(next_out)
required_unfuseable_inputs = [
inp
for inp in next_node.inputs
if next_node in unfuseable_clients_clone.get(inp, ())
]
new_required_unfuseable_inputs = [
inp
for inp in required_unfuseable_inputs
if inp not in subgraph_inputs
]
must_backtrack = False
if new_required_unfuseable_inputs and subgraph_outputs:
# We need to check that any new inputs required by this node
# do not depend on other outputs of the current subgraph,
# via an unfuseable path.
if variables_depend_on(
[next_out],
depend_on=unfuseable_clients_subgraph,
stop_search_at=subgraph_outputs,
):
must_backtrack = True
if not must_backtrack:
implied_unfuseable_clients = {
c
for client in unfuseable_clients_clone.get(next_out, ())
if not isinstance(client, str) # "output"
for c in client.outputs
}
new_implied_unfuseable_clients = (
implied_unfuseable_clients - unfuseable_clients_subgraph
)
TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that. if new_implied_unfuseable_clients and subgraph_inputs:
# We need to check that any inputs of the current subgraph
# do not depend on other clients of this node,
# via an unfuseable path.
if variables_depend_on(
subgraph_inputs,
depend_on=new_implied_unfuseable_clients,
):
must_backtrack = True
if must_backtrack:
for inp in next_node.inputs:
if (
inp.owner in visited_nodes
# next_node could have the same input repeated
and next_node in fuseable_clients_temp[inp]
):
fuseable_clients_temp[inp].remove(next_node)
unfuseable_clients_clone[inp].add(next_node)
# This input must become an output of the subgraph,
# because it can't be merged with next_node.
# We will revisit it to make sure this is safe.
fuseable_nodes_to_visit.appendleft(inp.owner)
for client in fuseable_clients_temp[next_out]:
if client in visited_nodes:
fuseable_clients_temp[next_out].remove(client)
unfuseable_clients_clone[next_out].add(client)
# next_out must become an input of the subgraph.
# We will revisit any of its clients currently
# in the subgraph to make sure this is safe.
fuseable_nodes_to_visit.appendleft(client)
# Revisit node at a later time
visited_nodes.remove(next_node)
continue
# Adding next_node to subgraph does not result in any
# immediate dependency problems. Update subgraph
# mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite
for inp in new_required_unfuseable_inputs:
if inp not in subgraph_inputs:
subgraph_inputs.append(inp)
if must_become_output:
subgraph_outputs.append(next_out)
unfuseable_clients_subgraph.update(
new_implied_unfuseable_clients
)
""" # Expand through unvisited fuseable ancestors
for inp in sorted(
(
inp
for inp in next_node.inputs
if (
inp not in required_unfuseable_inputs
and inp.owner not in visited_nodes
)
),
key=lambda inp: toposort.index(inp.owner),
reverse=True,
):
fuseable_nodes_to_visit.appendleft(inp.owner)
# Expand through unvisited fuseable clients
for next_node in sorted(
(
node
for node in fuseable_clients_temp.get(next_out, ())
if node not in visited_nodes
),
key=lambda node: toposort.index(node),
):
fuseable_nodes_to_visit.append(next_node)
# Don't return if final subgraph is just the original Elemwise
if len(subgraph_outputs) == 1 and set(
subgraph_outputs[0].owner.inputs
) == set(subgraph_inputs):
# Update global fuseable mappings
# No input was actually fuseable
for inp in starting_node.inputs:
if starting_node in fuseable_clients.get(inp, ()):
fuseable_clients[inp].remove(starting_node)
unfuseable_clients[inp].add(starting_node)
# No client was actually fuseable
unfuseable_clients[starting_out].update(
fuseable_clients.pop(starting_out, ())
)
continue
def __init__(self, node_rewriter): return subgraph_inputs, subgraph_outputs
super().__init__() raise ValueError
self.node_rewriter = node_rewriter
def update_fuseable_mappings_after_fg_replace(
*,
fg: FunctionGraph,
visited_nodes: Set[Apply],
fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING,
starting_nodes: Set[Apply],
) -> None:
# Find new composite node and dropped intermediate nodes
# by comparing the current fg.apply nodes with the cached
# original nodes
next_nodes = fg.apply_nodes
(new_composite_node,) = next_nodes - starting_nodes
dropped_nodes = starting_nodes - next_nodes
# Remove intermediate Composite nodes from mappings
for dropped_node in dropped_nodes:
(dropped_out,) = dropped_node.outputs
fuseable_clients.pop(dropped_out, None)
unfuseable_clients.pop(dropped_out, None)
visited_nodes.remove(dropped_node)
# Update fuseable information for subgraph inputs
for inp in subgraph_inputs:
if inp in fuseable_clients:
new_fuseable_clients = [
client
for client in fuseable_clients[inp]
if client not in dropped_nodes
]
if new_fuseable_clients:
fuseable_clients[inp] = new_fuseable_clients
else:
fuseable_clients.pop(inp)
unfuseable_clients[inp] = (
unfuseable_clients[inp] - dropped_nodes
) | {new_composite_node}
# Update fuseable information for subgraph outputs
for out in new_composite_node.outputs:
unfuseable_clients[out] = {client for client, _ in fg.clients[out]}
visited_nodes.add(new_composite_node)
return
# We start by creating two maps, 1) from each node to each potentially
# fuseable client (both nodes must be single output Elemwise with same
# broadcast type) and 2) from each node to each certainly unfuseable
# client (those that don't fit into 1))
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
visited_nodes: Set[Apply] = set()
while True:
starting_nodes = fg.apply_nodes.copy()
try:
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
fg=fg,
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
)
except ValueError:
return
else:
# The caller is now expected to update fg in place,
# by replacing the subgraph with a Composite Op
yield subgraph_inputs, subgraph_outputs
# This is where we avoid repeated work by using a stateful
# generator. For large models (as in `TestFusion.test_big_fusion`)
# this can provide huge speedups
update_fuseable_mappings_after_fg_replace(
fg=fg,
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
starting_nodes=starting_nodes,
)
def add_requirements(self, fgraph): for inputs, outputs in find_next_fuseable_subgraph(fgraph):
fgraph.attach_feature(ReplaceValidate()) if (len(inputs) + len(outputs)) > max_operands:
warn(
"Loop fusion failed because the resulting node would exceed "
"the kernel argument limit."
)
break
def apply(self, fgraph): scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
did_something = True composite_outputs = Elemwise(aes.Composite(scalar_inputs, scalar_outputs))(
nb_iter = 0 *inputs
nb_replacement = 0 )
nb_inconsistency_replace = 0 if not isinstance(composite_outputs, list):
time_toposort = 0 composite_outputs = [composite_outputs]
if fgraph.profile: for old_out, composite_out in zip(outputs, composite_outputs):
validate_before = fgraph.profile.validate_time if old_out.name:
callbacks_before = fgraph.execute_callbacks_times.copy() composite_out.name = old_out.name
callback_before = fgraph.execute_callbacks_time
while did_something: fgraph.replace_all_validate(
t0 = time.perf_counter() list(zip(outputs, composite_outputs)),
nodelist = list(fgraph.toposort()) reason=self.__class__.__name__,
time_toposort += time.perf_counter() - t0 )
nodelist.reverse() nb_replacement += 1
did_something = False
for node in nodelist:
# Don't try to fuse node that have already been fused.
if node in fgraph.apply_nodes:
new_outputs = self.node_rewriter(fgraph, node)
if new_outputs:
assert len(new_outputs) == len(node.outputs)
try:
fgraph.replace_all_validate(
list(zip(node.outputs, new_outputs)),
reason=self.__class__.__name__,
)
did_something = True
nb_replacement += 1
except InconsistencyError:
nb_inconsistency_replace += 1
nb_iter += 1
if fgraph.profile: if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before validate_time = fgraph.profile.validate_time - validate_before
...@@ -933,21 +1039,22 @@ class FusionOptimizer(GraphRewriter): ...@@ -933,21 +1039,22 @@ class FusionOptimizer(GraphRewriter):
validate_time = None validate_time = None
callback_time = None callback_time = None
callbacks_time = {} callbacks_time = {}
return ( return (
self, self,
nb_iter, 1, # nb_iter
nb_replacement, nb_replacement,
nb_inconsistency_replace, 0, # nb_inconsintency_replace
validate_time, validate_time,
callback_time, callback_time,
callbacks_time, callbacks_time,
time_toposort, -1, # toposort_time
) )
@classmethod @staticmethod
def print_profile(cls, stream, prof, level=0): def print_profile(stream, prof, level=0):
blanc = " " * level blanc = " " * level
print(blanc, cls.__name__, file=stream) print(blanc, "FusionOptimizer", file=stream)
print(blanc, " nb_iter", prof[1], file=stream) print(blanc, " nb_iter", prof[1], file=stream)
print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_replacement", prof[2], file=stream)
print(blanc, " nb_inconsistency_replace", prof[3], file=stream) print(blanc, " nb_inconsistency_replace", prof[3], file=stream)
...@@ -973,7 +1080,7 @@ if config.tensor__local_elemwise_fusion: ...@@ -973,7 +1080,7 @@ if config.tensor__local_elemwise_fusion:
) )
fuse_seqopt.register( fuse_seqopt.register(
"composite_elemwise_fusion", "composite_elemwise_fusion",
FusionOptimizer(local_elemwise_fusion), FusionOptimizer(),
"fast_run", "fast_run",
"fusion", "fusion",
position=1, position=1,
...@@ -999,7 +1106,9 @@ def local_useless_composite(fgraph, node): ...@@ -999,7 +1106,9 @@ def local_useless_composite(fgraph, node):
): ):
return return
comp = node.op.scalar_op comp = node.op.scalar_op
used_outputs_idxs = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]] used_outputs_idxs = [
i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]
]
used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs] used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs]
comp_fgraph = FunctionGraph( comp_fgraph = FunctionGraph(
inputs=comp.inputs, outputs=used_inner_outputs, clone=False inputs=comp.inputs, outputs=used_inner_outputs, clone=False
......
...@@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py ...@@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py
pytensor/tensor/random/op.py pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py pytensor/tensor/rewriting/basic.py
pytensor/tensor/rewriting/elemwise.py
pytensor/tensor/shape.py pytensor/tensor/shape.py
pytensor/tensor/slinalg.py pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py pytensor/tensor/subtensor.py
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import pytest import pytest
import pytensor.tensor as at import pytensor.tensor as at
from pytensor.compile import UnusedInputError from pytensor.compile import UnusedInputError, get_mode
from pytensor.compile.function import function, pfunc from pytensor.compile.function import function, pfunc
from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.io import In from pytensor.compile.io import In
...@@ -200,7 +200,12 @@ class TestPfunc: ...@@ -200,7 +200,12 @@ class TestPfunc:
bval = np.arange(5) bval = np.arange(5)
b.set_value(bval, borrow=True) b.set_value(bval, borrow=True)
bval = data_of(b) bval = data_of(b)
f = pfunc([], [b_out], updates=[(b, (b_out + 3))], mode="FAST_RUN") f = pfunc(
[],
[b_out],
updates=[(b, (b_out + 3))],
mode=get_mode("FAST_RUN").excluding("fusion"),
)
assert (f() == (np.arange(5) * 2)).all() assert (f() == (np.arange(5) * 2)).all()
# because of the update # because of the update
assert (b.get_value(borrow=True) == ((np.arange(5) * 2) + 3)).all() assert (b.get_value(borrow=True) == ((np.arange(5) * 2) + 3)).all()
......
import contextlib
import numpy as np import numpy as np
import pytest import pytest
...@@ -17,11 +15,14 @@ from pytensor.graph.rewriting.basic import check_stack_trace, out2in ...@@ -17,11 +15,14 @@ from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite
from pytensor.tensor.basic import MakeVector from pytensor.tensor.basic import MakeVector
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import add
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import ( from pytensor.tensor.math import (
add,
bitwise_and, bitwise_and,
bitwise_or, bitwise_or,
cos, cos,
...@@ -29,6 +30,7 @@ from pytensor.tensor.math import ( ...@@ -29,6 +30,7 @@ from pytensor.tensor.math import (
dot, dot,
eq, eq,
exp, exp,
ge,
int_div, int_div,
invert, invert,
iround, iround,
...@@ -900,6 +902,72 @@ class TestFusion: ...@@ -900,6 +902,72 @@ class TestFusion:
fxv * np.sin(fsv), fxv * np.sin(fsv),
"float32", "float32",
), ),
# Multiple output cases # 72
(
(
# sum(logp)
at_sum(-((fx - fy) ** 2) / 2),
# grad(logp)
at.grad(at_sum(-((fx - fy) ** 2) / 2), wrt=fx),
),
(fx, fy),
(fxv, fyv),
3,
(
np.sum(-((fxv - fyv) ** 2) / 2),
-(fxv - fyv),
),
("float32", "float32"),
),
# Two Composite graphs that share the same input, but are split by
# a non-elemwise operation (Assert)
(
(
log(
ge(
assert_op(
at_abs(fx),
at_all(ge(at_abs(fx), 0)),
),
0,
)
),
),
(fx,),
(fxv,),
4,
(np.zeros_like(fxv),),
("float32",),
),
# Two subgraphs that share the same non-fuseable input, but are otherwise
# completely independent
(
(
true_div(
mul(
at_sum(fx + 5), # breaks fusion
exp(fx),
),
(fx + 5),
),
),
(fx,),
(fxv,),
4,
(np.sum(fxv + 5) * np.exp(fxv) / (fxv + 5),),
("float32",),
),
pytest.param(
(
(sin(exp(fx)), exp(sin(fx))),
(fx,),
(fxv,),
1,
(np.sin(np.exp(fxv)), np.exp(np.sin(fxv))),
("float32", "float32"),
),
marks=pytest.mark.xfail, # Not implemented yet
),
], ],
) )
def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
...@@ -910,23 +978,34 @@ class TestFusion: ...@@ -910,23 +978,34 @@ class TestFusion:
if isinstance(out_dtype, dict): if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy] out_dtype = out_dtype[config.cast_policy]
if not isinstance(g, (tuple, list)):
g = (g,)
answer = (answer,)
out_dtype = (out_dtype,)
if self._shared is None: if self._shared is None:
f = function(list(sym_inputs), g, mode=self.mode) f = function(list(sym_inputs), g, mode=self.mode)
for x in range(nb_repeat): for x in range(nb_repeat):
out = f(*val_inputs) out = f(*val_inputs)
if not isinstance(out, list):
out = (out,)
else: else:
out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out") out = [
assert out.dtype == g.dtype self._shared(np.zeros((5,) * g_.ndim, dtype=od), "out")
f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode) for g_, od in zip(g, out_dtype)
]
assert all(o.dtype == g_.dtype for o, g_ in zip(out, g))
f = function(sym_inputs, [], updates=list(zip(out, g)), mode=self.mode)
for x in range(nb_repeat): for x in range(nb_repeat):
f(*val_inputs) f(*val_inputs)
out = out.get_value() out = [o.get_value() for o in out]
atol = 1e-8 atol = 1e-8
if out_dtype == "float32": if any(o == "float32" for o in out_dtype):
atol = 1e-6 atol = 1e-6
assert np.allclose(out, answer * nb_repeat, atol=atol) for o, a in zip(out, answer):
np.testing.assert_allclose(o, a * nb_repeat, atol=atol)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)]
...@@ -939,13 +1018,15 @@ class TestFusion: ...@@ -939,13 +1018,15 @@ class TestFusion:
# input of g, # input of g,
# check that the number of input to the Composite # check that the number of input to the Composite
# Elemwise is ok # Elemwise is ok
if len(set(g.owner.inputs)) == len(g.owner.inputs): for g_ in g:
expected_len_sym_inputs = sum( if len(set(g_.owner.inputs)) == len(g_.owner.inputs):
not isinstance(x, Constant) for x in topo_[0].inputs expected_len_sym_inputs = sum(
) not isinstance(x, Constant) for x in topo_[0].inputs
assert expected_len_sym_inputs == len(sym_inputs) )
assert expected_len_sym_inputs == len(sym_inputs)
assert out_dtype == out.dtype for od, o in zip(out_dtype, out):
assert od == o.dtype
def test_fusion_35_inputs(self): def test_fusion_35_inputs(self):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit.""" r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
...@@ -1006,6 +1087,30 @@ class TestFusion: ...@@ -1006,6 +1087,30 @@ class TestFusion:
for node in dlogp.maker.fgraph.toposort() for node in dlogp.maker.fgraph.toposort()
) )
@pytest.mark.xfail(reason="Fails due to #1244")
def test_add_mul_fusion_precedence(self):
"""Test that additions and multiplications are "fused together" before
a `Composite` `Op` is introduced. This fusion is done by canonicalization
"""
x, y, z = vectors("x", "y", "z")
out = log((x + y + z) / (x * y * z))
f = pytensor.function([x, y, z], out, mode=self.mode)
# There should be a single Composite Op
nodes = f.maker.fgraph.apply_nodes
assert len(nodes) == 1
(node,) = nodes
assert isinstance(node.op, Elemwise)
scalar_op = node.op.scalar_op
assert isinstance(scalar_op, Composite)
assert [node.op for node in scalar_op.fgraph.toposort()] == [
# There should be a single mul
aes.mul,
# There should be a single add
aes.add,
aes.true_div,
aes.log,
]
def test_add_mul_fusion_inplace(self): def test_add_mul_fusion_inplace(self):
x, y, z = dmatrices("xyz") x, y, z = dmatrices("xyz")
out = dot(x, y) + x + y + z out = dot(x, y) + x + y + z
...@@ -1082,11 +1187,8 @@ class TestFusion: ...@@ -1082,11 +1187,8 @@ class TestFusion:
@pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]]) @pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]])
def test_test_values(self, test_value): def test_test_values(self, test_value):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions. """Make sure that `local_elemwise_fusion_op` uses test values correctly
when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
""" """
x, y, z = dmatrices("xyz") x, y, z = dmatrices("xyz")
...@@ -1094,27 +1196,20 @@ class TestFusion: ...@@ -1094,27 +1196,20 @@ class TestFusion:
y.tag.test_value = test_value y.tag.test_value = test_value
z.tag.test_value = test_value z.tag.test_value = test_value
if test_value.size == 0:
cm = pytest.raises(ValueError)
else:
cm = contextlib.suppress()
with config.change_flags( with config.change_flags(
compute_test_value="raise", compute_test_value_opt="raise" compute_test_value="raise", compute_test_value_opt="raise"
): ):
out = x * y + z out = x * y + z
with cm: f = function([x, y, z], out, mode=self.mode)
f = function([x, y, z], out, mode=self.mode)
if test_value.size != 0: # Confirm that the fusion happened
# Confirm that the fusion happened assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) assert len(f.maker.fgraph.toposort()) == 1
assert len(f.maker.fgraph.toposort()) == 1
x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs assert np.array_equal(
assert np.array_equal( f.maker.fgraph.outputs[0].tag.test_value,
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] np.full_like(test_value, 2.0),
) )
@pytest.mark.parametrize("linker", ["cvm", "py"]) @pytest.mark.parametrize("linker", ["cvm", "py"])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)]) @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
...@@ -1227,6 +1322,26 @@ class TestFusion: ...@@ -1227,6 +1322,26 @@ class TestFusion:
aes.mul, aes.mul,
} }
def test_multiple_outputs_fused_root_elemwise(self):
"""Test that a root elemwise output (single layer) is reused when
there is another fused output"""
# By default, we do not introduce Composite for single layers of Elemwise
x = at.vector("x")
out1 = at.cos(x)
f = pytensor.function([x], out1, mode=self.mode)
nodes = tuple(f.maker.fgraph.apply_nodes)
assert len(nodes) == 1
assert isinstance(nodes[0].op.scalar_op, aes.Cos)
# However, when it can be composed with another output, we should not
# compute that root Elemwise twice
out2 = at.log(out1)
f = pytensor.function([x], [out1, out2], mode=self.mode)
nodes = tuple(f.maker.fgraph.apply_nodes)
assert len(nodes) == 1
assert isinstance(nodes[0].op.scalar_op, Composite)
class TimesN(aes.basic.UnaryScalarOp): class TimesN(aes.basic.UnaryScalarOp):
""" """
......
...@@ -887,10 +887,9 @@ class TestLocalSubtensorLift: ...@@ -887,10 +887,9 @@ class TestLocalSubtensorLift:
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle) assert isinstance(prog[0].op, DimShuffle)
assert isinstance(prog[1].op.scalar_op, aes.Composite) # Composite{add,exp} assert isinstance(prog[1].op.scalar_op, aes.Composite) # Composite{add,exp}
assert prog[2].op == add or prog[3].op == add
# first subtensor # first subtensor
assert isinstance(prog[2].op, Subtensor) or isinstance(prog[3].op, Subtensor) assert isinstance(prog[2].op, Subtensor)
assert len(prog) == 4 assert len(prog) == 3
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
def test_basic_7(self): def test_basic_7(self):
......
...@@ -273,8 +273,7 @@ def test_debugprint(): ...@@ -273,8 +273,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
exp_res = dedent( exp_res = dedent(
r""" r"""
Elemwise{Composite{(i0 + (i1 - i2))}} 4 Elemwise{Composite{(i2 + (i0 - i1))}} 4
|A
|InplaceDimShuffle{x,0} v={0: [0]} 3 |InplaceDimShuffle{x,0} v={0: [0]} 3
| |CGemv{inplace} d={0: [0]} 2 | |CGemv{inplace} d={0: [0]} 2
| |AllocEmpty{dtype='float64'} 1 | |AllocEmpty{dtype='float64'} 1
...@@ -285,6 +284,7 @@ def test_debugprint(): ...@@ -285,6 +284,7 @@ def test_debugprint():
| |<TensorType(float64, (?,))> | |<TensorType(float64, (?,))>
| |TensorConstant{0.0} | |TensorConstant{0.0}
|D |D
|A
""" """
).lstrip() ).lstrip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论