提交 2bb2ec6d authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Do not include constant branches in minimization arguments

上级 4b9163bc
......@@ -20,6 +20,7 @@ from pytensor.graph.op import (
from pytensor.graph.replace import graph_replace
from pytensor.graph.traversal import (
ancestors,
graph_inputs,
truncated_graph_inputs,
)
from pytensor.scalar import ScalarType, ScalarVariable
......@@ -147,10 +148,18 @@ def _find_optimization_parameters(
This is used to determine the additional arguments that need to be passed to the objective function.
"""
def _depends_only_on_constants(var: Variable) -> bool:
if isinstance(var, Constant):
return True
if var.owner is None:
return False
return all(isinstance(v, Constant) for v in graph_inputs([var]))
return [
arg
for arg in truncated_graph_inputs([objective], [x])
if (arg is not x and not isinstance(arg, Constant))
if (arg is not x and not _depends_only_on_constants(arg))
]
......
......@@ -17,6 +17,19 @@ from tests import unittest_tools as utt
floatX = config.floatX
def test_constant_expressions_not_in_args():
x = pt.scalar("x")
a = pt.scalar("a")
# cast(0, int64) is not a Constant, but depends only on constants
constant_expr = pt.cast(pt.constant(0), "int64")
out = (x - a) ** 2 + constant_expr
minimized_x, _ = minimize(out, x)
minimize_op_node = minimized_x.owner
assert minimize_op_node.inputs == [x, a]
def test_minimize_scalar():
x = pt.scalar("x")
a = pt.scalar("a")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论