提交 8351f902 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename test Op to not confuse pytest

上级 9cf2d181
...@@ -88,7 +88,7 @@ def test_runtime_broadcast(mode): ...@@ -88,7 +88,7 @@ def test_runtime_broadcast(mode):
check_blockwise_runtime_broadcasting(mode) check_blockwise_runtime_broadcasting(mode)
class TestOp(Op): class MyTestOp(Op):
def make_node(self, *inputs): def make_node(self, *inputs):
return Apply(self, inputs, [i.type() for i in inputs]) return Apply(self, inputs, [i.type() for i in inputs])
...@@ -96,7 +96,7 @@ class TestOp(Op): ...@@ -96,7 +96,7 @@ class TestOp(Op):
raise NotImplementedError("Test Op should not be present in final graph") raise NotImplementedError("Test Op should not be present in final graph")
test_op = TestOp() test_op = MyTestOp()
def test_vectorize_node_default_signature(): def test_vectorize_node_default_signature():
...@@ -106,12 +106,12 @@ def test_vectorize_node_default_signature(): ...@@ -106,12 +106,12 @@ def test_vectorize_node_default_signature():
vect_node = vectorize_node(node, mat, mat) vect_node = vectorize_node(node, mat, mat)
assert isinstance(vect_node.op, Blockwise) and isinstance( assert isinstance(vect_node.op, Blockwise) and isinstance(
vect_node.op.core_op, TestOp vect_node.op.core_op, MyTestOp
) )
assert vect_node.op.signature == ("(i00),(i10,i11)->(o00),(o10,o11)") assert vect_node.op.signature == ("(i00),(i10,i11)->(o00),(o10,o11)")
with pytest.raises( with pytest.raises(
ValueError, match="Signature not provided nor found in core_op TestOp" ValueError, match="Signature not provided nor found in core_op MyTestOp"
): ):
Blockwise(test_op) Blockwise(test_op)
...@@ -138,7 +138,7 @@ def test_blockwise_shape(): ...@@ -138,7 +138,7 @@ def test_blockwise_shape():
shape_fn = pytensor.function([inp], out.shape) shape_fn = pytensor.function([inp], out.shape)
assert not any( assert not any(
isinstance(getattr(n.op, "core_op", n.op), TestOp) isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes for n in shape_fn.maker.fgraph.apply_nodes
) )
assert tuple(shape_fn(inp_test)) == (5, 3, 4) assert tuple(shape_fn(inp_test)) == (5, 3, 4)
...@@ -150,13 +150,13 @@ def test_blockwise_shape(): ...@@ -150,13 +150,13 @@ def test_blockwise_shape():
shape_fn = pytensor.function([inp], out.shape) shape_fn = pytensor.function([inp], out.shape)
assert any( assert any(
isinstance(getattr(n.op, "core_op", n.op), TestOp) isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes for n in shape_fn.maker.fgraph.apply_nodes
) )
shape_fn = pytensor.function([inp], out.shape[:-1]) shape_fn = pytensor.function([inp], out.shape[:-1])
assert not any( assert not any(
isinstance(getattr(n.op, "core_op", n.op), TestOp) isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes for n in shape_fn.maker.fgraph.apply_nodes
) )
assert tuple(shape_fn(inp_test)) == (5, 4) assert tuple(shape_fn(inp_test)) == (5, 4)
...@@ -174,20 +174,20 @@ def test_blockwise_shape(): ...@@ -174,20 +174,20 @@ def test_blockwise_shape():
shape_fn = pytensor.function([inp1, inp2], [out.shape for out in outs]) shape_fn = pytensor.function([inp1, inp2], [out.shape for out in outs])
assert any( assert any(
isinstance(getattr(n.op, "core_op", n.op), TestOp) isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes for n in shape_fn.maker.fgraph.apply_nodes
) )
shape_fn = pytensor.function([inp1, inp2], outs[0].shape) shape_fn = pytensor.function([inp1, inp2], outs[0].shape)
assert not any( assert not any(
isinstance(getattr(n.op, "core_op", n.op), TestOp) isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes for n in shape_fn.maker.fgraph.apply_nodes
) )
assert tuple(shape_fn(inp1_test, inp2_test)) == (7, 5, 3, 4) assert tuple(shape_fn(inp1_test, inp2_test)) == (7, 5, 3, 4)
shape_fn = pytensor.function([inp1, inp2], [outs[0].shape, outs[1].shape[:-1]]) shape_fn = pytensor.function([inp1, inp2], [outs[0].shape, outs[1].shape[:-1]])
assert not any( assert not any(
isinstance(getattr(n.op, "core_op", n.op), TestOp) isinstance(getattr(n.op, "core_op", n.op), MyTestOp)
for n in shape_fn.maker.fgraph.apply_nodes for n in shape_fn.maker.fgraph.apply_nodes
) )
assert tuple(shape_fn(inp1_test, inp2_test)[0]) == (7, 5, 3, 4) assert tuple(shape_fn(inp1_test, inp2_test)[0]) == (7, 5, 3, 4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论