提交 1117ea5e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow keyword arguments in eval method

上级 98070db1
......@@ -555,13 +555,20 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
return [self.owner]
return []
def eval(self, inputs_to_values=None):
r"""Evaluate the `Variable`.
def eval(
self,
inputs_to_values: dict[Union["Variable", str], Any] | None = None,
**kwargs,
):
r"""Evaluate the `Variable` given a set of values for its inputs.
Parameters
----------
inputs_to_values :
A dictionary mapping PyTensor `Variable`\s to values.
A dictionary mapping PyTensor `Variable`\s or names to values.
Not needed if variable has no required inputs.
kwargs :
Optional keyword arguments to pass to the underlying `pytensor.function`
Examples
--------
......@@ -591,10 +598,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
"""
from pytensor.compile.function import function
if inputs_to_values is None:
inputs_to_values = {}
def convert_string_keys_to_variables(input_to_values):
def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]:
new_input_to_values = {}
for key, value in inputs_to_values.items():
if isinstance(key, str):
......@@ -608,19 +612,32 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
new_input_to_values[key] = value
return new_input_to_values
inputs_to_values = convert_string_keys_to_variables(inputs_to_values)
parsed_inputs_to_values: dict[Variable, Any] = {}
if inputs_to_values is not None:
parsed_inputs_to_values = convert_string_keys_to_variables(inputs_to_values)
if not hasattr(self, "_fn_cache"):
self._fn_cache = dict()
self._fn_cache: dict = dict()
inputs = tuple(sorted(inputs_to_values.keys(), key=id))
if inputs not in self._fn_cache:
self._fn_cache[inputs] = function(inputs, self)
args = [inputs_to_values[param] for param in inputs]
inputs = tuple(sorted(parsed_inputs_to_values.keys(), key=id))
cache_key = (inputs, tuple(kwargs.items()))
try:
fn = self._fn_cache[cache_key]
except (KeyError, TypeError):
fn = None
rval = self._fn_cache[inputs](*args)
if fn is None:
fn = function(inputs, self, **kwargs)
try:
self._fn_cache[cache_key] = fn
except TypeError as exc:
warnings.warn(
"Keyword arguments could not be used to create a cache key for the underlying variable. "
f"A function will be recompiled on every call with such keyword arguments.\n{exc}"
)
return rval
args = [parsed_inputs_to_values[param] for param in inputs]
return fn(*args)
def __getstate__(self):
d = self.__dict__.copy()
......
......@@ -6,6 +6,7 @@ import pytest
from pytensor import shared
from pytensor import tensor as pt
from pytensor.compile import UnusedInputError
from pytensor.graph.basic import (
Apply,
NominalVariable,
......@@ -30,6 +31,7 @@ from pytensor.graph.basic import (
)
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.tensor import constant
from pytensor.tensor.math import max_and_argmax
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
from pytensor.tensor.type_other import NoneConst
......@@ -359,6 +361,24 @@ class TestEval:
with pytest.raises(Exception, match="o not found in graph"):
t.eval({"o": 1})
def test_eval_kwargs(self):
with pytest.raises(UnusedInputError):
self.w.eval({self.z: 3, self.x: 2.5})
assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0
@pytest.mark.filterwarnings("error")
def test_eval_unashable_kwargs(self):
y_repl = constant(2.0, dtype="floatX")
assert self.w.eval({self.x: 1.0}, givens=((self.y, y_repl),)) == 6.0
with pytest.warns(
UserWarning,
match="Keyword arguments could not be used to create a cache key",
):
# givens dict is not hashable
assert self.w.eval({self.x: 1.0}, givens={self.y: y_repl}) == 6.0
class TestAutoName:
def test_auto_name(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论