提交 fc0d9ec2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fuse hyp2f1 grads

上级 86fe383d
...@@ -5,7 +5,9 @@ As SciPy is not always available, we treat them separately. ...@@ -5,7 +5,9 @@ As SciPy is not always available, we treat them separately.
""" """
import os import os
from functools import reduce
from textwrap import dedent from textwrap import dedent
from typing import Tuple
import numpy as np import numpy as np
import scipy.special import scipy.special
...@@ -683,14 +685,20 @@ class GammaIncC(BinaryScalarOp): ...@@ -683,14 +685,20 @@ class GammaIncC(BinaryScalarOp):
gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") gammaincc = GammaIncC(upgrade_to_float, name="gammaincc")
def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name): def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=ScalarLoop):
init = [as_scalar(x) for x in init] init = [as_scalar(x) if x is not None else None for x in init]
constant = [as_scalar(x) for x in constant] constant = [as_scalar(x) for x in constant]
# Create dummy types, in case some variables have the same initial form # Create dummy types, in case some variables have the same initial form
init_ = [x.type() for x in init] init_ = [x.type() if x is not None else None for x in init]
constant_ = [x.type() for x in constant] constant_ = [x.type() for x in constant]
update_, until_ = inner_loop_fn(*init_, *constant_) update_, until_ = inner_loop_fn(*init_, *constant_)
op = ScalarLoop(
# Filter Nones
init = [i for i in init if i is not None]
init_ = [i for i in init_ if i is not None]
update_ = [u for u in update_ if u is not None]
op = loop_op(
init=init_, init=init_,
constant=constant_, constant=constant_,
update=update_, update=update_,
...@@ -698,8 +706,7 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name): ...@@ -698,8 +706,7 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name):
until_condition_failed="warn", until_condition_failed="warn",
name=name, name=name,
) )
S, *_ = op(n_steps, *init, *constant) return op(n_steps, *init, *constant)
return S
def gammainc_grad(k, x): def gammainc_grad(k, x):
...@@ -740,7 +747,7 @@ def gammainc_grad(k, x): ...@@ -740,7 +747,7 @@ def gammainc_grad(k, x):
init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n] init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
constant = [log_x] constant = [log_x]
sum_a = _make_scalar_loop( sum_a, *_ = _make_scalar_loop(
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a" max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
) )
...@@ -827,7 +834,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): ...@@ -827,7 +834,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac] init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
constant = [x] constant = [x]
sum_a = _make_scalar_loop( sum_a, *_ = _make_scalar_loop(
n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a" n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a"
) )
grad_approx_a = ( grad_approx_a = (
...@@ -870,7 +877,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): ...@@ -870,7 +877,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
init = [sum_b0, log_s, s_sign, log_delta, n] init = [sum_b0, log_s, s_sign, log_delta, n]
constant = [k, log_x] constant = [k, log_x]
sum_b = _make_scalar_loop( sum_b, *_ = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b" max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
) )
grad_approx_b = ( grad_approx_b = (
...@@ -1540,7 +1547,7 @@ def betainc_grad(p, q, x, wrtp: bool): ...@@ -1540,7 +1547,7 @@ def betainc_grad(p, q, x, wrtp: bool):
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n] init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
constant = [f, p, q, K, dK] constant = [f, p, q, K, dK]
grad = _make_scalar_loop( grad, *_ = _make_scalar_loop(
max_iters, init, constant, inner_loop, name="betainc_grad" max_iters, init, constant, inner_loop, name="betainc_grad"
) )
return grad return grad
...@@ -1579,10 +1586,11 @@ class Hyp2F1(ScalarOp): ...@@ -1579,10 +1586,11 @@ class Hyp2F1(ScalarOp):
def grad(self, inputs, grads): def grad(self, inputs, grads):
a, b, c, z = inputs a, b, c, z = inputs
(gz,) = grads (gz,) = grads
grad_a, grad_b, grad_c = hyp2f1_grad(a, b, c, z, wrt=[0, 1, 2])
return [ return [
gz * hyp2f1_grad(a, b, c, z, wrt=0), gz * grad_a,
gz * hyp2f1_grad(a, b, c, z, wrt=1), gz * grad_b,
gz * hyp2f1_grad(a, b, c, z, wrt=2), gz * grad_c,
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z), gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
] ]
...@@ -1598,38 +1606,11 @@ def _unsafe_sign(x): ...@@ -1598,38 +1606,11 @@ def _unsafe_sign(x):
return switch(x > 0, 1, -1) return switch(x > 0, 1, -1)
def hyp2f1_grad(a, b, c, z, wrt: int): class Grad2F1Loop(ScalarLoop):
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32") """Subclass of ScalarLoop for easier targetting in rewrites"""
def check_2f1_converges(a, b, c, z):
def is_nonpositive_integer(x):
if x.type.dtype not in integer_types:
return eq(floor(x), x) & (x <= 0)
else:
return x <= 0
a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
num_terms = switch(
a_is_polynomial,
floor(scalar_abs(a)).astype("int64"),
0,
)
b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms) def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype):
num_terms = switch(
b_is_polynomial,
floor(scalar_abs(b)).astype("int64"),
num_terms,
)
is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
is_polynomial = a_is_polynomial | b_is_polynomial
return (~is_undefined) & (
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
)
def compute_grad_2f1(a, b, c, z, wrt, skip_loop):
""" """
Notes Notes
----- -----
...@@ -1650,14 +1631,6 @@ def hyp2f1_grad(a, b, c, z, wrt: int): ...@@ -1650,14 +1631,6 @@ def hyp2f1_grad(a, b, c, z, wrt: int):
tracking their signs. tracking their signs.
""" """
wrt_a = wrt_b = False
if wrt == 0:
wrt_a = True
elif wrt == 1:
wrt_b = True
elif wrt != 2:
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
min_steps = np.array( min_steps = np.array(
10, dtype="int32" 10, dtype="int32"
) # https://github.com/stan-dev/math/issues/2857 ) # https://github.com/stan-dev/math/issues/2857
...@@ -1666,10 +1639,9 @@ def hyp2f1_grad(a, b, c, z, wrt: int): ...@@ -1666,10 +1639,9 @@ def hyp2f1_grad(a, b, c, z, wrt: int):
) )
precision = np.array(1e-14, dtype=config.floatX) precision = np.array(1e-14, dtype=config.floatX)
grad = np.array(0, dtype=dtype) grads = [np.array(0, dtype=dtype) if i in wrt else None for i in range(3)]
log_gs = [np.array(-np.inf, dtype=dtype) if i in wrt else None for i in range(3)]
log_g = np.array(-np.inf, dtype=dtype) log_gs_signs = [np.array(1, dtype="int8") if i in wrt else None for i in range(3)]
log_g_sign = np.array(1, dtype="int8")
log_t = np.array(0.0, dtype=dtype) log_t = np.array(0.0, dtype=dtype)
log_t_sign = np.array(1, dtype="int8") log_t_sign = np.array(1, dtype="int8")
...@@ -1680,10 +1652,9 @@ def hyp2f1_grad(a, b, c, z, wrt: int): ...@@ -1680,10 +1652,9 @@ def hyp2f1_grad(a, b, c, z, wrt: int):
sign_zk = sign_z sign_zk = sign_z
k = np.array(0, dtype="int32") k = np.array(0, dtype="int32")
def inner_loop( def inner_loop(*args):
grad, (
log_g, *grads_vars,
log_g_sign,
log_t, log_t,
log_t_sign, log_t_sign,
sign_zk, sign_zk,
...@@ -1693,65 +1664,147 @@ def hyp2f1_grad(a, b, c, z, wrt: int): ...@@ -1693,65 +1664,147 @@ def hyp2f1_grad(a, b, c, z, wrt: int):
c, c,
log_z, log_z,
sign_z, sign_z,
): ) = args
(
grad_a,
grad_b,
grad_c,
log_g_a,
log_g_b,
log_g_c,
log_g_sign_a,
log_g_sign_b,
log_g_sign_c,
) = grads_vars
p = (a + k) * (b + k) / ((c + k) * (k + 1)) p = (a + k) * (b + k) / ((c + k) * (k + 1))
if p.type.dtype != dtype: if p.type.dtype != dtype:
p = p.astype(dtype) p = p.astype(dtype)
term = log_g_sign * log_t_sign * exp(log_g - log_t) # If p==0, don't update grad and get out of while loop next
if wrt_a: p_zero = eq(p, 0)
term += reciprocal(a + k)
elif wrt_b: if 0 in wrt:
term += reciprocal(b + k) term_a = log_g_sign_a * log_t_sign * exp(log_g_a - log_t)
else: term_a += reciprocal(a + k)
term -= reciprocal(c + k) if term_a.type.dtype != dtype:
term_a = term_a.astype(dtype)
if term.type.dtype != dtype: if 1 in wrt:
term = term.astype(dtype) term_b = log_g_sign_b * log_t_sign * exp(log_g_b - log_t)
term_b += reciprocal(b + k)
if term_b.type.dtype != dtype:
term_b = term_b.astype(dtype)
if 2 in wrt:
term_c = log_g_sign_c * log_t_sign * exp(log_g_c - log_t)
term_c -= reciprocal(c + k)
if term_c.type.dtype != dtype:
term_c = term_c.astype(dtype)
log_t = log_t + log(scalar_abs(p)) + log_z log_t = log_t + log(scalar_abs(p)) + log_z
log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8") log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8")
log_g = log_t + log(scalar_abs(term))
log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8")
g_current = log_g_sign * exp(log_g) * sign_zk grads = [None] * 3
log_gs = [None] * 3
# If p==0, don't update grad and get out of while loop next log_gs_signs = [None] * 3
grad = switch( grad_incs = [None] * 3
eq(p, 0),
grad, if 0 in wrt:
grad + g_current, log_g_a = log_t + log(scalar_abs(term_a))
) log_g_sign_a = (_unsafe_sign(term_a) * log_t_sign).astype("int8")
grad_inc_a = log_g_sign_a * exp(log_g_a) * sign_zk
grads[0] = switch(p_zero, grad_a, grad_a + grad_inc_a)
log_gs[0] = log_g_a
log_gs_signs[0] = log_g_sign_a
grad_incs[0] = grad_inc_a
if 1 in wrt:
log_g_b = log_t + log(scalar_abs(term_b))
log_g_sign_b = (_unsafe_sign(term_b) * log_t_sign).astype("int8")
grad_inc_b = log_g_sign_b * exp(log_g_b) * sign_zk
grads[1] = switch(p_zero, grad_b, grad_b + grad_inc_b)
log_gs[1] = log_g_b
log_gs_signs[1] = log_g_sign_b
grad_incs[1] = grad_inc_b
if 2 in wrt:
log_g_c = log_t + log(scalar_abs(term_c))
log_g_sign_c = (_unsafe_sign(term_c) * log_t_sign).astype("int8")
grad_inc_c = log_g_sign_c * exp(log_g_c) * sign_zk
grads[2] = switch(p_zero, grad_c, grad_c + grad_inc_c)
log_gs[2] = log_g_c
log_gs_signs[2] = log_g_sign_c
grad_incs[2] = grad_inc_c
sign_zk *= sign_z sign_zk *= sign_z
k += 1 k += 1
abs_grad_incs = [
scalar_abs(grad_inc) for grad_inc in grad_incs if grad_inc is not None
]
if len(grad_incs) == 1:
[max_abs_grad_inc] = grad_incs
else:
max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs)
return ( return (
(grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k), (*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k),
(eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))), (eq(p, 0) | ((k > min_steps) & (max_abs_grad_inc <= precision))),
) )
init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k] init = [*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k]
constant = [a, b, c, log_z, sign_z] constant = [a, b, c, log_z, sign_z]
grad = _make_scalar_loop( loop_outs = _make_scalar_loop(
max_steps, init, constant, inner_loop, name="hyp2f1_grad" max_steps, init, constant, inner_loop, name="hyp2f1_grad", loop_op=Grad2F1Loop
) )
return loop_outs[: len(wrt)]
return switch(
eq(z, 0), def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]):
dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32")
def check_2f1_converges(a, b, c, z):
def is_nonpositive_integer(x):
if x.type.dtype not in integer_types:
return eq(floor(x), x) & (x <= 0)
else:
return x <= 0
a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0)
num_terms = switch(
a_is_polynomial,
floor(scalar_abs(a)).astype("int64"),
0, 0,
grad, )
b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms)
num_terms = switch(
b_is_polynomial,
floor(scalar_abs(b)).astype("int64"),
num_terms,
)
is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms)
is_polynomial = a_is_polynomial | b_is_polynomial
return (~is_undefined) & (
is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b)))
) )
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy # We have to pass the converges flag to interrupt the loop, as the switch is not lazy
z_is_zero = eq(z, 0) z_is_zero = eq(z, 0)
converges = check_2f1_converges(a, b, c, z) converges = check_2f1_converges(a, b, c, z)
return switch( grads = _grad_2f1_loop(
a, b, c, z, skip_loop=z_is_zero | (~converges), wrt=wrt, dtype=dtype
)
return [
switch(
z_is_zero, z_is_zero,
0, 0,
switch( switch(
converges, converges,
compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)), grad,
np.nan, np.nan,
), ),
) )
for grad in grads
]
...@@ -23,6 +23,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -23,6 +23,7 @@ from pytensor.graph.rewriting.basic import (
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.loop import ScalarLoop from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
MakeVector, MakeVector,
alloc, alloc,
...@@ -31,6 +32,7 @@ from pytensor.tensor.basic import ( ...@@ -31,6 +32,7 @@ from pytensor.tensor.basic import (
) )
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import exp
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.var import TensorConstant from pytensor.tensor.var import TensorConstant
...@@ -1215,3 +1217,61 @@ compile.optdb.register( # type: ignore ...@@ -1215,3 +1217,61 @@ compile.optdb.register( # type: ignore
"fusion", "fusion",
position=49, position=49,
) )
@register_specialize
@node_rewriter([Elemwise])
def local_useless_2f1grad_loop(fgraph, node):
# Remove unused terms from the hyp2f1 grad loop
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return
grad_related_vars = node.outputs[:-4]
# Rewrite was already applied
if len(grad_related_vars) // 3 != 3:
return None
grad_vars = grad_related_vars[:3]
grad_var_is_used = [bool(fgraph.clients.get(v)) for v in grad_vars]
# Nothing to do here
if sum(grad_var_is_used) == 3:
return None
# Check that None of the remaining vars is used anywhere
if any(bool(fgraph.clients.get(v)) for v in node.outputs[3:]):
return None
a, b, c, log_z, sign_z = node.inputs[-5:]
z = exp(log_z) * sign_z
# Reconstruct scalar loop with relevant outputs
a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z))
wrt = [i for i, used in enumerate(grad_var_is_used) if used]
new_loop_op = _grad_2f1_loop(
a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype
)[0].owner.op
# Reconstruct elemwise loop
new_elemwise_op = Elemwise(scalar_op=new_loop_op)
n_steps = node.inputs[0]
init_grad_vars = node.inputs[1:10]
other_inputs = node.inputs[10:]
init_grads = init_grad_vars[: len(wrt)]
init_gs = init_grad_vars[3 : 3 + len(wrt)]
init_gs_signs = init_grad_vars[6 : 6 + len(wrt)]
subset_init_grad_vars = init_grads + init_gs + init_gs_signs
new_outs = new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs)
replacements = {}
i = 0
for grad_var, is_used in zip(grad_vars, grad_var_is_used):
if not is_used:
continue
replacements[grad_var] = new_outs[i]
i += 1
return replacements
...@@ -4,6 +4,8 @@ import numpy as np ...@@ -4,6 +4,8 @@ import numpy as np
import pytest import pytest
from pytensor.gradient import verify_grad from pytensor.gradient import verify_grad
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import Elemwise
scipy = pytest.importorskip("scipy") scipy = pytest.importorskip("scipy")
...@@ -1052,3 +1054,38 @@ class TestHyp2F1Grad: ...@@ -1052,3 +1054,38 @@ class TestHyp2F1Grad:
expected_result, expected_result,
rtol=rtol, rtol=rtol,
) )
@pytest.mark.parametrize("wrt", ([0], [1], [2], [0, 1], [1, 2], [0, 2], [0, 1, 2]))
def test_unused_grad_loop_opt(self, wrt):
"""Test that we don't compute unnecessary outputs in the grad scalar loop"""
(
test_a1,
test_a2,
test_b1,
test_z,
*expected_dds,
expected_ddz,
) = self.few_iters_case
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
wrt_vars = [v for i, v in enumerate((a1, a2, b1, z)) if i in wrt]
hyp2f1_grad = at.grad(hyp2f1_out, wrt=wrt_vars)
mode = get_default_mode().including("local_useless_2f1grad_loop")
f_grad = function([a1, a2, b1, z], hyp2f1_grad, mode=mode)
[scalar_loop_op] = [
node.op.scalar_op
for node in f_grad.maker.fgraph.apply_nodes
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop)
]
assert scalar_loop_op.nin == 10 + 3 * len(wrt)
rtol = 1e-9 if config.floatX == "float64" else 2e-3
np.testing.assert_allclose(
f_grad(test_a1, test_a2, test_b1, test_z),
[dd for i, dd in enumerate(expected_dds) if i in wrt],
rtol=rtol,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论