Unverified 提交 9858b330 authored 作者: Ian Schweer's avatar Ian Schweer 提交者: GitHub

Implement ScalarLoop in torch backend (#958)

* Add for loop based scalar loop * Pass all loop tests * Fetch constants from op * Add while loop test * Fix while loop and nasty stack over dtypes * Disable compile here based on CI result * Fix mypy signature * Remove unnecessary torch stack * Only call .cpu when necessary * Recursive false for torch compiler * Add elemwise test * Late import torch * Do iteration instead of vmap for elemwise * Clean up and add description * Add unit test to verify iteration * Refactor to ravel method * Fix unpacking Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Fix comment * Remove extra return * Update test * Add single carry test * Remove compiler disable * Better name * Lint * Better docstring * Pr comments --------- Co-authored-by: 's avatarIan Schweer <ischweer@riotgames.com> Co-authored-by: 's avatarRicardo Vieira <28983449+ricardoV94@users.noreply.github.com>
上级 07bd48db
...@@ -3,6 +3,7 @@ import importlib ...@@ -3,6 +3,7 @@ import importlib
import torch import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
...@@ -11,6 +12,7 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad ...@@ -11,6 +12,7 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@pytorch_funcify.register(Elemwise) @pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs): def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
def check_special_scipy(func_name): def check_special_scipy(func_name):
...@@ -33,6 +35,9 @@ def pytorch_funcify_Elemwise(op, node, **kwargs): ...@@ -33,6 +35,9 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
Elemwise._check_runtime_broadcast(node, inputs) Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs) return base_fn(*inputs)
elif isinstance(scalar_op, ScalarLoop):
return elemwise_ravel_fn(base_fn, op, node, **kwargs)
else: else:
def elemwise_fn(*inputs): def elemwise_fn(*inputs):
...@@ -176,3 +181,37 @@ def jax_funcify_SoftmaxGrad(op, **kwargs): ...@@ -176,3 +181,37 @@ def jax_funcify_SoftmaxGrad(op, **kwargs):
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm
return softmax_grad return softmax_grad
def elemwise_ravel_fn(base_fn, op, node, **kwargs):
"""
Dispatch methods using `.item()` (ScalarLoop + Elemwise) is common, but vmap
in torch has a limitation: https://github.com/pymc-devs/pytensor/issues/1031,
Instead, we can ravel all the inputs, broadcasted according to torch
"""
n_outputs = len(node.outputs)
def elemwise_fn(*inputs):
bcasted_inputs = torch.broadcast_tensors(*inputs)
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]
out_shape = bcasted_inputs[0].size()
out_size = out_shape.numel()
raveled_outputs = [torch.empty(out_size) for out in node.outputs]
for i in range(out_size):
core_outs = base_fn(*(inp[i] for inp in raveled_inputs))
if n_outputs == 1:
raveled_outputs[0][i] = core_outs
else:
for o in range(n_outputs):
raveled_outputs[o][i] = core_outs[o]
outputs = tuple(out.view(out_shape) for out in raveled_outputs)
if n_outputs == 1:
return outputs[0]
else:
return outputs
return elemwise_fn
...@@ -7,6 +7,7 @@ from pytensor.scalar.basic import ( ...@@ -7,6 +7,7 @@ from pytensor.scalar.basic import (
Cast, Cast,
ScalarOp, ScalarOp,
) )
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus from pytensor.scalar.math import Softplus
...@@ -62,3 +63,37 @@ def pytorch_funcify_Cast(op: Cast, node, **kwargs): ...@@ -62,3 +63,37 @@ def pytorch_funcify_Cast(op: Cast, node, **kwargs):
@pytorch_funcify.register(Softplus) @pytorch_funcify.register(Softplus)
def pytorch_funcify_Softplus(op, node, **kwargs): def pytorch_funcify_Softplus(op, node, **kwargs):
return torch.nn.Softplus() return torch.nn.Softplus()
@pytorch_funcify.register(ScalarLoop)
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
update = pytorch_funcify(op.fgraph, **kwargs)
state_length = op.nout
if op.is_while:
def scalar_loop(steps, *start_and_constants):
carry, constants = (
start_and_constants[:state_length],
start_and_constants[state_length:],
)
done = True
for _ in range(steps):
*carry, done = update(*carry, *constants)
if torch.any(done):
break
return *carry, done
else:
def scalar_loop(steps, *start_and_constants):
carry, constants = (
start_and_constants[:state_length],
start_and_constants[state_length:],
)
for _ in range(steps):
carry = update(*carry, *constants)
if len(node.outputs) == 1:
return carry[0]
else:
return carry
return scalar_loop
...@@ -54,21 +54,22 @@ class PytorchLinker(JITLinker): ...@@ -54,21 +54,22 @@ class PytorchLinker(JITLinker):
self.fn = torch.compile(fn) self.fn = torch.compile(fn)
self.gen_functors = gen_functors.copy() self.gen_functors = gen_functors.copy()
def __call__(self, *args, **kwargs): def __call__(self, *inputs, **kwargs):
import pytensor.link.utils import pytensor.link.utils
# set attrs # set attrs
for n, fn in self.gen_functors: for n, fn in self.gen_functors:
setattr(pytensor.link.utils, n[1:], fn) setattr(pytensor.link.utils, n[1:], fn)
res = self.fn(*args, **kwargs) # Torch does not accept numpy inputs and may return GPU objects
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)
# unset attrs # unset attrs
for n, _ in self.gen_functors: for n, _ in self.gen_functors:
if getattr(pytensor.link.utils, n[1:], False): if getattr(pytensor.link.utils, n[1:], False):
delattr(pytensor.link.utils, n[1:]) delattr(pytensor.link.utils, n[1:])
return res return tuple(out.cpu().numpy() for out in outs)
def __del__(self): def __del__(self):
del self.gen_functors del self.gen_functors
...@@ -76,12 +77,7 @@ class PytorchLinker(JITLinker): ...@@ -76,12 +77,7 @@ class PytorchLinker(JITLinker):
inner_fn = wrapper(fn, self.gen_functors) inner_fn = wrapper(fn, self.gen_functors)
self.gen_functors = [] self.gen_functors = []
# Torch does not accept numpy inputs and may return GPU objects return inner_fn
def fn(*inputs, inner_fn=inner_fn):
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
return tuple(out.cpu().numpy() for out in outs)
return fn
def create_thunk_inputs(self, storage_map): def create_thunk_inputs(self, storage_map):
thunk_inputs = [] thunk_inputs = []
......
...@@ -4,6 +4,7 @@ from functools import partial ...@@ -4,6 +4,7 @@ from functools import partial
import numpy as np import numpy as np
import pytest import pytest
import pytensor.tensor as pt
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
...@@ -17,7 +18,10 @@ from pytensor.graph.op import Op ...@@ -17,7 +18,10 @@ from pytensor.graph.op import Op
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import float64, int64
from pytensor.scalar.loop import ScalarLoop
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.type import matrices, matrix, scalar, vector from pytensor.tensor.type import matrices, matrix, scalar, vector
...@@ -385,3 +389,85 @@ def test_pytorch_softplus(): ...@@ -385,3 +389,85 @@ def test_pytorch_softplus():
out = softplus(x) out = softplus(x)
f = FunctionGraph([x], [out]) f = FunctionGraph([x], [out])
compare_pytorch_and_py(f, [np.random.rand(3)]) compare_pytorch_and_py(f, [np.random.rand(3)])
def test_ScalarLoop():
n_steps = int64("n_steps")
x0 = float64("x0")
const = float64("const")
x = x0 + const
op = ScalarLoop(init=[x0], constant=[const], update=[x])
x = op(n_steps, x0, const)
fn = function([n_steps, x0, const], x, mode=pytorch_mode)
np.testing.assert_allclose(fn(5, 0, 1), 5)
np.testing.assert_allclose(fn(5, 0, 2), 10)
np.testing.assert_allclose(fn(4, 3, -1), -1)
def test_ScalarLoop_while():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 + 1
until = x >= 10
op = ScalarLoop(init=[x0], update=[x], until=until)
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
for res, expected in zip(
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
[[10, True], [10, True], [6, False]],
strict=True,
):
np.testing.assert_allclose(res[0], np.array(expected[0]))
np.testing.assert_allclose(res[1], np.array(expected[1]))
def test_ScalarLoop_Elemwise_single_carries():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 * 2
until = x >= 10
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
op = Elemwise(scalarop)
n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
state, done = op(n_steps, x0)
f = FunctionGraph([n_steps, x0], [state, done])
args = [
np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"),
]
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)
def test_ScalarLoop_Elemwise_multi_carries():
n_steps = int64("n_steps")
x0 = float64("x0")
x1 = float64("x1")
x = x0 * 2
x1_n = x1 * 3
until = x >= 10
scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until)
op = Elemwise(scalarop)
n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1))
*states, done = op(n_steps, x0, x1)
f = FunctionGraph([n_steps, x0, x1], [*states, done])
args = [
np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"),
np.random.rand(7, 3, 1).astype("float32"),
]
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论