提交 2bb2ec6d authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Do not include constant branches in minimization arguments

上级 4b9163bc
...@@ -20,6 +20,7 @@ from pytensor.graph.op import ( ...@@ -20,6 +20,7 @@ from pytensor.graph.op import (
from pytensor.graph.replace import graph_replace from pytensor.graph.replace import graph_replace
from pytensor.graph.traversal import ( from pytensor.graph.traversal import (
ancestors, ancestors,
graph_inputs,
truncated_graph_inputs, truncated_graph_inputs,
) )
from pytensor.scalar import ScalarType, ScalarVariable from pytensor.scalar import ScalarType, ScalarVariable
...@@ -147,10 +148,18 @@ def _find_optimization_parameters( ...@@ -147,10 +148,18 @@ def _find_optimization_parameters(
This is used to determine the additional arguments that need to be passed to the objective function. This is used to determine the additional arguments that need to be passed to the objective function.
""" """
def _depends_only_on_constants(var: Variable) -> bool:
if isinstance(var, Constant):
return True
if var.owner is None:
return False
return all(isinstance(v, Constant) for v in graph_inputs([var]))
return [ return [
arg arg
for arg in truncated_graph_inputs([objective], [x]) for arg in truncated_graph_inputs([objective], [x])
if (arg is not x and not isinstance(arg, Constant)) if (arg is not x and not _depends_only_on_constants(arg))
] ]
......
...@@ -17,6 +17,19 @@ from tests import unittest_tools as utt ...@@ -17,6 +17,19 @@ from tests import unittest_tools as utt
floatX = config.floatX floatX = config.floatX
def test_constant_expressions_not_in_args():
x = pt.scalar("x")
a = pt.scalar("a")
# cast(0, int64) is not a Constant, but depends only on constants
constant_expr = pt.cast(pt.constant(0), "int64")
out = (x - a) ** 2 + constant_expr
minimized_x, _ = minimize(out, x)
minimize_op_node = minimized_x.owner
assert minimize_op_node.inputs == [x, a]
def test_minimize_scalar(): def test_minimize_scalar():
x = pt.scalar("x") x = pt.scalar("x")
a = pt.scalar("a") a = pt.scalar("a")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论