提交 fc0d9ec2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fuse hyp2f1 grads

上级 86fe383d
...@@ -5,7 +5,9 @@ As SciPy is not always available, we treat them separately. ...@@ -5,7 +5,9 @@ As SciPy is not always available, we treat them separately.
""" """
import os import os
from functools import reduce
from textwrap import dedent from textwrap import dedent
from typing import Tuple
import numpy as np import numpy as np
import scipy.special import scipy.special
...@@ -683,14 +685,20 @@ class GammaIncC(BinaryScalarOp): ...@@ -683,14 +685,20 @@ class GammaIncC(BinaryScalarOp):
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name): def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=ScalarLoop):
init = [as_scalar(x) for x in init] init = [as_scalar(x) if x is not None else None for x in init]
constant = [as_scalar(x) for x in constant] constant = [as_scalar(x) for x in constant]
# Create dummy types, in case some variables have the same initial form # Create dummy types, in case some variables have the same initial form
init_ = [x.type() for x in init] init_ = [x.type() if x is not None else None for x in init]
constant_ = [x.type() for x in constant] constant_ = [x.type() for x in constant]
update_, until_ = inner_loop_fn(*init_, *constant_) update_, until_ = inner_loop_fn(*init_, *constant_)
op = ScalarLoop(
# Filter Nones
init = [i for i in init if i is not None]
init_ = [i for i in init_ if i is not None]
update_ = [u for u in update_ if u is not None]
op = loop_op(
init=init_, init=init_,
constant=constant_, constant=constant_,
update=update_, update=update_,
...@@ -698,8 +706,7 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name): ...@@ -698,8 +706,7 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
until_condition_failed="warn", until_condition_failed="warn",
name=name, name=name,
) )
S, *_ = op(n_steps, *init, *constant) return op(n_steps, *init, *constant)
return S
def gammainc_grad(k, x): def gammainc_grad(k, x):
...@@ -740,7 +747,7 @@ def gammainc_grad(k, x): ...@@ -740,7 +747,7 @@ def gammainc_grad(k, x):
init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n] init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
constant = [log_x] constant = [log_x]
sum_a = _make_scalar_loop( sum_a, *_ = _make_scalar_loop(
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a" max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
) )
...@@ -827,7 +834,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): ...@@ -827,7 +834,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac] init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
constant = [x] constant = [x]
sum_a = _make_scalar_loop( sum_a, *_ = _make_scalar_loop(
n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a" n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a"
) )
grad_approx_a = ( grad_approx_a = (
...@@ -870,7 +877,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): ...@@ -870,7 +877,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
init = [sum_b0, log_s, s_sign, log_delta, n] init = [sum_b0, log_s, s_sign, log_delta, n]
constant = [k, log_x] constant = [k, log_x]
sum_b = _make_scalar_loop( sum_b, *_ = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b" max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
) )
grad_approx_b = ( grad_approx_b = (
...@@ -1540,7 +1547,7 @@ def betainc_grad(p, q, x, wrtp: bool): ...@@ -1540,7 +1547,7 @@ def betainc_grad(p, q, x, wrtp: bool):
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n] init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
constant = [f, p, q, K, dK] constant = [f, p, q, K, dK]
grad = _make_scalar_loop( grad, *_ = _make_scalar_loop(
max_iters, init, constant, inner_loop, name="betainc_grad" max_iters, init, constant, inner_loop, name="betainc_grad"
) )
return grad return grad
...@@ -1579,10 +1586,11 @@ class Hyp2F1(ScalarOp): ...@@ -1579,10 +1586,11 @@ class Hyp2F1(ScalarOp):
def grad(self, inputs, grads): def grad(self, inputs, grads):
a, b, c, z = inputs a, b, c, z = inputs
(gz,) = grads (gz,) = grads
grad_a, grad_b, grad_c = hyp2f1_grad(a, b, c, z, wrt=[0, 1, 2])
return [ return [
gz * hyp2f1_grad(a, b, c, z, wrt=0), gz * grad_a,
gz * hyp2f1_grad(a, b, c, z, wrt=1), gz * grad_b,
gz * hyp2f1_grad(a, b, c, z, wrt=2), gz * grad_c,
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z), gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
] ]
...@@ -1598,92 +1606,55 @@ def _unsafe_sign(x): ...@@ -1598,92 +1606,55 @@ def _unsafe_sign(x):
return switch(x > 0, 1, -1) return switch(x > 0, 1, -1)
def hyp2f1_grad(a, b, c, z, wrt: int): class Grad2F1Loop(ScalarLoop):
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32") """Subclass of ScalarLoop for easier targetting in rewrites"""
def check_2f1_converges(a, b, c, z):
def is_nonpositive_integer(x):
if x.type.dtype not in integer_types:
return eq(floor(x), x) & (x <= 0)
else:
return x <= 0
a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
num_terms = switch(
a_is_polynomial,
floor(scalar_abs(a)).astype("int64"),
0,
)
b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms) def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype):
num_terms = switch( """
b_is_polynomial, Notes
floor(scalar_abs(b)).astype("int64"), -----
num_terms, The algorithm can be derived by looking at the ratio of two successive terms in the series
) β_{k+1}/β_{k} = A(k)/B(k)
β_{k+1} = A(k)/B(k) * β_{k}
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms) In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
is_polynomial = a_is_polynomial | b_is_polynomial
return (~is_undefined) & ( The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b))) by dropping the respective term
) d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
def compute_grad_2f1(a, b, c, z, wrt, skip_loop): The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
""" tracking their signs.
Notes """
-----
The algorithm can be derived by looking at the ratio of two successive terms in the series
β_{k+1}/β_{k} = A(k)/B(k)
β_{k+1} = A(k)/B(k) * β_{k}
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
by dropping the respective term
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
tracking their signs.
"""
wrt_a = wrt_b = False
if wrt == 0:
wrt_a = True
elif wrt == 1:
wrt_b = True
elif wrt != 2:
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
min_steps = np.array(
10, dtype="int32"
) # https://github.com/stan-dev/math/issues/2857
max_steps = switch(
skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32")
)
precision = np.array(1e-14, dtype=config.floatX)
grad = np.array(0, dtype=dtype) min_steps = np.array(
10, dtype="int32"
) # https://github.com/stan-dev/math/issues/2857
max_steps = switch(
skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32")
)
precision = np.array(1e-14, dtype=config.floatX)
log_g = np.array(-np.inf, dtype=dtype) grads = [np.array(0, dtype=dtype) if i in wrt else None for i in range(3)]
log_g_sign = np.array(1, dtype="int8") log_gs = [np.array(-np.inf, dtype=dtype) if i in wrt else None for i in range(3)]
log_gs_signs = [np.array(1, dtype="int8") if i in wrt else None for i in range(3)]
log_t = np.array(0.0, dtype=dtype) log_t = np.array(0.0, dtype=dtype)
log_t_sign = np.array(1, dtype="int8") log_t_sign = np.array(1, dtype="int8")
log_z = log(scalar_abs(z)) log_z = log(scalar_abs(z))
sign_z = _unsafe_sign(z) sign_z = _unsafe_sign(z)
sign_zk = sign_z sign_zk = sign_z
k = np.array(0, dtype="int32") k = np.array(0, dtype="int32")
def inner_loop( def inner_loop(*args):
grad, (
log_g, *grads_vars,
log_g_sign,
log_t, log_t,
log_t_sign, log_t_sign,
sign_zk, sign_zk,
...@@ -1693,65 +1664,147 @@ def hyp2f1_grad(a, b, c, z, wrt: int): ...@@ -1693,65 +1664,147 @@ def hyp2f1_grad(a, b, c, z, wrt: int):
c, c,
log_z, log_z,
sign_z, sign_z,
): ) = args
p = (a + k) * (b + k) / ((c + k) * (k + 1))
if p.type.dtype != dtype: (
p = p.astype(dtype) grad_a,
grad_b,
term = log_g_sign * log_t_sign * exp(log_g - log_t) grad_c,
if wrt_a: log_g_a,
term += reciprocal(a + k) log_g_b,
elif wrt_b: log_g_c,
term += reciprocal(b + k) log_g_sign_a,
else: log_g_sign_b,
term -= reciprocal(c + k) log_g_sign_c,
) = grads_vars
p = (a + k) * (b + k) / ((c + k) * (k + 1))
if p.type.dtype != dtype:
p = p.astype(dtype)
# If p==0, don't update grad and get out of while loop next
p_zero = eq(p, 0)
if 0 in wrt:
term_a = log_g_sign_a * log_t_sign * exp(log_g_a - log_t)
term_a += reciprocal(a + k)
if term_a.type.dtype != dtype:
term_a = term_a.astype(dtype)
if 1 in wrt:
term_b = log_g_sign_b * log_t_sign * exp(log_g_b - log_t)
term_b += reciprocal(b + k)
if term_b.type.dtype != dtype:
term_b = term_b.astype(dtype)
if 2 in wrt:
term_c = log_g_sign_c * log_t_sign * exp(log_g_c - log_t)
term_c -= reciprocal(c + k)
if term_c.type.dtype != dtype:
term_c = term_c.astype(dtype)
log_t = log_t + log(scalar_abs(p)) + log_z
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
grads = [None] * 3
log_gs = [None] * 3
log_gs_signs = [None] * 3
grad_incs = [None] * 3
if 0 in wrt:
log_g_a = log_t + log(scalar_abs(term_a))
log_g_sign_a = (_unsafe_sign(term_a) * log_t_sign).astype("int8")
grad_inc_a = log_g_sign_a * exp(log_g_a) * sign_zk
grads[0] = switch(p_zero, grad_a, grad_a + grad_inc_a)
log_gs[0] = log_g_a
log_gs_signs[0] = log_g_sign_a
grad_incs[0] = grad_inc_a
if 1 in wrt:
log_g_b = log_t + log(scalar_abs(term_b))
log_g_sign_b = (_unsafe_sign(term_b) * log_t_sign).astype("int8")
grad_inc_b = log_g_sign_b * exp(log_g_b) * sign_zk
grads[1] = switch(p_zero, grad_b, grad_b + grad_inc_b)
log_gs[1] = log_g_b
log_gs_signs[1] = log_g_sign_b
grad_incs[1] = grad_inc_b
if 2 in wrt:
log_g_c = log_t + log(scalar_abs(term_c))
log_g_sign_c = (_unsafe_sign(term_c) * log_t_sign).astype("int8")
grad_inc_c = log_g_sign_c * exp(log_g_c) * sign_zk
grads[2] = switch(p_zero, grad_c, grad_c + grad_inc_c)
log_gs[2] = log_g_c
log_gs_signs[2] = log_g_sign_c
grad_incs[2] = grad_inc_c
sign_zk *= sign_z
k += 1
abs_grad_incs = [
scalar_abs(grad_inc) for grad_inc in grad_incs if grad_inc is not None
]
if len(grad_incs) == 1:
[max_abs_grad_inc] = grad_incs
else:
max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs)
if term.type.dtype != dtype: return (
term = term.astype(dtype) (*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k),
(eq(p, 0) | ((k > min_steps) & (max_abs_grad_inc <= precision))),
)
log_t = log_t + log(scalar_abs(p)) + log_z init = [*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k]
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8") constant = [a, b, c, log_z, sign_z]
log_g = log_t + log(scalar_abs(term)) loop_outs = _make_scalar_loop(
log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8") max_steps, init, constant, inner_loop, name="hyp2f1_grad", loop_op=Grad2F1Loop
)
return loop_outs[: len(wrt)]
g_current = log_g_sign * exp(log_g) * sign_zk
# If p==0, don't update grad and get out of while loop next def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]):
grad = switch( dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
eq(p, 0),
grad,
grad + g_current,
)
sign_zk *= sign_z def check_2f1_converges(a, b, c, z):
k += 1 def is_nonpositive_integer(x):
if x.type.dtype not in integer_types:
return eq(floor(x), x) & (x <= 0)
else:
return x <= 0
return ( a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
(grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k), num_terms = switch(
(eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))), a_is_polynomial,
) floor(scalar_abs(a)).astype("int64"),
0,
)
init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k] b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms)
constant = [a, b, c, log_z, sign_z] num_terms = switch(
grad = _make_scalar_loop( b_is_polynomial,
max_steps, init, constant, inner_loop, name="hyp2f1_grad" floor(scalar_abs(b)).astype("int64"),
num_terms,
) )
return switch( is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
eq(z, 0), is_polynomial = a_is_polynomial | b_is_polynomial
0,
grad, return (~is_undefined) & (
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
) )
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy # We have to pass the converges flag to interrupt the loop, as the switch is not lazy
z_is_zero = eq(z, 0) z_is_zero = eq(z, 0)
converges = check_2f1_converges(a, b, c, z) converges = check_2f1_converges(a, b, c, z)
return switch( grads = _grad_2f1_loop(
z_is_zero, a, b, c, z, skip_loop=z_is_zero | (~converges), wrt=wrt, dtype=dtype
0,
switch(
converges,
compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)),
np.nan,
),
) )
return [
switch(
z_is_zero,
0,
switch(
converges,
grad,
np.nan,
),
)
for grad in grads
]
...@@ -23,6 +23,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -23,6 +23,7 @@ from pytensor.graph.rewriting.basic import (
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.loop import ScalarLoop from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
MakeVector, MakeVector,
alloc, alloc,
...@@ -31,6 +32,7 @@ from pytensor.tensor.basic import ( ...@@ -31,6 +32,7 @@ from pytensor.tensor.basic import (
) )
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import exp
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.var import TensorConstant from pytensor.tensor.var import TensorConstant
...@@ -1215,3 +1217,61 @@ compile.optdb.register( # type: ignore ...@@ -1215,3 +1217,61 @@ compile.optdb.register( # type: ignore
"fusion", "fusion",
position=49, position=49,
) )
@register_specialize
@node_rewriter([Elemwise])
def local_useless_2f1grad_loop(fgraph, node):
# Remove unused terms from the hyp2f1 grad loop
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return
grad_related_vars = node.outputs[:-4]
# Rewrite was already applied
if len(grad_related_vars) // 3 != 3:
return None
grad_vars = grad_related_vars[:3]
grad_var_is_used = [bool(fgraph.clients.get(v)) for v in grad_vars]
# Nothing to do here
if sum(grad_var_is_used) == 3:
return None
# Check that None of the remaining vars is used anywhere
if any(bool(fgraph.clients.get(v)) for v in node.outputs[3:]):
return None
a, b, c, log_z, sign_z = node.inputs[-5:]
z = exp(log_z) * sign_z
# Reconstruct scalar loop with relevant outputs
a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z))
wrt = [i for i, used in enumerate(grad_var_is_used) if used]
new_loop_op = _grad_2f1_loop(
a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype
)[0].owner.op
# Reconstruct elemwise loop
new_elemwise_op = Elemwise(scalar_op=new_loop_op)
n_steps = node.inputs[0]
init_grad_vars = node.inputs[1:10]
other_inputs = node.inputs[10:]
init_grads = init_grad_vars[: len(wrt)]
init_gs = init_grad_vars[3 : 3 + len(wrt)]
init_gs_signs = init_grad_vars[6 : 6 + len(wrt)]
subset_init_grad_vars = init_grads + init_gs + init_gs_signs
new_outs = new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs)
replacements = {}
i = 0
for grad_var, is_used in zip(grad_vars, grad_var_is_used):
if not is_used:
continue
replacements[grad_var] = new_outs[i]
i += 1
return replacements
...@@ -4,6 +4,8 @@ import numpy as np ...@@ -4,6 +4,8 @@ import numpy as np
import pytest import pytest
from pytensor.gradient import verify_grad from pytensor.gradient import verify_grad
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import Elemwise
scipy = pytest.importorskip("scipy") scipy = pytest.importorskip("scipy")
...@@ -1052,3 +1054,38 @@ class TestHyp2F1Grad: ...@@ -1052,3 +1054,38 @@ class TestHyp2F1Grad:
expected_result, expected_result,
rtol=rtol, rtol=rtol,
) )
@pytest.mark.parametrize("wrt", ([0], [1], [2], [0, 1], [1, 2], [0, 2], [0, 1, 2]))
def test_unused_grad_loop_opt(self, wrt):
"""Test that we don't compute unnecessary outputs in the grad scalar loop"""
(
test_a1,
test_a2,
test_b1,
test_z,
*expected_dds,
expected_ddz,
) = self.few_iters_case
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
wrt_vars = [v for i, v in enumerate((a1, a2, b1, z)) if i in wrt]
hyp2f1_grad = at.grad(hyp2f1_out, wrt=wrt_vars)
mode = get_default_mode().including("local_useless_2f1grad_loop")
f_grad = function([a1, a2, b1, z], hyp2f1_grad, mode=mode)
[scalar_loop_op] = [
node.op.scalar_op
for node in f_grad.maker.fgraph.apply_nodes
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop)
]
assert scalar_loop_op.nin == 10 + 3 * len(wrt)
rtol = 1e-9 if config.floatX == "float64" else 2e-3
np.testing.assert_allclose(
f_grad(test_a1, test_a2, test_b1, test_z),
[dd for i, dd in enumerate(expected_dds) if i in wrt],
rtol=rtol,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论