Unverified 提交 feccc417 authored 作者: Raj Parekh's avatar Raj Parekh 提交者: GitHub

Allow string keys in `eval` utility (#242)

上级 74dca205
......@@ -597,6 +597,22 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
if inputs_to_values is None:
inputs_to_values = {}
def convert_string_keys_to_variables(input_to_values):
new_input_to_values = {}
for key, value in inputs_to_values.items():
if isinstance(key, str):
matching_vars = get_var_by_name([self], key)
if not matching_vars:
raise Exception(f"{key} not found in graph")
elif len(matching_vars) > 1:
raise Exception(f"Found multiple variables with name {key}")
new_input_to_values[matching_vars[0]] = value
else:
new_input_to_values[key] = value
return new_input_to_values
inputs_to_values = convert_string_keys_to_variables(inputs_to_values)
if not hasattr(self, "_fn_cache"):
self._fn_cache = dict()
......
......@@ -302,6 +302,24 @@ class TestEval:
pickle.loads(pickle.dumps(self.w)), "_fn_cache"
), "temporary functions must not be serialized"
def test_eval_with_strings(self):
assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0
assert self.w.eval({self.z: 3}) == 6.0
def test_eval_with_strings_multiple_matches(self):
e = scalars("e")
t = e + 1
t.name = "e"
with pytest.raises(Exception, match="Found multiple variables with name e"):
t.eval({"e": 1})
def test_eval_with_strings_no_match(self):
e = scalars("e")
t = e + 1
t.name = "p"
with pytest.raises(Exception, match="o not found in graph"):
t.eval({"o": 1})
class TestAutoName:
def test_auto_name(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论