提交 9a5deee0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't run MergeOptimization in Composite.fgraph

This would trigger it for every Composite/ScalarLoop present in the C-cache
上级 d9b494d2
......@@ -3998,6 +3998,42 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
def __init__(self, *args, **kwargs):
self.prepare_node_called = set()
def _cleanup_graph(self, inputs, outputs):
# TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`.
fgraph = FunctionGraph(copy(inputs), copy(outputs))
# Validate node types
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise TypeError(
f"The fgraph of {self.__class__.__name__} must be exclusively "
"composed of scalar operations."
)
# Run MergeOptimization to avoid duplicated nodes
MergeOptimizer().rewrite(fgraph)
inputs, outputs = fgraph.inputs, fgraph.outputs
# Clone identical outputs that may have been merged
# If fgraph.outputs = [out_A, out_B, out_A], then final outputs = [out_A, out_B, clone(out_A)]
if len(set(fgraph.outputs)) != len(outputs):
old_outputs = outputs
outputs = []
for old_output in old_outputs:
if old_output not in outputs:
outputs.append(old_output)
else:
node = old_output.owner
output_idx = node.outputs.index(old_output)
output = node.clone().outputs[output_idx]
outputs.append(output)
return inputs, outputs
@property
def fn(self):
return None
......@@ -4187,10 +4223,9 @@ class Composite(ScalarInnerGraphOp):
assert res[0] != inputs
inputs, outputs = res[0], res2[1]
self.inputs = copy(inputs)
self.outputs = copy(outputs)
self.inputs_type = tuple([input.type for input in inputs])
self.outputs_type = tuple([output.type for output in outputs])
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
self.inputs_type = tuple([input.type for input in self.inputs])
self.outputs_type = tuple([output.type for output in self.outputs])
self.nin = len(inputs)
self.nout = len(outputs)
super().__init__()
......@@ -4237,34 +4272,9 @@ class Composite(ScalarInnerGraphOp):
def fgraph(self):
if hasattr(self, "_fgraph"):
return self._fgraph
# The clone done by FunctionGraph is needed as we don't want
# the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work.
# fgraph cannot be a property of the base class because it messes up with C caching.
# We also need a `FunctionGraph(clone=True)` (default) according to an old comment
fgraph = FunctionGraph(self.inputs, self.outputs)
with config.change_flags(optimizer_verbose=False):
MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise TypeError(
"The fgraph to Composite must be exclusively"
" composed of ScalarOp instances."
)
# Clone identical outputs that have been merged
if len(set(fgraph.outputs)) != len(self.outputs):
old_outputs = fgraph.outputs
new_outputs = []
for output in old_outputs:
if output not in new_outputs:
new_outputs.append(output)
else:
node = output.owner
output_idx = node.outputs.index(output)
new_output = node.clone().outputs[output_idx]
new_outputs.append(new_output)
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)
self._fgraph = fgraph
return self._fgraph
......@@ -4389,7 +4399,7 @@ class Composite(ScalarInnerGraphOp):
return self.c_code_template % d
def c_code_cache_version_outer(self) -> Tuple[int, ...]:
return (3,)
return (4,)
class Compositef32:
......
from copy import copy
from itertools import chain
from typing import Optional, Sequence, Tuple, cast
from typing import Optional, Sequence, Tuple
from pytensor.compile import rebuild_collect_shared
from pytensor.graph import Constant, FunctionGraph, Variable, clone
from pytensor.graph.rewriting.basic import MergeOptimizer
from pytensor.scalar.basic import ScalarInnerGraphOp, ScalarOp, as_scalar
from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar
class ScalarLoop(ScalarInnerGraphOp):
......@@ -62,44 +60,38 @@ class ScalarLoop(ScalarInnerGraphOp):
if not len(init) == len(update):
raise ValueError("An update must be given for each init variable")
if until:
inputs, (*outputs, until) = clone([*init, *constant], [*update, until])
self.outputs = copy([*outputs, until])
inputs, outputs = clone([*init, *constant], [*update, until])
else:
inputs, outputs = clone([*init, *constant], update)
self.outputs = copy(outputs)
self.inputs = copy(inputs)
self.is_while = bool(until)
self.inputs_type = tuple(input.type for input in inputs)
self.outputs_type = tuple(output.type for output in outputs)
if self.is_while:
self.outputs_type = self.outputs_type + (cast(Variable, until).type,)
self.nin = len(inputs) + 1 # n_steps is not part of the inner graph
self.nout = len(outputs) + (1 if self.is_while else 0)
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
self._validate_updates(self.inputs, self.outputs)
self.inputs_type = tuple(input.type for input in self.inputs)
self.outputs_type = tuple(output.type for output in self.outputs)
self.nin = len(self.inputs) + 1 # n_steps is not part of the inner graph
self.nout = len(self.outputs)
self.name = name
self._validate_fgraph(FunctionGraph(self.inputs, self.outputs, clone=False))
super().__init__()
def output_types(self, input_types):
return self.outputs_type
def _validate_fgraph(self, fgraph: FunctionGraph) -> None:
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise TypeError(
"The fgraph of ScalarLoop must be composed exclusively of ScalarOp nodes"
)
init = fgraph.inputs
update = fgraph.outputs
def _validate_updates(
self, inputs: Sequence[Variable], outputs: Sequence[Variable]
) -> None:
init = inputs
update: Sequence[Variable]
if self.is_while:
*update, until = update
*update, until = outputs
if not until.type.dtype == "bool":
raise TypeError(
f"Until condition must be boolean, got {until}({until.type.dtype})"
)
else:
update = outputs
for i, u in zip(init, update):
if i.type != u.type:
raise TypeError(
......@@ -116,28 +108,9 @@ class ScalarLoop(ScalarInnerGraphOp):
def fgraph(self):
if hasattr(self, "_fgraph"):
return self._fgraph
# fgraph cannot be a property of the base class because it messes up with C caching.
# We also need a `FunctionGraph(clone=True)` (default) according to an old comment
fgraph = FunctionGraph(self.inputs, self.outputs)
# TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`.
MergeOptimizer().rewrite(fgraph)
self._validate_fgraph(fgraph)
# Clone identical outputs that have been merged
if len(set(fgraph.outputs)) != len(self.outputs):
old_outputs = fgraph.outputs
new_outputs = []
for output in old_outputs:
if output not in new_outputs:
new_outputs.append(output)
else:
node = output.owner
output_idx = node.outputs.index(output)
new_output = node.clone().outputs[output_idx]
new_outputs.append(new_output)
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)
self._fgraph = fgraph
return self._fgraph
......
......@@ -200,10 +200,11 @@ class TestComposite:
def test_non_scalar_error(self):
x = float32("x")
comp_op = Composite([x], [(at.zeros((2,)) + x).sum()])
with pytest.raises(TypeError, match=".*exclusively.*ScalarOp.*"):
comp_op.fgraph
with pytest.raises(
TypeError,
match="The fgraph of Composite must be exclusively composed of scalar operations",
):
Composite([x], [(at.zeros((2,)) + x).sum()])
def test_multi_out_perform(self):
from pytensor.graph.basic import Apply
......
......@@ -151,7 +151,8 @@ def test_non_scalar_error():
x = as_scalar(tensor_exp(x0))
with pytest.raises(
TypeError, match="must be composed exclusively of ScalarOp nodes"
TypeError,
match="The fgraph of ScalarLoop must be exclusively composed of scalar operations",
):
ScalarLoop(init=[x0], constant=[], update=[x])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论