提交 6de31513 authored 作者: Ch0ronomato's avatar Ch0ronomato 提交者: Ricardo Vieira

Improve torch elemwise operator

上级 0ba554b3
......@@ -11,9 +11,21 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)
if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]):
# torch can handle this scalar
# broadcast, we'll let it.
def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)
else:
def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
broadcast_inputs = torch.broadcast_tensors(*inputs)
ufunc = base_fn
for _ in range(broadcast_inputs[0].dim()):
ufunc = torch.vmap(ufunc)
return ufunc(*broadcast_inputs)
return elemwise_fn
......
import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import ScalarOp, get_scalar_type
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, tensor3, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -150,3 +153,33 @@ def test_cast():
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
)
assert res.dtype == torch.int32
def test_vmap_elemwise():
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
class TestOp(ScalarOp):
def __init__(self):
super().__init__(
output_types_preference=lambda *_: [get_scalar_type("float32")]
)
self.call_shapes = []
self.nin = 1
def perform(self, *_):
raise RuntimeError("In perform")
@pytorch_funcify.register(TestOp)
def relu(op, node, **kwargs):
def relu(row):
op.call_shapes.append(row.size())
return torch.max(torch.zeros_like(row), row)
return relu
x = matrix("x", shape=(2, 3))
op = TestOp()
f = pytensor.function([x], Elemwise(op)(x), mode="PYTORCH")
vals = torch.zeros(2, 3).normal_()
np.testing.assert_allclose(f(vals), torch.relu(vals))
assert op.call_shapes == [torch.Size([])], op.call_shapes
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论