提交 4273eb87 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Inline constants and merge duplicate inputs in OptimizeOp

上级 2bb2ec6d
...@@ -199,10 +199,16 @@ class ScipyWrapperOp(Op, HasInnerGraph): ...@@ -199,10 +199,16 @@ class ScipyWrapperOp(Op, HasInnerGraph):
def inner_outputs(self): def inner_outputs(self):
return self.fgraph.outputs return self.fgraph.outputs
def clone_with_new_fgraph(self, fgraph):
clone_op = copy(self)
clone_op._fn = None
clone_op._fn_wrapped = None
clone_op.fgraph = fgraph
return clone_op
def clone(self): def clone(self):
copy_op = copy(self) clone_fgraph = self.fgraph.clone(clone_inner_graphs=True)
copy_op.fgraph = self.fgraph.clone(clone_inner_graphs=True) return self.clone_with_new_fgraph(clone_fgraph)
return copy_op
def prepare_node( def prepare_node(
self, self,
......
...@@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.linalg ...@@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.numba import pytensor.tensor.rewriting.numba
import pytensor.tensor.rewriting.ofg import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.optimize
import pytensor.tensor.rewriting.reshape import pytensor.tensor.rewriting.reshape
import pytensor.tensor.rewriting.shape import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special import pytensor.tensor.rewriting.special
......
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.optimize import ScipyWrapperOp
from pytensor.tensor.rewriting.basic import register_canonicalize
@register_canonicalize
@node_rewriter([ScipyWrapperOp])
def remove_constants_and_duplicate_inputs_scipy(fgraph, node):
"""Inline constants and remove duplicate inputs from ScipyWrapperOp nodes.
Constants in the outer graph are free symbolic variables in the inner graph.
Moving them into the inner graph enables constant-folding. Duplicate outer
inputs can share a single inner variable.
Only args (inputs[1:]) are candidates — inputs[0] is always the
optimization variable x.
"""
op: ScipyWrapperOp = node.op
inner_x, *inner_args = op.inner_inputs
outer_x, *outer_args = list(node.inputs)
givens = {}
new_inner_args = []
new_outer_args = []
for inner_in, outer_in in zip(inner_args, outer_args):
if isinstance(outer_in, Constant):
givens[inner_in] = outer_in
elif outer_in in new_outer_args:
# De-duplicate outer variable
idx = new_outer_args.index(outer_in)
givens[inner_in] = new_inner_args[idx]
else:
new_inner_args.append(inner_in)
new_outer_args.append(outer_in)
if not givens:
return None
new_inner_outputs = clone_replace(op.inner_outputs, replace=givens)
new_inner_inputs = (inner_x, *new_inner_args)
new_fgraph = FunctionGraph(new_inner_inputs, new_inner_outputs, clone=False)
new_op = op.clone_with_new_fgraph(new_fgraph)
new_outer_inputs = (outer_x, *new_outer_args)
return new_op.make_node(*new_outer_inputs).outputs
import numpy as np
import pytensor.tensor as pt
from pytensor import function
from pytensor.tensor.optimize import MinimizeOp, ScipyWrapperOp
def test_inline_constants():
"""Constants passed as args should be inlined into the inner graph."""
x = pt.scalar("x")
a = pt.scalar("a")
b = pt.scalar("b", dtype=int)
c = pt.scalar("c")
objective = (x - c * a) ** b
minimize_op = MinimizeOp(
x,
a,
b,
c,
objective=objective,
method="BFGS",
)
two_float = pt.full((), 2.0, dtype=a.dtype)
two_int = two_float.astype(b.dtype)
minimize_node = minimize_op.make_node(x, two_float, two_int, c)
assert len(minimize_node.inputs) == 4
f = function([x, c], minimize_node.outputs)
# Check the two constants are inlined
[minimize_node] = [
node
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, ScipyWrapperOp)
]
assert len(minimize_node.inputs) == 2
# Check correctness
c_val = 3.0
minimized_x_val, success_val = f(np.pi, c_val)
assert success_val
np.testing.assert_allclose(
minimized_x_val,
2 * c_val,
)
def test_remove_duplicate_inputs(new_minimize_outs=None):
"""Duplicate outer inputs should be deduplicated."""
x = pt.scalar("x")
a = pt.scalar("a")
b = pt.scalar("b")
objective = (x + a) ** 2 + (x - b) ** 2
minimize_op = MinimizeOp(
x,
a,
b,
objective=objective,
method="BFGS",
)
# Use same outer variable for both a, b
c = pt.scalar("c")
minimized_node = minimize_op.make_node(x, c, c)
assert len(minimized_node.inputs) == 3
f = function([x, c], minimized_node.outputs)
[minimize_node] = [
node
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, ScipyWrapperOp)
]
assert len(minimize_node.inputs) == 2
# Check correctness: minimum of (x+a)^2 + (x-a)^2 = 2x^2 + 2a^2 is at x=0
minimized_x_val, success_val = f(np.pi, np.e)
assert success_val
np.testing.assert_allclose(minimized_x_val, 0.0, atol=1e-8)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论