提交 9ef575b7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid double cloning of Composite Ops created by FusionOptimizer

上级 30e19e53
......@@ -13,7 +13,6 @@ you probably want to use pytensor.tensor.[c,z,f,d,b,w,i,l,]scalar!
import builtins
import math
from collections.abc import Callable
from copy import copy
from itertools import chain
from textwrap import dedent
from typing import Any, TypeAlias
......@@ -4093,12 +4092,12 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
self.prepare_node_called = set()
super().__init__(*args, **kwargs)
def _cleanup_graph(self, inputs, outputs):
def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True):
# TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`.
fgraph = FunctionGraph(copy(inputs), copy(outputs))
fgraph = FunctionGraph(inputs, outputs, clone=clone)
# Validate node types
for node in fgraph.apply_nodes:
......@@ -4281,7 +4280,9 @@ class Composite(ScalarInnerGraphOp):
init_param: tuple[str, ...] = ("inputs", "outputs")
def __init__(self, inputs, outputs, name="Composite"):
def __init__(
self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True
):
self.name = name
self._name = None
# We need to clone the graph as sometimes its nodes already
......@@ -4299,10 +4300,13 @@ class Composite(ScalarInnerGraphOp):
if len(outputs) > 1 or not any(
isinstance(var.owner.op, Composite) for var in outputs
):
# No inner Composite
if clone_graph:
inputs, outputs = clone(inputs, outputs)
else:
# Inner Composite that we need to flatten
# FIXME: There could be a composite in the middle of the graph, why is this here?
# If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway.
assert len(outputs) == 1
# 1. Create a new graph from inputs up to the
# Composite
......@@ -4321,7 +4325,8 @@ class Composite(ScalarInnerGraphOp):
assert res[0] != inputs
inputs, outputs = res[0], res2[1]
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
# We already cloned the graph, or the user told us there was no need for it
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False)
self.inputs_type = tuple(input.type for input in self.inputs)
self.outputs_type = tuple(output.type for output in self.outputs)
self.nin = len(inputs)
......
......@@ -915,12 +915,13 @@ class FusionOptimizer(GraphRewriter):
break
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
composite_outputs = Elemwise(ps.Composite(scalar_inputs, scalar_outputs))(
*inputs
)
if not isinstance(composite_outputs, list):
composite_outputs = [composite_outputs]
for old_out, composite_out in zip(outputs, composite_outputs, strict=True):
composite_outputs = Elemwise(
# No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables
ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False)
)(*inputs, return_list=True)
assert len(outputs) == len(composite_outputs)
for old_out, composite_out in zip(outputs, composite_outputs):
# Preserve any names on the original outputs
if old_out.name:
composite_out.name = old_out.name
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论