提交 5a6d92c3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix more old style exceptions

上级 444f13c7
...@@ -113,9 +113,8 @@ def rebuild_collect_shared( ...@@ -113,9 +113,8 @@ def rebuild_collect_shared(
) )
if v_update.type != v.type: if v_update.type != v.type:
raise TypeError( raise TypeError(
"an update must have the same type as " "An update must have a type compatible with "
"the original shared variable", "the original shared variable"
(v, v.type, v_update, v_update.type),
) )
update_d[v] = v_update update_d[v] = v_update
update_expr.append((v, v_update)) update_expr.append((v, v_update))
...@@ -134,7 +133,7 @@ def rebuild_collect_shared( ...@@ -134,7 +133,7 @@ def rebuild_collect_shared(
for v_orig, v_repl in replace_pairs: for v_orig, v_repl in replace_pairs:
if not isinstance(v_orig, Variable): if not isinstance(v_orig, Variable):
raise TypeError("given keys must be Variable", v_orig) raise TypeError("`givens` keys must be Variables")
if not isinstance(v_repl, Variable): if not isinstance(v_repl, Variable):
v_repl = shared(v_repl) v_repl = shared(v_repl)
......
...@@ -697,15 +697,13 @@ class Subtensor(COp): ...@@ -697,15 +697,13 @@ class Subtensor(COp):
input_types = get_slice_elements( input_types = get_slice_elements(
idx_list, lambda entry: isinstance(entry, Type) idx_list, lambda entry: isinstance(entry, Type)
) )
if len(inputs) != len(input_types):
raise IndexError( assert len(inputs) == len(input_types)
"Not enough inputs to fill in the Subtensor template.", inputs, idx_list
)
for input, expected_type in zip(inputs, input_types): for input, expected_type in zip(inputs, input_types):
if input.type != expected_type: if input.type != expected_type:
raise TypeError( raise TypeError(
"Wrong type for Subtensor template. Expected %s, got %s." f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}."
% (input.type, expected_type)
) )
# infer the broadcasting pattern # infer the broadcasting pattern
...@@ -1533,8 +1531,7 @@ class IncSubtensor(COp): ...@@ -1533,8 +1531,7 @@ class IncSubtensor(COp):
for input, expected_type in zip(inputs, input_types): for input, expected_type in zip(inputs, input_types):
if input.type != expected_type: if input.type != expected_type:
raise TypeError( raise TypeError(
"Wrong type for Subtensor template. Expected %s, got %s." f"Wrong type for Subtensor template. Expected {input.type}, got {expected_type}."
% (input.type, expected_type)
) )
return Apply(self, (x, y) + inputs, [x.type()]) return Apply(self, (x, y) + inputs, [x.type()])
......
...@@ -4,6 +4,7 @@ import pytest ...@@ -4,6 +4,7 @@ import pytest
import aesara.tensor as at import aesara.tensor as at
from aesara.compile import UnusedInputError from aesara.compile import UnusedInputError
from aesara.compile.function import function, pfunc from aesara.compile.function import function, pfunc
from aesara.compile.function.pfunc import rebuild_collect_shared
from aesara.compile.io import In from aesara.compile.io import In
from aesara.compile.sharedvalue import shared from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config from aesara.configdefaults import config
...@@ -1045,3 +1046,12 @@ class TestRebuildStrict: ...@@ -1045,3 +1046,12 @@ class TestRebuildStrict:
z_val = f(np.ones((3, 5), dtype="int32"), np.arange(5, dtype="int32")) z_val = f(np.ones((3, 5), dtype="int32"), np.arange(5, dtype="int32"))
assert z_val.ndim == 2 assert z_val.ndim == 2
assert np.all(z_val == np.ones((3, 5)) * np.arange(5)) assert np.all(z_val == np.ones((3, 5)) * np.arange(5))
def test_rebuild_collect_shared():
x, y = ivectors("x", "y")
z = x * y
with pytest.raises(TypeError):
rebuild_collect_shared([z], replace={1: 2})
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论