Unverified 提交 9858b330 authored 作者: Ian Schweer's avatar Ian Schweer 提交者: GitHub

Implement ScalarLoop in torch backend (#958)

* Add for loop based scalar loop * Pass all loop tests * Fetch constants from op * Add while loop test * Fix while loop and nasty stack over dtypes * Disable compile here based on CI result * Fix mypy signature * Remove unnecessary torch stack * Only call .cpu when necessary * Recursive false for torch compiler * Add elemwise test * Late import torch * Do iteration instead of vmap for elemwise * Clean up and add description * Add unit test to verify iteration * Refactor to ravel method * Fix unpacking Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Fix comment * Remove extra return * Update test * Add single carry test * Remove compiler disable * Better name * Lint * Better docstring * Pr comments --------- Co-authored-by: 's avatarIan Schweer <ischweer@riotgames.com> Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com>
上级 07bd48db
......@@ -3,6 +3,7 @@ import importlib
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
......@@ -11,6 +12,7 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
def check_special_scipy(func_name):
......@@ -33,6 +35,9 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)
elif isinstance(scalar_op, ScalarLoop):
return elemwise_ravel_fn(base_fn, op, node, **kwargs)
else:
def elemwise_fn(*inputs):
......@@ -176,3 +181,37 @@ def jax_funcify_SoftmaxGrad(op, **kwargs):
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm
return softmax_grad
def elemwise_ravel_fn(base_fn, op, node, **kwargs):
"""
Dispatch methods using `.item()` (ScalarLoop + Elemwise) is common, but vmap
in torch has a limitation: https://github.com/pymc-devs/pytensor/issues/1031,
Instead, we can ravel all the inputs, broadcasted according to torch
"""
n_outputs = len(node.outputs)
def elemwise_fn(*inputs):
bcasted_inputs = torch.broadcast_tensors(*inputs)
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]
out_shape = bcasted_inputs[0].size()
out_size = out_shape.numel()
raveled_outputs = [torch.empty(out_size) for out in node.outputs]
for i in range(out_size):
core_outs = base_fn(*(inp[i] for inp in raveled_inputs))
if n_outputs == 1:
raveled_outputs[0][i] = core_outs
else:
for o in range(n_outputs):
raveled_outputs[o][i] = core_outs[o]
outputs = tuple(out.view(out_shape) for out in raveled_outputs)
if n_outputs == 1:
return outputs[0]
else:
return outputs
return elemwise_fn
......@@ -7,6 +7,7 @@ from pytensor.scalar.basic import (
Cast,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus
......@@ -62,3 +63,37 @@ def pytorch_funcify_Cast(op: Cast, node, **kwargs):
@pytorch_funcify.register(Softplus)
def pytorch_funcify_Softplus(op, node, **kwargs):
return torch.nn.Softplus()
@pytorch_funcify.register(ScalarLoop)
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
update = pytorch_funcify(op.fgraph, **kwargs)
state_length = op.nout
if op.is_while:
def scalar_loop(steps, *start_and_constants):
carry, constants = (
start_and_constants[:state_length],
start_and_constants[state_length:],
)
done = True
for _ in range(steps):
*carry, done = update(*carry, *constants)
if torch.any(done):
break
return *carry, done
else:
def scalar_loop(steps, *start_and_constants):
carry, constants = (
start_and_constants[:state_length],
start_and_constants[state_length:],
)
for _ in range(steps):
carry = update(*carry, *constants)
if len(node.outputs) == 1:
return carry[0]
else:
return carry
return scalar_loop
......@@ -54,21 +54,22 @@ class PytorchLinker(JITLinker):
self.fn = torch.compile(fn)
self.gen_functors = gen_functors.copy()
def __call__(self, *args, **kwargs):
def __call__(self, *inputs, **kwargs):
import pytensor.link.utils
# set attrs
for n, fn in self.gen_functors:
setattr(pytensor.link.utils, n[1:], fn)
res = self.fn(*args, **kwargs)
# Torch does not accept numpy inputs and may return GPU objects
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)
# unset attrs
for n, _ in self.gen_functors:
if getattr(pytensor.link.utils, n[1:], False):
delattr(pytensor.link.utils, n[1:])
return res
return tuple(out.cpu().numpy() for out in outs)
def __del__(self):
del self.gen_functors
......@@ -76,12 +77,7 @@ class PytorchLinker(JITLinker):
inner_fn = wrapper(fn, self.gen_functors)
self.gen_functors = []
# Torch does not accept numpy inputs and may return GPU objects
def fn(*inputs, inner_fn=inner_fn):
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
return tuple(out.cpu().numpy() for out in outs)
return fn
return inner_fn
def create_thunk_inputs(self, storage_map):
thunk_inputs = []
......
......@@ -4,6 +4,7 @@ from functools import partial
import numpy as np
import pytest
import pytensor.tensor as pt
import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
......@@ -17,7 +18,10 @@ from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import float64, int64
from pytensor.scalar.loop import ScalarLoop
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.type import matrices, matrix, scalar, vector
......@@ -385,3 +389,85 @@ def test_pytorch_softplus():
out = softplus(x)
f = FunctionGraph([x], [out])
compare_pytorch_and_py(f, [np.random.rand(3)])
def test_ScalarLoop():
n_steps = int64("n_steps")
x0 = float64("x0")
const = float64("const")
x = x0 + const
op = ScalarLoop(init=[x0], constant=[const], update=[x])
x = op(n_steps, x0, const)
fn = function([n_steps, x0, const], x, mode=pytorch_mode)
np.testing.assert_allclose(fn(5, 0, 1), 5)
np.testing.assert_allclose(fn(5, 0, 2), 10)
np.testing.assert_allclose(fn(4, 3, -1), -1)
def test_ScalarLoop_while():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 + 1
until = x >= 10
op = ScalarLoop(init=[x0], update=[x], until=until)
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
for res, expected in zip(
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
[[10, True], [10, True], [6, False]],
strict=True,
):
np.testing.assert_allclose(res[0], np.array(expected[0]))
np.testing.assert_allclose(res[1], np.array(expected[1]))
def test_ScalarLoop_Elemwise_single_carries():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 * 2
until = x >= 10
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
op = Elemwise(scalarop)
n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
state, done = op(n_steps, x0)
f = FunctionGraph([n_steps, x0], [state, done])
args = [
np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"),
]
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)
def test_ScalarLoop_Elemwise_multi_carries():
n_steps = int64("n_steps")
x0 = float64("x0")
x1 = float64("x1")
x = x0 * 2
x1_n = x1 * 3
until = x >= 10
scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until)
op = Elemwise(scalarop)
n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1))
*states, done = op(n_steps, x0, x1)
f = FunctionGraph([n_steps, x0, x1], [*states, done])
args = [
np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"),
np.random.rand(7, 3, 1).astype("float32"),
]
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论