提交 be358ed2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Thomas Wiecki

Implement Cast in PyTorch backend

上级 be6a0322
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
Cast,
ScalarOp, ScalarOp,
) )
...@@ -38,3 +39,13 @@ def pytorch_funcify_ScalarOp(op, node, **kwargs): ...@@ -38,3 +39,13 @@ def pytorch_funcify_ScalarOp(op, node, **kwargs):
) )
return pytorch_func return pytorch_func
@pytorch_funcify.register(Cast)
def pytorch_funcify_Cast(op: Cast, node, **kwargs):
dtype = getattr(torch, op.o_type.dtype)
def cast(x):
return x.to(dtype=dtype)
return cast
...@@ -10,6 +10,9 @@ from pytensor.tensor.type import matrix, tensor, tensor3, vector ...@@ -10,6 +10,9 @@ from pytensor.tensor.type import matrix, tensor, tensor3, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
torch = pytest.importorskip("torch")
def test_pytorch_Dimshuffle(): def test_pytorch_Dimshuffle():
a_pt = matrix("a") a_pt = matrix("a")
...@@ -137,3 +140,13 @@ def test_softmax_grad(axis): ...@@ -137,3 +140,13 @@ def test_softmax_grad(axis):
out = SoftmaxGrad(axis=axis)(dy, sm) out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out]) fgraph = FunctionGraph([dy, sm], [out])
compare_pytorch_and_py(fgraph, [dy_value, sm_value]) compare_pytorch_and_py(fgraph, [dy_value, sm_value])
def test_cast():
x = matrix("x", dtype="float32")
out = pt.cast(x, "int32")
fgraph = FunctionGraph([x], [out])
_, [res] = compare_pytorch_and_py(
fgraph, [np.arange(6, dtype="float32").reshape(2, 3)]
)
assert res.dtype == torch.int32
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论