Unverified 提交 2315e693 authored 作者: Trey Wenger's avatar Trey Wenger 提交者: GitHub

support `on_unused_input` for string parameter names in `eval` (#1085)

上级 d9d8dba8
...@@ -616,16 +616,20 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]): ...@@ -616,16 +616,20 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
""" """
from pytensor.compile.function import function from pytensor.compile.function import function
ignore_unused_input = kwargs.get("on_unused_input", None) in ("ignore", "warn")
def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]: def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]:
new_input_to_values = {} new_input_to_values = {}
for key, value in inputs_to_values.items(): for key, value in inputs_to_values.items():
if isinstance(key, str): if isinstance(key, str):
matching_vars = get_var_by_name([self], key) matching_vars = get_var_by_name([self], key)
if not matching_vars: if not matching_vars:
raise ValueError(f"{key} not found in graph") if not ignore_unused_input:
raise ValueError(f"{key} not found in graph")
elif len(matching_vars) > 1: elif len(matching_vars) > 1:
raise ValueError(f"Found multiple variables with name {key}") raise ValueError(f"Found multiple variables with name {key}")
new_input_to_values[matching_vars[0]] = value else:
new_input_to_values[matching_vars[0]] = value
else: else:
new_input_to_values[key] = value new_input_to_values[key] = value
return new_input_to_values return new_input_to_values
......
...@@ -367,6 +367,10 @@ class TestEval: ...@@ -367,6 +367,10 @@ class TestEval:
self.w.eval({self.z: 3, self.x: 2.5}) self.w.eval({self.z: 3, self.x: 2.5})
assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0 assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0
# regression test for https://github.com/pymc-devs/pytensor/issues/1084
q = self.x + 1
assert q.eval({"x": 1, "y": 2}, on_unused_input="ignore") == 2.0
@pytest.mark.filterwarnings("error") @pytest.mark.filterwarnings("error")
def test_eval_unashable_kwargs(self): def test_eval_unashable_kwargs(self):
y_repl = constant(2.0, dtype="floatX") y_repl = constant(2.0, dtype="floatX")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论