提交 9ef575b7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid double cloning of Composite Ops created by FusionOptimizer

上级 30e19e53
...@@ -13,7 +13,6 @@ you probably want to use pytensor.tensor.[c,z,f,d,b,w,i,l,]scalar! ...@@ -13,7 +13,6 @@ you probably want to use pytensor.tensor.[c,z,f,d,b,w,i,l,]scalar!
import builtins import builtins
import math import math
from collections.abc import Callable from collections.abc import Callable
from copy import copy
from itertools import chain from itertools import chain
from textwrap import dedent from textwrap import dedent
from typing import Any, TypeAlias from typing import Any, TypeAlias
...@@ -4093,12 +4092,12 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph): ...@@ -4093,12 +4092,12 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
self.prepare_node_called = set() self.prepare_node_called = set()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _cleanup_graph(self, inputs, outputs): def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True):
# TODO: We could convert to TensorVariable, optimize graph, # TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable. # and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`. # This would introduce rewrites like `log(1 + x) -> log1p`.
fgraph = FunctionGraph(copy(inputs), copy(outputs)) fgraph = FunctionGraph(inputs, outputs, clone=clone)
# Validate node types # Validate node types
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
...@@ -4281,7 +4280,9 @@ class Composite(ScalarInnerGraphOp): ...@@ -4281,7 +4280,9 @@ class Composite(ScalarInnerGraphOp):
init_param: tuple[str, ...] = ("inputs", "outputs") init_param: tuple[str, ...] = ("inputs", "outputs")
def __init__(self, inputs, outputs, name="Composite"): def __init__(
self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True
):
self.name = name self.name = name
self._name = None self._name = None
# We need to clone the graph as sometimes its nodes already # We need to clone the graph as sometimes its nodes already
...@@ -4299,10 +4300,13 @@ class Composite(ScalarInnerGraphOp): ...@@ -4299,10 +4300,13 @@ class Composite(ScalarInnerGraphOp):
if len(outputs) > 1 or not any( if len(outputs) > 1 or not any(
isinstance(var.owner.op, Composite) for var in outputs isinstance(var.owner.op, Composite) for var in outputs
): ):
# No inner Composite if clone_graph:
inputs, outputs = clone(inputs, outputs) inputs, outputs = clone(inputs, outputs)
else: else:
# Inner Composite that we need to flatten # Inner Composite that we need to flatten
# FIXME: There could be a composite in the middle of the graph, why is this here?
# If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway.
assert len(outputs) == 1 assert len(outputs) == 1
# 1. Create a new graph from inputs up to the # 1. Create a new graph from inputs up to the
# Composite # Composite
...@@ -4321,7 +4325,8 @@ class Composite(ScalarInnerGraphOp): ...@@ -4321,7 +4325,8 @@ class Composite(ScalarInnerGraphOp):
assert res[0] != inputs assert res[0] != inputs
inputs, outputs = res[0], res2[1] inputs, outputs = res[0], res2[1]
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs) # We already cloned the graph, or the user told us there was no need for it
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False)
self.inputs_type = tuple(input.type for input in self.inputs) self.inputs_type = tuple(input.type for input in self.inputs)
self.outputs_type = tuple(output.type for output in self.outputs) self.outputs_type = tuple(output.type for output in self.outputs)
self.nin = len(inputs) self.nin = len(inputs)
......
...@@ -915,12 +915,13 @@ class FusionOptimizer(GraphRewriter): ...@@ -915,12 +915,13 @@ class FusionOptimizer(GraphRewriter):
break break
scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs)
composite_outputs = Elemwise(ps.Composite(scalar_inputs, scalar_outputs))( composite_outputs = Elemwise(
*inputs # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables
) ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False)
if not isinstance(composite_outputs, list): )(*inputs, return_list=True)
composite_outputs = [composite_outputs] assert len(outputs) == len(composite_outputs)
for old_out, composite_out in zip(outputs, composite_outputs, strict=True): for old_out, composite_out in zip(outputs, composite_outputs):
# Preserve any names on the original outputs
if old_out.name: if old_out.name:
composite_out.name = old_out.name composite_out.name = old_out.name
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论