提交 ae66e82a authored 作者: Ian Schweer's avatar Ian Schweer 提交者: Ricardo Vieira

Fix test warning

上级 7300a687
...@@ -12,7 +12,7 @@ torch = pytest.importorskip("torch") ...@@ -12,7 +12,7 @@ torch = pytest.importorskip("torch")
basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic") basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")
class TestOp(Op): class BatchedTestOp(Op):
gufunc_signature = "(m,n),(n,p)->(m,p)" gufunc_signature = "(m,n),(n,p)->(m,p)"
def __init__(self, final_shape): def __init__(self, final_shape):
...@@ -27,7 +27,7 @@ class TestOp(Op): ...@@ -27,7 +27,7 @@ class TestOp(Op):
raise RuntimeError("In perform") raise RuntimeError("In perform")
@basic.pytorch_funcify.register(TestOp) @basic.pytorch_funcify.register(BatchedTestOp)
def evaluate_test_op(op, **_): def evaluate_test_op(op, **_):
def func(a, b): def func(a, b):
op.call_shapes.extend(map(torch.Tensor.size, [a, b])) op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
...@@ -42,7 +42,7 @@ def test_blockwise_broadcast(): ...@@ -42,7 +42,7 @@ def test_blockwise_broadcast():
x = pt.tensor4("x", shape=(5, 1, 2, 3)) x = pt.tensor4("x", shape=(5, 1, 2, 3))
y = pt.tensor3("y", shape=(3, 3, 2)) y = pt.tensor3("y", shape=(3, 3, 2))
op = TestOp((2, 2)) op = BatchedTestOp((2, 2))
z = Blockwise(op)(x, y) z = Blockwise(op)(x, y)
f = pytensor.function([x, y], z, mode="PYTORCH") f = pytensor.function([x, y], z, mode="PYTORCH")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论