Unverified 提交 c477732f authored 作者: Ian Schweer's avatar Ian Schweer 提交者: GitHub

Torch dispatch for scipy-like functions and Softplus (#1066)

* Allow for scipy module resolution * Add softplus * Add tests * Allow scipy scalar handling * Check for scipy in elemwise --------- Co-authored-by: 's avatarIan Schweer <ischweer@riotgames.com>
上级 ae66e82a
import importlib
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
......@@ -11,12 +13,26 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]):
def check_special_scipy(func_name):
if "scipy." not in func_name:
return False
loc = func_name.split(".")[1:]
try:
mod = importlib.import_module(".".join(loc[:-1]), "torch")
return getattr(mod, loc[-1], False)
except ImportError:
return False
if hasattr(scalar_op, "nfunc_spec") and (
hasattr(torch, scalar_op.nfunc_spec[0])
or check_special_scipy(scalar_op.nfunc_spec[0])
):
# torch can handle this scalar
# broadcast, we'll let it.
def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)
else:
def elemwise_fn(*inputs):
......
import importlib
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
......@@ -5,6 +7,7 @@ from pytensor.scalar.basic import (
Cast,
ScalarOp,
)
from pytensor.scalar.math import Softplus
@pytorch_funcify.register(ScalarOp)
......@@ -19,9 +22,14 @@ def pytorch_funcify_ScalarOp(op, node, **kwargs):
if nfunc_spec is None:
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
func_name = nfunc_spec[0]
func_name = nfunc_spec[0].replace("scipy.", "")
pytorch_func = getattr(torch, func_name)
if "." in func_name:
loc = func_name.split(".")
mod = importlib.import_module(".".join(["torch", *loc[:-1]]))
pytorch_func = getattr(mod, loc[-1])
else:
pytorch_func = getattr(torch, func_name)
if len(node.inputs) > op.nfunc_spec[1]:
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
......@@ -49,3 +57,8 @@ def pytorch_funcify_Cast(op: Cast, node, **kwargs):
return x.to(dtype=dtype)
return cast
@pytorch_funcify.register(Softplus)
def pytorch_funcify_Softplus(op, node, **kwargs):
return torch.nn.Softplus()
......@@ -17,7 +17,7 @@ from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
from pytensor.tensor.type import matrices, matrix, scalar, vector
......@@ -374,3 +374,17 @@ def test_pytorch_link_references():
f = function([x], out, mode="PYTORCH")
f(torch.ones(3))
assert "inner_fn" not in dir(m), "function call reference leaked"
def test_pytorch_scipy():
x = vector("a", shape=(3,))
out = expit(x)
f = FunctionGraph([x], [out])
compare_pytorch_and_py(f, [np.random.rand(3)])
def test_pytorch_softplus():
x = vector("a", shape=(3,))
out = softplus(x)
f = FunctionGraph([x], [out])
compare_pytorch_and_py(f, [np.random.rand(3)])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论