提交 7393b744 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

Allow subtypes by default in make_node

上级 d6858fe2
...@@ -222,21 +222,22 @@ class Op(MetaObject): ...@@ -222,21 +222,22 @@ class Op(MetaObject):
""" """
if self.itypes is None: if self.itypes is None:
raise NotImplementedError( raise NotImplementedError(
"You can either define itypes and otypes,\ "You can either define itypes and otypes, or implement make_node"
or implement make_node"
) )
if self.otypes is None: if self.otypes is None:
raise NotImplementedError( raise NotImplementedError(
"You can either define itypes and otypes,\ "You can either define itypes and otypes, or implement make_node"
or implement make_node"
) )
if len(inputs) != len(self.itypes): if len(inputs) != len(self.itypes):
raise ValueError( raise ValueError(
f"We expected {len(self.itypes)} inputs but got {len(inputs)}." f"We expected {len(self.itypes)} inputs but got {len(inputs)}."
) )
if not all(it.in_same_class(inp.type) for inp, it in zip(inputs, self.itypes)): if not all(
expected_type.is_super(var.type)
for var, expected_type in zip(inputs, self.itypes)
):
raise TypeError( raise TypeError(
f"Invalid input types for Op {self}:\n" f"Invalid input types for Op {self}:\n"
+ "\n".join( + "\n".join(
......
...@@ -223,3 +223,16 @@ def test_op_invalid_input_types(): ...@@ -223,3 +223,16 @@ def test_op_invalid_input_types():
msg = r"^Invalid input types for Op.*" msg = r"^Invalid input types for Op.*"
with pytest.raises(TypeError, match=msg): with pytest.raises(TypeError, match=msg):
TestOp()(dvector(), dscalar(), dvector()) TestOp()(dvector(), dscalar(), dvector())
def test_op_input_broadcastable():
# Test that we can create an op with a broadcastable subtype as input
class SomeOp(aesara.tensor.Op):
itypes = [at.dvector]
otypes = [at.dvector]
def perform(self, *_):
raise NotImplementedError()
x = at.TensorType(dtype="float64", shape=(1,))("x")
assert SomeOp()(x).type == at.dvector
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论