提交 5a6d92c3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix more old style exceptions

上级 444f13c7
......@@ -113,9 +113,8 @@ def rebuild_collect_shared(
)
if v_update.type != v.type:
raise TypeError(
"an update must have the same type as "
"the original shared variable",
(v, v.type, v_update, v_update.type),
"An update must have a type compatible with "
"the original shared variable"
)
update_d[v] = v_update
update_expr.append((v, v_update))
......@@ -134,7 +133,7 @@ def rebuild_collect_shared(
for v_orig, v_repl in replace_pairs:
if not isinstance(v_orig, Variable):
raise TypeError("given keys must be Variable", v_orig)
raise TypeError("`givens` keys must be Variables")
if not isinstance(v_repl, Variable):
v_repl = shared(v_repl)
......
......@@ -697,15 +697,13 @@ class Subtensor(COp):
input_types = get_slice_elements(
idx_list, lambda entry: isinstance(entry, Type)
)
if len(inputs) != len(input_types):
raise IndexError(
"Not enough inputs to fill in the Subtensor template.", inputs, idx_list
)
assert len(inputs) == len(input_types)
for input, expected_type in zip(inputs, input_types):
if input.type != expected_type:
raise TypeError(
"Wrong type for Subtensor template. Expected %s, got %s."
% (input.type, expected_type)
f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}."
)
# infer the broadcasting pattern
......@@ -1533,8 +1531,7 @@ class IncSubtensor(COp):
for input, expected_type in zip(inputs, input_types):
if input.type != expected_type:
raise TypeError(
"Wrong type for Subtensor template. Expected %s, got %s."
% (input.type, expected_type)
f"Wrong type for Subtensor template. Expected {input.type}, got {expected_type}."
)
return Apply(self, (x, y) + inputs, [x.type()])
......
......@@ -4,6 +4,7 @@ import pytest
import aesara.tensor as at
from aesara.compile import UnusedInputError
from aesara.compile.function import function, pfunc
from aesara.compile.function.pfunc import rebuild_collect_shared
from aesara.compile.io import In
from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config
......@@ -1045,3 +1046,12 @@ class TestRebuildStrict:
z_val = f(np.ones((3, 5), dtype="int32"), np.arange(5, dtype="int32"))
assert z_val.ndim == 2
assert np.all(z_val == np.ones((3, 5)) * np.arange(5))
def test_rebuild_collect_shared():
x, y = ivectors("x", "y")
z = x * y
with pytest.raises(TypeError):
rebuild_collect_shared([z], replace={1: 2})
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论