提交 9a5deee0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't run MergeOptimization in Composite.fgraph

This would trigger it for every Composite/ScalarLoop present in the C-cache
上级 d9b494d2
...@@ -3998,6 +3998,42 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph): ...@@ -3998,6 +3998,42 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.prepare_node_called = set() self.prepare_node_called = set()
def _cleanup_graph(self, inputs, outputs):
# TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`.
fgraph = FunctionGraph(copy(inputs), copy(outputs))
# Validate node types
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise TypeError(
f"The fgraph of {self.__class__.__name__} must be exclusively "
"composed of scalar operations."
)
# Run MergeOptimization to avoid duplicated nodes
MergeOptimizer().rewrite(fgraph)
inputs, outputs = fgraph.inputs, fgraph.outputs
# Clone identical outputs that may have been merged
# If fgraph.outputs = [out_A, out_B, out_A], then final outputs = [out_A, out_B, clone(out_A)]
if len(set(fgraph.outputs)) != len(outputs):
old_outputs = outputs
outputs = []
for old_output in old_outputs:
if old_output not in outputs:
outputs.append(old_output)
else:
node = old_output.owner
output_idx = node.outputs.index(old_output)
output = node.clone().outputs[output_idx]
outputs.append(output)
return inputs, outputs
@property @property
def fn(self): def fn(self):
return None return None
...@@ -4187,10 +4223,9 @@ class Composite(ScalarInnerGraphOp): ...@@ -4187,10 +4223,9 @@ class Composite(ScalarInnerGraphOp):
assert res[0] != inputs assert res[0] != inputs
inputs, outputs = res[0], res2[1] inputs, outputs = res[0], res2[1]
self.inputs = copy(inputs) self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
self.outputs = copy(outputs) self.inputs_type = tuple([input.type for input in self.inputs])
self.inputs_type = tuple([input.type for input in inputs]) self.outputs_type = tuple([output.type for output in self.outputs])
self.outputs_type = tuple([output.type for output in outputs])
self.nin = len(inputs) self.nin = len(inputs)
self.nout = len(outputs) self.nout = len(outputs)
super().__init__() super().__init__()
...@@ -4237,34 +4272,9 @@ class Composite(ScalarInnerGraphOp): ...@@ -4237,34 +4272,9 @@ class Composite(ScalarInnerGraphOp):
def fgraph(self): def fgraph(self):
if hasattr(self, "_fgraph"): if hasattr(self, "_fgraph"):
return self._fgraph return self._fgraph
# fgraph cannot be a property of the base class because it messes up with C caching.
# The clone done by FunctionGraph is needed as we don't want # We also need a `FunctionGraph(clone=True)` (default) according to an old comment
# the fgraph to be set to the variable as we need to pickle
# them for the cache of c module to work.
fgraph = FunctionGraph(self.inputs, self.outputs) fgraph = FunctionGraph(self.inputs, self.outputs)
with config.change_flags(optimizer_verbose=False):
MergeOptimizer().rewrite(fgraph)
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise TypeError(
"The fgraph to Composite must be exclusively"
" composed of ScalarOp instances."
)
# Clone identical outputs that have been merged
if len(set(fgraph.outputs)) != len(self.outputs):
old_outputs = fgraph.outputs
new_outputs = []
for output in old_outputs:
if output not in new_outputs:
new_outputs.append(output)
else:
node = output.owner
output_idx = node.outputs.index(output)
new_output = node.clone().outputs[output_idx]
new_outputs.append(new_output)
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)
self._fgraph = fgraph self._fgraph = fgraph
return self._fgraph return self._fgraph
...@@ -4389,7 +4399,7 @@ class Composite(ScalarInnerGraphOp): ...@@ -4389,7 +4399,7 @@ class Composite(ScalarInnerGraphOp):
return self.c_code_template % d return self.c_code_template % d
def c_code_cache_version_outer(self) -> Tuple[int, ...]: def c_code_cache_version_outer(self) -> Tuple[int, ...]:
return (3,) return (4,)
class Compositef32: class Compositef32:
......
from copy import copy
from itertools import chain from itertools import chain
from typing import Optional, Sequence, Tuple, cast from typing import Optional, Sequence, Tuple
from pytensor.compile import rebuild_collect_shared from pytensor.compile import rebuild_collect_shared
from pytensor.graph import Constant, FunctionGraph, Variable, clone from pytensor.graph import Constant, FunctionGraph, Variable, clone
from pytensor.graph.rewriting.basic import MergeOptimizer from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar
from pytensor.scalar.basic import ScalarInnerGraphOp, ScalarOp, as_scalar
class ScalarLoop(ScalarInnerGraphOp): class ScalarLoop(ScalarInnerGraphOp):
...@@ -62,44 +60,38 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -62,44 +60,38 @@ class ScalarLoop(ScalarInnerGraphOp):
if not len(init) == len(update): if not len(init) == len(update):
raise ValueError("An update must be given for each init variable") raise ValueError("An update must be given for each init variable")
if until: if until:
inputs, (*outputs, until) = clone([*init, *constant], [*update, until]) inputs, outputs = clone([*init, *constant], [*update, until])
self.outputs = copy([*outputs, until])
else: else:
inputs, outputs = clone([*init, *constant], update) inputs, outputs = clone([*init, *constant], update)
self.outputs = copy(outputs)
self.inputs = copy(inputs)
self.is_while = bool(until) self.is_while = bool(until)
self.inputs_type = tuple(input.type for input in inputs) self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
self.outputs_type = tuple(output.type for output in outputs) self._validate_updates(self.inputs, self.outputs)
if self.is_while:
self.outputs_type = self.outputs_type + (cast(Variable, until).type,) self.inputs_type = tuple(input.type for input in self.inputs)
self.nin = len(inputs) + 1 # n_steps is not part of the inner graph self.outputs_type = tuple(output.type for output in self.outputs)
self.nout = len(outputs) + (1 if self.is_while else 0) self.nin = len(self.inputs) + 1 # n_steps is not part of the inner graph
self.nout = len(self.outputs)
self.name = name self.name = name
self._validate_fgraph(FunctionGraph(self.inputs, self.outputs, clone=False))
super().__init__() super().__init__()
def output_types(self, input_types): def output_types(self, input_types):
return self.outputs_type return self.outputs_type
def _validate_fgraph(self, fgraph: FunctionGraph) -> None: def _validate_updates(
for node in fgraph.apply_nodes: self, inputs: Sequence[Variable], outputs: Sequence[Variable]
if not isinstance(node.op, ScalarOp): ) -> None:
raise TypeError( init = inputs
"The fgraph of ScalarLoop must be composed exclusively of ScalarOp nodes" update: Sequence[Variable]
)
init = fgraph.inputs
update = fgraph.outputs
if self.is_while: if self.is_while:
*update, until = update *update, until = outputs
if not until.type.dtype == "bool": if not until.type.dtype == "bool":
raise TypeError( raise TypeError(
f"Until condition must be boolean, got {until}({until.type.dtype})" f"Until condition must be boolean, got {until}({until.type.dtype})"
) )
else:
update = outputs
for i, u in zip(init, update): for i, u in zip(init, update):
if i.type != u.type: if i.type != u.type:
raise TypeError( raise TypeError(
...@@ -116,28 +108,9 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -116,28 +108,9 @@ class ScalarLoop(ScalarInnerGraphOp):
def fgraph(self): def fgraph(self):
if hasattr(self, "_fgraph"): if hasattr(self, "_fgraph"):
return self._fgraph return self._fgraph
# fgraph cannot be a property of the base class because it messes up with C caching.
# We also need a `FunctionGraph(clone=True)` (default) according to an old comment
fgraph = FunctionGraph(self.inputs, self.outputs) fgraph = FunctionGraph(self.inputs, self.outputs)
# TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`.
MergeOptimizer().rewrite(fgraph)
self._validate_fgraph(fgraph)
# Clone identical outputs that have been merged
if len(set(fgraph.outputs)) != len(self.outputs):
old_outputs = fgraph.outputs
new_outputs = []
for output in old_outputs:
if output not in new_outputs:
new_outputs.append(output)
else:
node = output.owner
output_idx = node.outputs.index(output)
new_output = node.clone().outputs[output_idx]
new_outputs.append(new_output)
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)
self._fgraph = fgraph self._fgraph = fgraph
return self._fgraph return self._fgraph
......
...@@ -200,10 +200,11 @@ class TestComposite: ...@@ -200,10 +200,11 @@ class TestComposite:
def test_non_scalar_error(self): def test_non_scalar_error(self):
x = float32("x") x = float32("x")
comp_op = Composite([x], [(at.zeros((2,)) + x).sum()]) with pytest.raises(
TypeError,
with pytest.raises(TypeError, match=".*exclusively.*ScalarOp.*"): match="The fgraph of Composite must be exclusively composed of scalar operations",
comp_op.fgraph ):
Composite([x], [(at.zeros((2,)) + x).sum()])
def test_multi_out_perform(self): def test_multi_out_perform(self):
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
......
...@@ -151,7 +151,8 @@ def test_non_scalar_error(): ...@@ -151,7 +151,8 @@ def test_non_scalar_error():
x = as_scalar(tensor_exp(x0)) x = as_scalar(tensor_exp(x0))
with pytest.raises( with pytest.raises(
TypeError, match="must be composed exclusively of ScalarOp nodes" TypeError,
match="The fgraph of ScalarLoop must be exclusively composed of scalar operations",
): ):
ScalarLoop(init=[x0], constant=[], update=[x]) ScalarLoop(init=[x0], constant=[], update=[x])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论