提交 60246ad1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement basic Alloc Ops in PyTorch

上级 320bac49
......@@ -6,6 +6,7 @@ from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange
@singledispatch
......@@ -58,3 +59,33 @@ def pytorch_funcify_DeepCopyOp(op, **kwargs):
return x.clone()
return deepcopyop
@pytorch_funcify.register(AllocEmpty)
def pytorch_funcify_AllocEmpty(op, **kwargs):
dtype = getattr(torch, op.dtype)
def alloc_empty(*shape):
return torch.empty(shape, dtype=dtype)
return alloc_empty
@pytorch_funcify.register(Alloc)
def pytorch_funcify_alloc(op, **kwargs):
def alloc(value, *shape):
out = torch.empty(shape, dtype=value.dtype)
out[...] = value # broadcast value to shape of out
return out
return alloc
@pytorch_funcify.register(ARange)
def pytorch_funcify_arange(op, **kwargs):
dtype = getattr(torch, op.dtype)
def arange(start, stop, step):
return torch.arange(start, stop, step, dtype=dtype)
return arange
......@@ -12,6 +12,7 @@ from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty
from pytensor.tensor.type import scalar, vector
......@@ -191,7 +192,7 @@ def test_shared_updates(device):
assert isinstance(a.get_value(), np.ndarray)
def test_pytorch_checkandraise():
def test_checkandraise():
check_and_raise = CheckAndRaise(AssertionError, "testing")
x = scalar("x")
......@@ -203,3 +204,34 @@ def test_pytorch_checkandraise():
with pytest.raises(AssertionError, match="testing"):
y_fn(0.0)
assert y_fn(4).item() == 4
def test_alloc_and_empty():
dim0 = as_tensor(5, dtype="int64")
dim1 = scalar("dim1", dtype="int64")
out = empty((dim0, dim1, 3), dtype="float32")
fn = function([dim1], out, mode=pytorch_mode)
res = fn(7)
assert res.shape == (5, 7, 3)
assert res.dtype == torch.float32
v = vector("v", shape=(3,), dtype="float64")
out = alloc(v, (dim0, dim1, 3))
compare_pytorch_and_py(
FunctionGraph([v, dim1], [out]),
[np.array([1, 2, 3]), np.array(7)],
)
def test_arange():
start = scalar("start", dtype="int64")
stop = scalar("stop", dtype="int64")
step = scalar("step", dtype="int64")
out = arange(start, stop, step, dtype="int16")
compare_pytorch_and_py(
FunctionGraph([start, stop, step], [out]),
[np.array(1), np.array(10), np.array(2)],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论