提交 7393b744 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

Allow subtypes by default in make_node

上级 d6858fe2
......@@ -222,21 +222,22 @@ class Op(MetaObject):
"""
if self.itypes is None:
raise NotImplementedError(
"You can either define itypes and otypes,\
or implement make_node"
"You can either define itypes and otypes, or implement make_node"
)
if self.otypes is None:
raise NotImplementedError(
"You can either define itypes and otypes,\
or implement make_node"
"You can either define itypes and otypes, or implement make_node"
)
if len(inputs) != len(self.itypes):
raise ValueError(
f"We expected {len(self.itypes)} inputs but got {len(inputs)}."
)
if not all(it.in_same_class(inp.type) for inp, it in zip(inputs, self.itypes)):
if not all(
expected_type.is_super(var.type)
for var, expected_type in zip(inputs, self.itypes)
):
raise TypeError(
f"Invalid input types for Op {self}:\n"
+ "\n".join(
......
......@@ -223,3 +223,16 @@ def test_op_invalid_input_types():
msg = r"^Invalid input types for Op.*"
with pytest.raises(TypeError, match=msg):
TestOp()(dvector(), dscalar(), dvector())
def test_op_input_broadcastable():
# Test that we can create an op with a broadcastable subtype as input
class SomeOp(aesara.tensor.Op):
itypes = [at.dvector]
otypes = [at.dvector]
def perform(self, *_):
raise NotImplementedError()
x = at.TensorType(dtype="float64", shape=(1,))("x")
assert SomeOp()(x).type == at.dvector
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论