提交 993c2c64 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba dispatch of ScalarLoop

上级 23bbabf7
...@@ -15,6 +15,7 @@ from pytensor.link.numba.dispatch.cython_support import wrap_cython_function ...@@ -15,6 +15,7 @@ from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
from pytensor.link.utils import ( from pytensor.link.utils import (
get_name_for_object, get_name_for_object,
) )
from pytensor.scalar import ScalarLoop
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
Add, Add,
Cast, Cast,
...@@ -364,3 +365,52 @@ def numba_funcify_Softplus(op, node, **kwargs): ...@@ -364,3 +365,52 @@ def numba_funcify_Softplus(op, node, **kwargs):
return numba_basic.direct_cast(value, out_dtype) return numba_basic.direct_cast(value, out_dtype)
return softplus, scalar_op_cache_key(op, cache_version=1) return softplus, scalar_op_cache_key(op, cache_version=1)
@register_funcify_and_cache_key(ScalarLoop)
def numba_funcify_ScalarLoop(op, node, **kwargs):
inner_fn, inner_fn_cache_key = numba_funcify_and_cache_key(op.fgraph)
if inner_fn_cache_key is None:
loop_cache_key = None
else:
loop_cache_key = sha256(
str((type(op), op.is_while, inner_fn_cache_key)).encode()
).hexdigest()
if op.is_while:
n_update = len(op.outputs) - 1
@numba_basic.numba_njit
def while_loop(n_steps, *inputs):
carry, constant = inputs[:n_update], inputs[n_update:]
until = False
for i in range(n_steps):
outputs = inner_fn(*carry, *constant)
carry, until = outputs[:-1], outputs[-1]
if until:
break
return *carry, until
return while_loop, loop_cache_key
else:
n_update = len(op.outputs)
@numba_basic.numba_njit
def for_loop(n_steps, *inputs):
carry, constant = inputs[:n_update], inputs[n_update:]
if n_steps < 0:
raise ValueError("ScalarLoop does not have a termination condition.")
for i in range(n_steps):
carry = inner_fn(*carry, *constant)
if n_update == 1:
return carry[0]
else:
return carry
return for_loop, loop_cache_key
...@@ -609,18 +609,42 @@ def test_elemwise_multiple_inplace_outs(): ...@@ -609,18 +609,42 @@ def test_elemwise_multiple_inplace_outs():
def test_scalar_loop(): def test_scalar_loop():
a = float64("a") a_scalar = float64("a")
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a]) const_scalar = float64("const")
scalar_loop = pytensor.scalar.ScalarLoop(
init=[a_scalar],
update=[a_scalar + a_scalar + const_scalar],
constant=[const_scalar],
)
x = pt.tensor("x", shape=(3,)) a = pt.tensor("a", shape=(3,))
elemwise_loop = Elemwise(scalar_loop)(3, x) const = pt.tensor("const", shape=(3,))
n_steps = 3
elemwise_loop = Elemwise(scalar_loop)(n_steps, a, const)
with pytest.warns(UserWarning, match="object mode"): compare_numba_and_py(
compare_numba_and_py( [a, const],
[x], [elemwise_loop],
[elemwise_loop], [np.array([1, 2, 3], dtype="float64"), np.array([1, 1, 1], dtype="float64")],
(np.array([1, 2, 3], dtype="float64"),), )
)
def test_gammainc_wrt_k_grad():
x = pt.vector("x", dtype="float64")
k = pt.vector("k", dtype="float64")
out = pt.gammainc(k, x)
grad_out = grad(out.sum(), k)
compare_numba_and_py(
[x, k],
[grad_out],
# These values of x and k trigger all the branches in the gradient of gammainc
[
np.array([0.0, 29.0, 31.0], dtype="float64"),
np.array([1.0, 13.0, 11.0], dtype="float64"),
],
)
class TestsBenchmark: class TestsBenchmark:
......
...@@ -8,7 +8,7 @@ import pytensor.scalar.math as psm ...@@ -8,7 +8,7 @@ import pytensor.scalar.math as psm
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function from pytensor import config, function
from pytensor.graph import Apply from pytensor.graph import Apply
from pytensor.scalar import UnaryScalarOp from pytensor.scalar import ScalarLoop, UnaryScalarOp
from pytensor.scalar.basic import Composite from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
...@@ -231,3 +231,74 @@ def test_erf_complex(): ...@@ -231,3 +231,74 @@ def test_erf_complex():
[g], [g],
[np.array(0.5 + 1j, dtype="complex128")], [np.array(0.5 + 1j, dtype="complex128")],
) )
class TestScalarLoop:
def test_scalar_for_loop_single_out(self):
n_steps = ps.int64("n_steps")
x0 = ps.float64("x0")
const = ps.float64("const")
x = x0 + const
op = ScalarLoop(init=[x0], constant=[const], update=[x])
x = op(n_steps, x0, const)
fn = function([n_steps, x0, const], [x], mode=numba_mode)
res_x = fn(n_steps=5, x0=0, const=1)
np.testing.assert_allclose(res_x, 5)
res_x = fn(n_steps=5, x0=0, const=2)
np.testing.assert_allclose(res_x, 10)
res_x = fn(n_steps=4, x0=3, const=-1)
np.testing.assert_allclose(res_x, -1)
def test_scalar_for_loop_multiple_outs(self):
n_steps = ps.int64("n_steps")
x0 = ps.float64("x0")
y0 = ps.int64("y0")
const = ps.float64("const")
x = x0 + const
y = y0 + 1
op = ScalarLoop(init=[x0, y0], constant=[const], update=[x, y])
x, y = op(n_steps, x0, y0, const)
fn = function([n_steps, x0, y0, const], [x, y], mode=numba_mode)
res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=1)
np.testing.assert_allclose(res_x, 5)
np.testing.assert_allclose(res_y, 5)
res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=2)
np.testing.assert_allclose(res_x, 10)
np.testing.assert_allclose(res_y, 5)
res_x, res_y = fn(n_steps=4, x0=3, y0=2, const=-1)
np.testing.assert_allclose(res_x, -1)
np.testing.assert_allclose(res_y, 6)
def test_scalar_while_loop(self):
n_steps = ps.int64("n_steps")
x0 = ps.float64("x0")
x = x0 + 1
until = x >= 10
op = ScalarLoop(init=[x0], update=[x], until=until)
fn = function([n_steps, x0], op(n_steps, x0), mode=numba_mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
np.testing.assert_allclose(fn(n_steps=0, x0=1), [1, False])
def test_loop_with_cython_wrapped_op(self):
x = ps.float64("x")
op = ScalarLoop(init=[x], update=[ps.psi(x)])
out = op(1, x)
fn = function([x], out, mode=numba_mode)
x_test = np.float64(0.5)
res = fn(x_test)
expected_res = ps.psi(x).eval({x: x_test})
np.testing.assert_allclose(res, expected_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论