提交 4273eb87 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Inline constants and merge duplicate inputs in OptimizeOp

上级 2bb2ec6d
......@@ -199,10 +199,16 @@ class ScipyWrapperOp(Op, HasInnerGraph):
def inner_outputs(self):
return self.fgraph.outputs
def clone_with_new_fgraph(self, fgraph):
clone_op = copy(self)
clone_op._fn = None
clone_op._fn_wrapped = None
clone_op.fgraph = fgraph
return clone_op
def clone(self):
copy_op = copy(self)
copy_op.fgraph = self.fgraph.clone(clone_inner_graphs=True)
return copy_op
clone_fgraph = self.fgraph.clone(clone_inner_graphs=True)
return self.clone_with_new_fgraph(clone_fgraph)
def prepare_node(
self,
......
......@@ -10,6 +10,7 @@ import pytensor.tensor.rewriting.linalg
import pytensor.tensor.rewriting.math
import pytensor.tensor.rewriting.numba
import pytensor.tensor.rewriting.ofg
import pytensor.tensor.rewriting.optimize
import pytensor.tensor.rewriting.reshape
import pytensor.tensor.rewriting.shape
import pytensor.tensor.rewriting.special
......
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.optimize import ScipyWrapperOp
from pytensor.tensor.rewriting.basic import register_canonicalize
@register_canonicalize
@node_rewriter([ScipyWrapperOp])
def remove_constants_and_duplicate_inputs_scipy(fgraph, node):
"""Inline constants and remove duplicate inputs from ScipyWrapperOp nodes.
Constants in the outer graph are free symbolic variables in the inner graph.
Moving them into the inner graph enables constant-folding. Duplicate outer
inputs can share a single inner variable.
Only args (inputs[1:]) are candidates — inputs[0] is always the
optimization variable x.
"""
op: ScipyWrapperOp = node.op
inner_x, *inner_args = op.inner_inputs
outer_x, *outer_args = list(node.inputs)
givens = {}
new_inner_args = []
new_outer_args = []
for inner_in, outer_in in zip(inner_args, outer_args):
if isinstance(outer_in, Constant):
givens[inner_in] = outer_in
elif outer_in in new_outer_args:
# De-duplicate outer variable
idx = new_outer_args.index(outer_in)
givens[inner_in] = new_inner_args[idx]
else:
new_inner_args.append(inner_in)
new_outer_args.append(outer_in)
if not givens:
return None
new_inner_outputs = clone_replace(op.inner_outputs, replace=givens)
new_inner_inputs = (inner_x, *new_inner_args)
new_fgraph = FunctionGraph(new_inner_inputs, new_inner_outputs, clone=False)
new_op = op.clone_with_new_fgraph(new_fgraph)
new_outer_inputs = (outer_x, *new_outer_args)
return new_op.make_node(*new_outer_inputs).outputs
import numpy as np
import pytensor.tensor as pt
from pytensor import function
from pytensor.tensor.optimize import MinimizeOp, ScipyWrapperOp
def test_inline_constants():
"""Constants passed as args should be inlined into the inner graph."""
x = pt.scalar("x")
a = pt.scalar("a")
b = pt.scalar("b", dtype=int)
c = pt.scalar("c")
objective = (x - c * a) ** b
minimize_op = MinimizeOp(
x,
a,
b,
c,
objective=objective,
method="BFGS",
)
two_float = pt.full((), 2.0, dtype=a.dtype)
two_int = two_float.astype(b.dtype)
minimize_node = minimize_op.make_node(x, two_float, two_int, c)
assert len(minimize_node.inputs) == 4
f = function([x, c], minimize_node.outputs)
# Check the two constants are inlined
[minimize_node] = [
node
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, ScipyWrapperOp)
]
assert len(minimize_node.inputs) == 2
# Check correctness
c_val = 3.0
minimized_x_val, success_val = f(np.pi, c_val)
assert success_val
np.testing.assert_allclose(
minimized_x_val,
2 * c_val,
)
def test_remove_duplicate_inputs(new_minimize_outs=None):
"""Duplicate outer inputs should be deduplicated."""
x = pt.scalar("x")
a = pt.scalar("a")
b = pt.scalar("b")
objective = (x + a) ** 2 + (x - b) ** 2
minimize_op = MinimizeOp(
x,
a,
b,
objective=objective,
method="BFGS",
)
# Use same outer variable for both a, b
c = pt.scalar("c")
minimized_node = minimize_op.make_node(x, c, c)
assert len(minimized_node.inputs) == 3
f = function([x, c], minimized_node.outputs)
[minimize_node] = [
node
for node in f.maker.fgraph.apply_nodes
if isinstance(node.op, ScipyWrapperOp)
]
assert len(minimize_node.inputs) == 2
# Check correctness: minimum of (x+a)^2 + (x-a)^2 = 2x^2 + 2a^2 is at x=0
minimized_x_val, success_val = f(np.pi, np.e)
assert success_val
np.testing.assert_allclose(minimized_x_val, 0.0, atol=1e-8)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论