提交 0670ac2f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fuse consecutive Elemwise nodes with multiple clients

上级 d5cb23a5
......@@ -652,10 +652,10 @@ class Elemwise(OpenMPOp):
def prepare_node(self, node, storage_map, compute_map, impl):
# Postpone the ufunc building to the last minutes due to:
# - NumPy ufunc support only up to 31 inputs.
# - NumPy ufunc support only up to 32 operands (inputs and outputs)
# But our c code support more.
# - nfunc is reused for scipy and scipy is optional
if len(node.inputs) > 32 and self.ufunc and impl == "py":
if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py":
impl = "c"
if getattr(self, "nfunc_spec", None) and impl != "c":
......@@ -677,7 +677,7 @@ class Elemwise(OpenMPOp):
self.nfunc = module
if (
len(node.inputs) < 32
(len(node.inputs) + len(node.outputs)) <= 32
and (self.nfunc is None or self.scalar_op.nin != len(node.inputs))
and self.ufunc is None
and impl == "py"
......@@ -727,28 +727,18 @@ class Elemwise(OpenMPOp):
self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl)
def perform(self, node, inputs, output_storage):
if len(node.inputs) >= 32:
if (len(node.inputs) + len(node.outputs)) > 32:
# Some versions of NumPy will segfault, other will raise a
# ValueError, if the number of inputs to a ufunc is 32 or more.
# ValueError, if the number of operands in an ufunc is more than 32.
# In that case, the C version should be used, or Elemwise fusion
# should be disabled.
# FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage)
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
# Determine the shape of outputs
out_shape = []
for values in zip(*[input.shape for input in inputs]):
if any(v == 0 for v in values):
# All non-broadcasted dimensions should be zero
assert max(values) <= 1
out_shape.append(0)
else:
out_shape.append(max(values))
out_shape = tuple(out_shape)
ufunc_args = inputs
ufunc_kwargs = {}
# We supported in the past calling manually op.perform.
......
import sys
import time
from collections import defaultdict
from typing import Optional
from collections import defaultdict, deque
from functools import lru_cache
from typing import DefaultDict, Generator, List, Set, Tuple, TypeVar
from warnings import warn
import pytensor
import pytensor.scalar.basic as aes
from pytensor import compile
from pytensor import clone_replace, compile
from pytensor.compile.mode import get_target_language
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, io_toposort
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort
from pytensor.graph.features import ReplaceValidate
from pytensor.graph.op import compute_test_value, get_test_value
from pytensor.graph.fg import ApplyOrOutput
from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
GraphRewriter,
......@@ -20,7 +21,7 @@ from pytensor.graph.rewriting.basic import (
node_rewriter,
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
......@@ -592,333 +593,438 @@ def local_add_mul_fusion(fgraph, node):
return [output]
def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None):
r"""Create a recursive function that fuses `Elemwise` `Op`\s.
The basic idea is that we loop through an `Elemwise` node's inputs, find
other `Elemwise` nodes, determine the scalars input types for all of the
`Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types
and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a
new "fused" `Elemwise`.
It's parameterized in order to work for `Elemwise` `Op`\s.
Parameters
----------
op_class : type
`Elemwise` class (the one that we want to fuse)
max_input_fct : callable
A function that returns the maximum number of inputs that this `Elemwise`
can take.
On the CPU we limit to 32 input variables since that is the maximum
NumPy support.
maker: callable
A function with the signature ``(node, *args)`` that constructs an
`op_class` instance (e.g. ``op_class(*args)``).
"""
if maker is None:
def maker(node, scalar_op):
return op_class(scalar_op)
def local_fuse(fgraph, node):
r"""Fuse `Elemwise` `Op`\s in a node.
As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the
same shape.
def elemwise_max_operands_fct(node) -> int:
# `Elemwise.perform` uses NumPy ufuncs and they are limited to 32 operands (inputs and outputs)
if not config.cxx:
return 32
return 1024
For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C
compiler do the cast.
The number of dimensions is validated at call time by PyTensor itself.
class FusionOptimizer(GraphRewriter):
"""Graph optimizer that fuses consecutive Elemwise operations."""
"""
# TODO: use broadcast flag?
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
# TODO: don't do this rewrite as a `NodeRewriter`.
# Analyze the graph in terms of elemwise subgraphs, and then
# replace each subgraph with a Composite version.
@staticmethod
def elemwise_to_scalar(inputs, outputs):
replace_inputs = [(inp, inp.clone()) for inp in inputs]
outputs = clone_replace(outputs, replace=replace_inputs)
# TODO: use malloc and copy to transfer arguments that don't
# fit within the parameter space of 256 bytes
#
# TODO: Merge with multiple output to merge when an inputs
# have multiple clients. This can't be done with a `NodeRewriter`
# TODO: Related: Support composites with multiple outputs
# TODO: Use Composite to combine Elemwise and Reduce
# operations. We have to loop over the data anyway... might
# as well sum it up while we're at it (this can be trickier
# than i'm making it seound here. The data-traversal should be
# done contiguously, and the summing-up might not be easy or
# worthwhile if the summation axis doesn't line up with a
# contiguous dimension)
if type(node.op) is not op_class:
return False
if len(node.outputs) > 1:
# We don't support fusion for nodes with multiple outputs.
return
inputs = [] # inputs of the new Elemwise op.
s_inputs = [] # inputs of the new scalar op used by the Composite.
# Inputs of the new scalar op that represents the current node.
s_g = []
# There is a hard limit of 256 bytes for the formal argument list to a
# GPU kernel function.
max_nb_input = max_input_fct(node)
# The number of inputs to the new fused op if we do not fuse more
# inputs.
new_nb_input = len(node.inputs)
# Did we fuse something?
# Needed as we can fuse unary op that don't change the number of
# inputs.
# And there is a case where the inputs are the same as the current
# node. That won't change the number of inputs of the new op.
fused = False
for i in node.inputs:
scalar_node: Optional[Apply] = None
# Will store inputs of the fused node that are not currently inputs
# of the node we want to create (to avoid duplicating inputs).
tmp_input = []
# Same as tmp_input, but for scalars.
tmp_scalar = []
# We should not check the number of inputs here
# As fusing op don't always change the number of input.
# If a variable is used as multiple into to the same node,
# we still want to fusion. So we take the set.
if (
i.owner
and isinstance(i.owner.op, op_class)
and len({n for n, idx in fgraph.clients[i]}) == 1
and
# Do not merge elemwise that don't have the same
# broadcastable pattern to don't redo duplicate
# computation due to broadcast.
i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable
):
try:
tmp_s_input = []
# we should not put duplicate input into s_inputs and inputs
for ii in i.owner.inputs:
if ii in inputs:
tmp_s_input.append(s_inputs[inputs.index(ii)])
elif ii in tmp_input:
tmp_s_input.append(tmp_scalar[tmp_input.index(ii)])
else:
tmp = aes.get_scalar_type(ii.type.dtype).make_variable()
try:
tv = get_test_value(ii)
# Sometimes the original inputs have
# zero-valued shapes in some dimensions, which
# implies that this whole scalar thing doesn't
# make sense (i.e. we're asking for the scalar
# value of an entry in a zero-dimensional
# array).
# This will eventually lead to an error in the
# `compute_test_value` call below when/if
# `config.compute_test_value_opt` is enabled
# (for debugging, more or less)
tmp.tag.test_value = tv.item()
except (TestValueError, ValueError):
pass
tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])
# Use the `Op.make_node` interface in case `Op.__call__`
# has been customized
scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input)
if config.compute_test_value_opt != "off":
# This is required because `Op.make_node` won't do it
compute_test_value(scalar_node)
# If the scalar_op doesn't have a C implementation, we skip
# its fusion to allow fusion of the other ops
i.owner.op.scalar_op.c_code(
scalar_node,
"test_presence_of_c_code",
["x" for x in i.owner.inputs],
["z" for z in i.owner.outputs],
{"fail": "%(fail)s"},
)
inputs = [inp for _, inp in replace_inputs]
fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False)
middle_inputs = []
except (NotImplementedError, MethodNotDefined):
warn(
"Rewrite warning: "
f"The Op {i.owner.op.scalar_op} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
scalar_inputs = [
aes.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs
]
middle_scalar_inputs = []
for node in fg.toposort():
node_scalar_inputs = []
for inp in node.inputs:
if inp in inputs:
node_scalar_inputs.append(scalar_inputs[inputs.index(inp)])
elif inp in middle_inputs:
node_scalar_inputs.append(
middle_scalar_inputs[middle_inputs.index(inp)]
)
scalar_node = None
# Compute the number of inputs in case we fuse this input.
# We subtract 1 because we replace the existing input with the new
# inputs from `tmp_input`.
new_nb_input_ = new_nb_input + len(tmp_input) - 1
# If the new input is already an input of the current node, it was
# already counted when `new_nb_input` was initialized to
# len(node.inputs).
# This can happen when a variable is used both by the Elemwise to
# fuse and the current node.
for x in tmp_input:
if x in node.inputs:
new_nb_input_ -= 1
if scalar_node and (new_nb_input_ <= max_nb_input):
fused = True
new_nb_input = new_nb_input_
inputs.extend(tmp_input)
s_inputs.extend(tmp_scalar)
s_g.extend(scalar_node.outputs)
else:
# We must support the case where the same variable appears many
# times within the inputs
if inputs.count(i) == node.inputs.count(i):
s = s_inputs[inputs.index(i)]
else:
s = aes.get_scalar_type(i.type.dtype).make_variable()
if config.compute_test_value_opt != "off":
try:
v = get_test_value(i)
# See the zero-dimensional test value situation
# described above.
s.tag.test_value = v.item()
except (TestValueError, ValueError):
pass
inputs.append(i)
s_inputs.append(s)
s_g.append(s)
if not fused:
return False
if new_nb_input != len(inputs) or len(s_inputs) != len(inputs):
# TODO FIXME: This shouldn't be a generic `Exception`
raise Exception(
"Something has gone wrong with the elemwise fusion rewrite; skipping."
)
s_new_out = node.op.scalar_op(*s_g, return_list=True)
try:
s_new_out[0].owner.op.c_code(
s_new_out[0].owner,
"test_presence_of_c_code",
["x" for x in s_g],
["z" for x in s_new_out],
{"fail": "%(fail)s"},
)
except (NotImplementedError, MethodNotDefined):
name = str(s_new_out[0].owner.op)
warn(
"Rewrite warning: "
f"The Op {name} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
return False
# create the composite op.
composite_op = aes.Composite(s_inputs, s_new_out)
# create the new node.
# Do not call make_node to have test_value
new_node = maker(node, composite_op)(*inputs).owner
assert len(new_node.outputs) == 1
assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype
new_scalar_input = aes.get_scalar_type(
inp.type.dtype
).make_variable()
node_scalar_inputs.append(new_scalar_input)
middle_scalar_inputs.append(new_scalar_input)
middle_inputs.append(inp)
new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs)
middle_scalar_inputs.append(new_scalar_node.outputs[0])
middle_inputs.append(node.outputs[0])
scalar_outputs = [
middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs
]
return scalar_inputs, scalar_outputs
if len(new_node.inputs) > max_nb_input:
warn(
"Loop fusion failed because the resulting node "
"would exceed the kernel argument limit."
)
return False
# we fuse as many that we can at the same time to make debug mode faster
# debug mode will be faster as it won't test all intermediate step.
while True:
ret = local_fuse(fgraph, new_node)
if ret is not False and ret is not None:
assert len(ret) == len(new_node.outputs)
assert len(ret) == 1
new_node = ret[0].owner
else:
break
def apply(self, fgraph):
nb_replacement = 0
return new_node.outputs
if fgraph.profile:
validate_before = fgraph.profile.validate_time
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
return local_fuse
max_operands = elemwise_max_operands_fct(None)
def find_next_fuseable_subgraph(
fg: FunctionGraph,
) -> Generator[Tuple[List[Variable], List[Variable]], None, None]:
"""Find all subgraphs in a FunctionGraph that can be fused together
Yields
-------
List of inputs and outputs that determine subgraphs which can be fused.
This generator assumes that such subgraph is replaced by a single
Elemwise Composite before being accessed again in the next iteration.
"""
FUSEABLE_MAPPING = DefaultDict[Variable, List[Apply]]
UNFUSEABLE_MAPPING = DefaultDict[Variable, Set[ApplyOrOutput]]
def initialize_fuseable_mappings(
*, fg: FunctionGraph
) -> Tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]:
@lru_cache(maxsize=None)
def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
# TODO: This should not play a role in non-c backends!
if node.op.scalar_op.supports_c_code(node.inputs, node.outputs):
return True
else:
warn(
"Optimization Warning: "
f"The Op {node.op.scalar_op} does not provide a C implementation."
" As well as being potentially slow, this also disables "
"loop fusion."
)
return False
# Fuseable nodes have to be accessed in a deterministic manner
# to ensure the rewrite remains deterministic.
# This is not a problem from unfuseable ones, as they can never
# become part of the graph.
fuseable_clients: FUSEABLE_MAPPING = defaultdict(list)
unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set)
for out, clients in fg.clients.items():
out_maybe_fuseable = (
out.owner
and isinstance(out.owner.op, Elemwise)
# and not isinstance(out.owner.op.scalar_op, aes.Composite)
and len(out.owner.outputs) == 1
and elemwise_scalar_op_has_c_code(out.owner)
)
for client, _ in clients:
if (
out_maybe_fuseable
and not isinstance(client, str) # "output"
and isinstance(client.op, Elemwise)
# and not isinstance(client.op.scalar_op, aes.Composite)
and len(client.outputs) == 1
and out.type.broadcastable
== client.outputs[0].type.broadcastable
and elemwise_scalar_op_has_c_code(client)
):
if client not in fuseable_clients[out]:
fuseable_clients[out].append(client)
else:
unfuseable_clients[out].add(client)
return fuseable_clients, unfuseable_clients
def find_fuseable_subgraph(
*,
fg: FunctionGraph,
visited_nodes: Set[Apply],
fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING,
) -> Tuple[List[Variable], List[Variable]]:
KT = TypeVar("KT")
VT = TypeVar("VT", list, set)
def shallow_clone_defaultdict(
d: DefaultDict[KT, VT]
) -> DefaultDict[KT, VT]:
new_dict: DefaultDict[KT, VT] = defaultdict(d.default_factory)
new_dict.update({k: v.copy() for k, v in d.items()})
return new_dict
def variables_depend_on(
variables, depend_on, stop_search_at=None
) -> bool:
return any(
a in depend_on
for a in ancestors(variables, blockers=stop_search_at)
)
toposort = fg.toposort()
for starting_node in toposort:
if starting_node in visited_nodes:
continue
def elemwise_max_input_fct(node):
# `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs.
if not config.cxx:
return 31
return 1024
starting_out = starting_node.outputs[0]
if not fuseable_clients.get(starting_out):
visited_nodes.add(starting_node)
continue
subgraph_inputs: List[Variable] = []
subgraph_outputs: List[Variable] = []
unfuseable_clients_subgraph: Set[Variable] = set()
local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct)
# Shallow cloning of maps so that they can be manipulated in place
fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients)
unfuseable_clients_clone = shallow_clone_defaultdict(
unfuseable_clients
)
fuseable_nodes_to_visit = deque([starting_node])
# We now try to expand as much as possible towards the potentially
# fuseable clients and ancestors to detect the largest possible
# subgraph that can be Composed together into a single `Op`. The
# largest issue to watch out is for cyclical dependencies, where
# some inputs or clients may depend on other nodes of the same
# subgraph via a path that cannot be included in the Composite
# (unfuseable)
while fuseable_nodes_to_visit:
next_node = fuseable_nodes_to_visit.popleft()
visited_nodes.add(next_node)
next_out = next_node.outputs[0]
# If the output variable of next_node has no fuseable clients
# or has unfuseable clients, then next_node must become an output
# if it is to be fused.
must_become_output = (
next_out not in fuseable_clients_temp
or next_out in unfuseable_clients_clone
)
class FusionOptimizer(GraphRewriter):
"""Graph rewriter that simply runs node fusion operations.
# We have backtracked to this node, and it may no longer be a viable output,
# so we remove it and check again as if we had never seen this node
if must_become_output and next_out in subgraph_outputs:
subgraph_outputs.remove(next_out)
required_unfuseable_inputs = [
inp
for inp in next_node.inputs
if next_node in unfuseable_clients_clone.get(inp, ())
]
new_required_unfuseable_inputs = [
inp
for inp in required_unfuseable_inputs
if inp not in subgraph_inputs
]
must_backtrack = False
if new_required_unfuseable_inputs and subgraph_outputs:
# We need to check that any new inputs required by this node
# do not depend on other outputs of the current subgraph,
# via an unfuseable path.
if variables_depend_on(
[next_out],
depend_on=unfuseable_clients_subgraph,
stop_search_at=subgraph_outputs,
):
must_backtrack = True
if not must_backtrack:
implied_unfuseable_clients = {
c
for client in unfuseable_clients_clone.get(next_out, ())
if not isinstance(client, str) # "output"
for c in client.outputs
}
new_implied_unfuseable_clients = (
implied_unfuseable_clients - unfuseable_clients_subgraph
)
TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that.
if new_implied_unfuseable_clients and subgraph_inputs:
# We need to check that any inputs of the current subgraph
# do not depend on other clients of this node,
# via an unfuseable path.
if variables_depend_on(
subgraph_inputs,
depend_on=new_implied_unfuseable_clients,
):
must_backtrack = True
if must_backtrack:
for inp in next_node.inputs:
if (
inp.owner in visited_nodes
# next_node could have the same input repeated
and next_node in fuseable_clients_temp[inp]
):
fuseable_clients_temp[inp].remove(next_node)
unfuseable_clients_clone[inp].add(next_node)
# This input must become an output of the subgraph,
# because it can't be merged with next_node.
# We will revisit it to make sure this is safe.
fuseable_nodes_to_visit.appendleft(inp.owner)
for client in fuseable_clients_temp[next_out]:
if client in visited_nodes:
fuseable_clients_temp[next_out].remove(client)
unfuseable_clients_clone[next_out].add(client)
# next_out must become an input of the subgraph.
# We will revisit any of its clients currently
# in the subgraph to make sure this is safe.
fuseable_nodes_to_visit.appendleft(client)
# Revisit node at a later time
visited_nodes.remove(next_node)
continue
# Adding next_node to subgraph does not result in any
# immediate dependency problems. Update subgraph
# mappings as if it next_node was part of it.
# Useless inputs will be removed by the useless Composite rewrite
for inp in new_required_unfuseable_inputs:
if inp not in subgraph_inputs:
subgraph_inputs.append(inp)
if must_become_output:
subgraph_outputs.append(next_out)
unfuseable_clients_subgraph.update(
new_implied_unfuseable_clients
)
"""
# Expand through unvisited fuseable ancestors
for inp in sorted(
(
inp
for inp in next_node.inputs
if (
inp not in required_unfuseable_inputs
and inp.owner not in visited_nodes
)
),
key=lambda inp: toposort.index(inp.owner),
reverse=True,
):
fuseable_nodes_to_visit.appendleft(inp.owner)
# Expand through unvisited fuseable clients
for next_node in sorted(
(
node
for node in fuseable_clients_temp.get(next_out, ())
if node not in visited_nodes
),
key=lambda node: toposort.index(node),
):
fuseable_nodes_to_visit.append(next_node)
# Don't return if final subgraph is just the original Elemwise
if len(subgraph_outputs) == 1 and set(
subgraph_outputs[0].owner.inputs
) == set(subgraph_inputs):
# Update global fuseable mappings
# No input was actually fuseable
for inp in starting_node.inputs:
if starting_node in fuseable_clients.get(inp, ()):
fuseable_clients[inp].remove(starting_node)
unfuseable_clients[inp].add(starting_node)
# No client was actually fuseable
unfuseable_clients[starting_out].update(
fuseable_clients.pop(starting_out, ())
)
continue
def __init__(self, node_rewriter):
super().__init__()
self.node_rewriter = node_rewriter
return subgraph_inputs, subgraph_outputs
raise ValueError
def update_fuseable_mappings_after_fg_replace(
*,
fg: FunctionGraph,
visited_nodes: Set[Apply],
fuseable_clients: FUSEABLE_MAPPING,
unfuseable_clients: UNFUSEABLE_MAPPING,
starting_nodes: Set[Apply],
) -> None:
# Find new composite node and dropped intermediate nodes
# by comparing the current fg.apply nodes with the cached
# original nodes
next_nodes = fg.apply_nodes
(new_composite_node,) = next_nodes - starting_nodes
dropped_nodes = starting_nodes - next_nodes
# Remove intermediate Composite nodes from mappings
for dropped_node in dropped_nodes:
(dropped_out,) = dropped_node.outputs
fuseable_clients.pop(dropped_out, None)
unfuseable_clients.pop(dropped_out, None)
visited_nodes.remove(dropped_node)
# Update fuseable information for subgraph inputs
for inp in subgraph_inputs:
if inp in fuseable_clients:
new_fuseable_clients = [
client
for client in fuseable_clients[inp]
if client not in dropped_nodes
]
if new_fuseable_clients:
fuseable_clients[inp] = new_fuseable_clients
else:
fuseable_clients.pop(inp)
unfuseable_clients[inp] = (
unfuseable_clients[inp] - dropped_nodes
) | {new_composite_node}
# Update fuseable information for subgraph outputs
for out in new_composite_node.outputs:
unfuseable_clients[out] = {client for client, _ in fg.clients[out]}
visited_nodes.add(new_composite_node)
return
# We start by creating two maps, 1) from each node to each potentially
# fuseable client (both nodes must be single output Elemwise with same
# broadcast type) and 2) from each node to each certainly unfuseable
# client (those that don't fit into 1))
fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg)
visited_nodes: Set[Apply] = set()
while True:
starting_nodes = fg.apply_nodes.copy()
try:
subgraph_inputs, subgraph_outputs = find_fuseable_subgraph(
fg=fg,
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
)
except ValueError:
return
else:
# The caller is now expected to update fg in place,
# by replacing the subgraph with a Composite Op
yield subgraph_inputs, subgraph_outputs
# This is where we avoid repeated work by using a stateful
# generator. For large models (as in `TestFusion.test_big_fusion`)
# this can provide huge speedups
update_fuseable_mappings_after_fg_replace(
fg=fg,
visited_nodes=visited_nodes,
fuseable_clients=fuseable_clients,
unfuseable_clients=unfuseable_clients,
starting_nodes=starting_nodes,
)
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
for inputs, outputs in find_next_fuseable_subgraph(fgraph):
if (len(inputs) + len(outputs)) > max_operands:
warn(
"Loop fusion failed because the resulting node would exceed "
"the kernel argument limit."
)
break
def apply(self, fgraph):
did_something = True
nb_iter = 0
nb_replacement = 0
nb_inconsistency_replace = 0
time_toposort = 0
if fgraph.profile:
validate_before = fgraph.profile.validate_time
callbacks_before = fgraph.execute_callbacks_times.copy()
callback_before = fgraph.execute_callbacks_time
while did_something:
t0 = time.perf_counter()
nodelist = list(fgraph.toposort())
time_toposort += time.perf_counter() - t0
nodelist.reverse()
did_something = False
for node in nodelist:
# Don't try to fuse node that have already been fused.
if node in fgraph.apply_nodes:
new_outputs = self.node_rewriter(fgraph, node)
if new_outputs:
assert len(new_outputs) == len(node.outputs)
try:
fgraph.replace_all_validate(
list(zip(node.outputs, new_outputs)),
reason=self.__class__.__name__,
)
did_something = True
nb_replacement += 1
except InconsistencyError:
nb_inconsistency_replace += 1
nb_iter += 1
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
composite_outputs = Elemwise(aes.Composite(scalar_inputs, scalar_outputs))(
*inputs
)
if not isinstance(composite_outputs, list):
composite_outputs = [composite_outputs]
for old_out, composite_out in zip(outputs, composite_outputs):
if old_out.name:
composite_out.name = old_out.name
fgraph.replace_all_validate(
list(zip(outputs, composite_outputs)),
reason=self.__class__.__name__,
)
nb_replacement += 1
if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before
......@@ -933,21 +1039,22 @@ class FusionOptimizer(GraphRewriter):
validate_time = None
callback_time = None
callbacks_time = {}
return (
self,
nb_iter,
1, # nb_iter
nb_replacement,
nb_inconsistency_replace,
0, # nb_inconsintency_replace
validate_time,
callback_time,
callbacks_time,
time_toposort,
-1, # toposort_time
)
@classmethod
def print_profile(cls, stream, prof, level=0):
@staticmethod
def print_profile(stream, prof, level=0):
blanc = " " * level
print(blanc, cls.__name__, file=stream)
print(blanc, "FusionOptimizer", file=stream)
print(blanc, " nb_iter", prof[1], file=stream)
print(blanc, " nb_replacement", prof[2], file=stream)
print(blanc, " nb_inconsistency_replace", prof[3], file=stream)
......@@ -973,7 +1080,7 @@ if config.tensor__local_elemwise_fusion:
)
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(local_elemwise_fusion),
FusionOptimizer(),
"fast_run",
"fusion",
position=1,
......@@ -999,7 +1106,9 @@ def local_useless_composite(fgraph, node):
):
return
comp = node.op.scalar_op
used_outputs_idxs = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]]
used_outputs_idxs = [
i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]
]
used_inner_outputs = [comp.outputs[i] for i in used_outputs_idxs]
comp_fgraph = FunctionGraph(
inputs=comp.inputs, outputs=used_inner_outputs, clone=False
......
......@@ -27,7 +27,6 @@ pytensor/tensor/random/basic.py
pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py
pytensor/tensor/rewriting/basic.py
pytensor/tensor/rewriting/elemwise.py
pytensor/tensor/shape.py
pytensor/tensor/slinalg.py
pytensor/tensor/subtensor.py
......
......@@ -2,7 +2,7 @@ import numpy as np
import pytest
import pytensor.tensor as at
from pytensor.compile import UnusedInputError
from pytensor.compile import UnusedInputError, get_mode
from pytensor.compile.function import function, pfunc
from pytensor.compile.function.pfunc import rebuild_collect_shared
from pytensor.compile.io import In
......@@ -200,7 +200,12 @@ class TestPfunc:
bval = np.arange(5)
b.set_value(bval, borrow=True)
bval = data_of(b)
f = pfunc([], [b_out], updates=[(b, (b_out + 3))], mode="FAST_RUN")
f = pfunc(
[],
[b_out],
updates=[(b, (b_out + 3))],
mode=get_mode("FAST_RUN").excluding("fusion"),
)
assert (f() == (np.arange(5) * 2)).all()
# because of the update
assert (b.get_value(borrow=True) == ((np.arange(5) * 2) + 3)).all()
......
import contextlib
import numpy as np
import pytest
......@@ -17,11 +15,14 @@ from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import Composite
from pytensor.tensor.basic import MakeVector
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import add
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import (
add,
bitwise_and,
bitwise_or,
cos,
......@@ -29,6 +30,7 @@ from pytensor.tensor.math import (
dot,
eq,
exp,
ge,
int_div,
invert,
iround,
......@@ -900,6 +902,72 @@ class TestFusion:
fxv * np.sin(fsv),
"float32",
),
# Multiple output cases # 72
(
(
# sum(logp)
at_sum(-((fx - fy) ** 2) / 2),
# grad(logp)
at.grad(at_sum(-((fx - fy) ** 2) / 2), wrt=fx),
),
(fx, fy),
(fxv, fyv),
3,
(
np.sum(-((fxv - fyv) ** 2) / 2),
-(fxv - fyv),
),
("float32", "float32"),
),
# Two Composite graphs that share the same input, but are split by
# a non-elemwise operation (Assert)
(
(
log(
ge(
assert_op(
at_abs(fx),
at_all(ge(at_abs(fx), 0)),
),
0,
)
),
),
(fx,),
(fxv,),
4,
(np.zeros_like(fxv),),
("float32",),
),
# Two subgraphs that share the same non-fuseable input, but are otherwise
# completely independent
(
(
true_div(
mul(
at_sum(fx + 5), # breaks fusion
exp(fx),
),
(fx + 5),
),
),
(fx,),
(fxv,),
4,
(np.sum(fxv + 5) * np.exp(fxv) / (fxv + 5),),
("float32",),
),
pytest.param(
(
(sin(exp(fx)), exp(sin(fx))),
(fx,),
(fxv,),
1,
(np.sin(np.exp(fxv)), np.exp(np.sin(fxv))),
("float32", "float32"),
),
marks=pytest.mark.xfail, # Not implemented yet
),
],
)
def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
......@@ -910,23 +978,34 @@ class TestFusion:
if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy]
if not isinstance(g, (tuple, list)):
g = (g,)
answer = (answer,)
out_dtype = (out_dtype,)
if self._shared is None:
f = function(list(sym_inputs), g, mode=self.mode)
for x in range(nb_repeat):
out = f(*val_inputs)
if not isinstance(out, list):
out = (out,)
else:
out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out")
assert out.dtype == g.dtype
f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode)
out = [
self._shared(np.zeros((5,) * g_.ndim, dtype=od), "out")
for g_, od in zip(g, out_dtype)
]
assert all(o.dtype == g_.dtype for o, g_ in zip(out, g))
f = function(sym_inputs, [], updates=list(zip(out, g)), mode=self.mode)
for x in range(nb_repeat):
f(*val_inputs)
out = out.get_value()
out = [o.get_value() for o in out]
atol = 1e-8
if out_dtype == "float32":
if any(o == "float32" for o in out_dtype):
atol = 1e-6
assert np.allclose(out, answer * nb_repeat, atol=atol)
for o, a in zip(out, answer):
np.testing.assert_allclose(o, a * nb_repeat, atol=atol)
topo = f.maker.fgraph.toposort()
topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)]
......@@ -939,13 +1018,15 @@ class TestFusion:
# input of g,
# check that the number of input to the Composite
# Elemwise is ok
if len(set(g.owner.inputs)) == len(g.owner.inputs):
expected_len_sym_inputs = sum(
not isinstance(x, Constant) for x in topo_[0].inputs
)
assert expected_len_sym_inputs == len(sym_inputs)
for g_ in g:
if len(set(g_.owner.inputs)) == len(g_.owner.inputs):
expected_len_sym_inputs = sum(
not isinstance(x, Constant) for x in topo_[0].inputs
)
assert expected_len_sym_inputs == len(sym_inputs)
assert out_dtype == out.dtype
for od, o in zip(out_dtype, out):
assert od == o.dtype
def test_fusion_35_inputs(self):
r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit."""
......@@ -1006,6 +1087,30 @@ class TestFusion:
for node in dlogp.maker.fgraph.toposort()
)
@pytest.mark.xfail(reason="Fails due to #1244")
def test_add_mul_fusion_precedence(self):
"""Test that additions and multiplications are "fused together" before
a `Composite` `Op` is introduced. This fusion is done by canonicalization
"""
x, y, z = vectors("x", "y", "z")
out = log((x + y + z) / (x * y * z))
f = pytensor.function([x, y, z], out, mode=self.mode)
# There should be a single Composite Op
nodes = f.maker.fgraph.apply_nodes
assert len(nodes) == 1
(node,) = nodes
assert isinstance(node.op, Elemwise)
scalar_op = node.op.scalar_op
assert isinstance(scalar_op, Composite)
assert [node.op for node in scalar_op.fgraph.toposort()] == [
# There should be a single mul
aes.mul,
# There should be a single add
aes.add,
aes.true_div,
aes.log,
]
def test_add_mul_fusion_inplace(self):
x, y, z = dmatrices("xyz")
out = dot(x, y) + x + y + z
......@@ -1082,11 +1187,8 @@ class TestFusion:
@pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]])
def test_test_values(self, test_value):
"""Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions.
The test values we're talking about are the ones used when C implementations
are checked.
"""Make sure that `local_elemwise_fusion_op` uses test values correctly
when they have zero dimensions.
"""
x, y, z = dmatrices("xyz")
......@@ -1094,27 +1196,20 @@ class TestFusion:
y.tag.test_value = test_value
z.tag.test_value = test_value
if test_value.size == 0:
cm = pytest.raises(ValueError)
else:
cm = contextlib.suppress()
with config.change_flags(
compute_test_value="raise", compute_test_value_opt="raise"
):
out = x * y + z
with cm:
f = function([x, y, z], out, mode=self.mode)
f = function([x, y, z], out, mode=self.mode)
if test_value.size != 0:
# Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1
# Confirm that the fusion happened
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite)
assert len(f.maker.fgraph.toposort()) == 1
x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs
assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]]
)
assert np.array_equal(
f.maker.fgraph.outputs[0].tag.test_value,
np.full_like(test_value, 2.0),
)
@pytest.mark.parametrize("linker", ["cvm", "py"])
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
......@@ -1227,6 +1322,26 @@ class TestFusion:
aes.mul,
}
def test_multiple_outputs_fused_root_elemwise(self):
"""Test that a root elemwise output (single layer) is reused when
there is another fused output"""
# By default, we do not introduce Composite for single layers of Elemwise
x = at.vector("x")
out1 = at.cos(x)
f = pytensor.function([x], out1, mode=self.mode)
nodes = tuple(f.maker.fgraph.apply_nodes)
assert len(nodes) == 1
assert isinstance(nodes[0].op.scalar_op, aes.Cos)
# However, when it can be composed with another output, we should not
# compute that root Elemwise twice
out2 = at.log(out1)
f = pytensor.function([x], [out1, out2], mode=self.mode)
nodes = tuple(f.maker.fgraph.apply_nodes)
assert len(nodes) == 1
assert isinstance(nodes[0].op.scalar_op, Composite)
class TimesN(aes.basic.UnaryScalarOp):
"""
......
......@@ -887,10 +887,9 @@ class TestLocalSubtensorLift:
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle)
assert isinstance(prog[1].op.scalar_op, aes.Composite) # Composite{add,exp}
assert prog[2].op == add or prog[3].op == add
# first subtensor
assert isinstance(prog[2].op, Subtensor) or isinstance(prog[3].op, Subtensor)
assert len(prog) == 4
assert isinstance(prog[2].op, Subtensor)
assert len(prog) == 3
f([[0, 1], [2, 3]], [4, 5]) # let debugmode test something
def test_basic_7(self):
......
......@@ -273,8 +273,7 @@ def test_debugprint():
s = s.getvalue()
exp_res = dedent(
r"""
Elemwise{Composite{(i0 + (i1 - i2))}} 4
|A
Elemwise{Composite{(i2 + (i0 - i1))}} 4
|InplaceDimShuffle{x,0} v={0: [0]} 3
| |CGemv{inplace} d={0: [0]} 2
| |AllocEmpty{dtype='float64'} 1
......@@ -285,6 +284,7 @@ def test_debugprint():
| |<TensorType(float64, (?,))>
| |TensorConstant{0.0}
|D
|A
"""
).lstrip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论