提交 a5f8b693 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix Type conversion and checking in IfElse.make_node

上级 8528590d
...@@ -175,6 +175,11 @@ class IfElse(_NoPythonOp): ...@@ -175,6 +175,11 @@ class IfElse(_NoPythonOp):
if not isinstance(input_f, Variable): if not isinstance(input_f, Variable):
input_f = as_symbolic(input_f) input_f = as_symbolic(input_f)
if type(input_f.type) != type(input_t.type): # noqa: E721
raise TypeError(
f"Input types {type(input_t.type)} and {type(input_f.type)} do not match."
)
if isinstance(input_t.type, HasDataType) and isinstance( if isinstance(input_t.type, HasDataType) and isinstance(
input_f.type, HasDataType input_f.type, HasDataType
): ):
...@@ -207,18 +212,18 @@ class IfElse(_NoPythonOp): ...@@ -207,18 +212,18 @@ class IfElse(_NoPythonOp):
# TODO FIXME: The presence of this keyword is a strong # TODO FIXME: The presence of this keyword is a strong
# assumption. Find something that's guaranteed by the/a # assumption. Find something that's guaranteed by the/a
# confirmed interface. # confirmed interface.
output_type_t = input_t.type.clone(shape=new_shape)() output_var_t = input_t.type.clone(shape=new_shape)()
output_type_f = input_f.type.clone(shape=new_shape)() output_var_f = input_f.type.clone(shape=new_shape)()
else: else:
output_type_t = input_t.type() output_var_t = input_t.type()
output_type_f = input_f.type() output_var_f = input_f.type()
input_t = output_type_f.type.convert_variable(input_t) input_t_ = output_var_f.type.filter_variable(input_t)
input_f = output_type_t.type.convert_variable(input_f) input_f_ = output_var_t.type.filter_variable(input_f)
new_inputs_true_branch.append(input_t) new_inputs_true_branch.append(input_t_)
new_inputs_false_branch.append(input_f) new_inputs_false_branch.append(input_f_)
output_vars.append(output_type_t) output_vars.append(output_var_t)
return Apply( return Apply(
self, self,
......
...@@ -325,22 +325,24 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -325,22 +325,24 @@ class TestIfelse(utt.OptimizationTestMixin):
with pytest.raises(TypeError): with pytest.raises(TypeError):
ifelse(cond, y, x) ifelse(cond, y, x)
def test_sparse_tensor_error(self): def test_sparse_conversions(self):
from aesara.sparse import matrix
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
data = rng.random((2, 3)).astype(self.dtype) data = rng.random((2, 3)).astype(self.dtype)
x = self.shared(data) x = self.shared(data)
y = aesara.sparse.matrix("csc", dtype=self.dtype, name="y") y = matrix("csc", dtype=self.dtype, name="y")
z = aesara.sparse.matrix("csr", dtype=self.dtype, name="z") z = matrix("csr", dtype=self.dtype, name="z")
cond = iscalar("cond") cond = iscalar("cond")
with pytest.raises(NotImplementedError): with pytest.raises(TypeError, match=".*do not match."):
ifelse(cond, x, y) ifelse(cond, x, y)
with pytest.raises(NotImplementedError): with pytest.raises(TypeError, match=".*do not match."):
ifelse(cond, y, x) ifelse(cond, y, x)
with pytest.raises(NotImplementedError): with pytest.raises(TypeError):
ifelse(cond, x, z) ifelse(cond, x, z)
with pytest.raises(NotImplementedError): with pytest.raises(TypeError):
ifelse(cond, z, x) ifelse(cond, z, x)
with pytest.raises(TypeError): with pytest.raises(TypeError):
ifelse(cond, y, z) ifelse(cond, y, z)
...@@ -534,6 +536,8 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -534,6 +536,8 @@ class TestIfelse(utt.OptimizationTestMixin):
[ [
((2,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)), ((2,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)),
((None,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)), ((None,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)),
((3,), (None,), np.r_[1.0, 2.0, 3.0], np.r_[1.0, 2.0], (None,)),
((2, 1), (None, 1), np.c_[[1.0, 2.0]], np.c_[[1.0, 2.0, 3.0]], (None, 1)),
((3,), (3,), np.r_[1.0, 2.0, 3.0], np.r_[1.0, 2.0, 3.0], (3,)), ((3,), (3,), np.r_[1.0, 2.0, 3.0], np.r_[1.0, 2.0, 3.0], (3,)),
((1,), (3,), np.r_[1.0], np.r_[1.0, 2.0, 3.0], (None,)), ((1,), (3,), np.r_[1.0], np.r_[1.0, 2.0, 3.0], (None,)),
], ],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论