提交 fc0d9ec2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fuse hyp2f1 grads

上级 86fe383d
差异被折叠。
......@@ -23,6 +23,7 @@ from pytensor.graph.rewriting.basic import (
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
MakeVector,
alloc,
......@@ -31,6 +32,7 @@ from pytensor.tensor.basic import (
)
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import exp
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.var import TensorConstant
......@@ -1215,3 +1217,61 @@ compile.optdb.register( # type: ignore
"fusion",
position=49,
)
@register_specialize
@node_rewriter([Elemwise])
def local_useless_2f1grad_loop(fgraph, node):
# Remove unused terms from the hyp2f1 grad loop
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return
grad_related_vars = node.outputs[:-4]
# Rewrite was already applied
if len(grad_related_vars) // 3 != 3:
return None
grad_vars = grad_related_vars[:3]
grad_var_is_used = [bool(fgraph.clients.get(v)) for v in grad_vars]
# Nothing to do here
if sum(grad_var_is_used) == 3:
return None
# Check that None of the remaining vars is used anywhere
if any(bool(fgraph.clients.get(v)) for v in node.outputs[3:]):
return None
a, b, c, log_z, sign_z = node.inputs[-5:]
z = exp(log_z) * sign_z
# Reconstruct scalar loop with relevant outputs
a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z))
wrt = [i for i, used in enumerate(grad_var_is_used) if used]
new_loop_op = _grad_2f1_loop(
a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype
)[0].owner.op
# Reconstruct elemwise loop
new_elemwise_op = Elemwise(scalar_op=new_loop_op)
n_steps = node.inputs[0]
init_grad_vars = node.inputs[1:10]
other_inputs = node.inputs[10:]
init_grads = init_grad_vars[: len(wrt)]
init_gs = init_grad_vars[3 : 3 + len(wrt)]
init_gs_signs = init_grad_vars[6 : 6 + len(wrt)]
subset_init_grad_vars = init_grads + init_gs + init_gs_signs
new_outs = new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs)
replacements = {}
i = 0
for grad_var, is_used in zip(grad_vars, grad_var_is_used):
if not is_used:
continue
replacements[grad_var] = new_outs[i]
i += 1
return replacements
......@@ -4,6 +4,8 @@ import numpy as np
import pytest
from pytensor.gradient import verify_grad
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import Elemwise
scipy = pytest.importorskip("scipy")
......@@ -1052,3 +1054,38 @@ class TestHyp2F1Grad:
expected_result,
rtol=rtol,
)
@pytest.mark.parametrize("wrt", ([0], [1], [2], [0, 1], [1, 2], [0, 2], [0, 1, 2]))
def test_unused_grad_loop_opt(self, wrt):
"""Test that we don't compute unnecessary outputs in the grad scalar loop"""
(
test_a1,
test_a2,
test_b1,
test_z,
*expected_dds,
expected_ddz,
) = self.few_iters_case
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
hyp2f1_out = at.hyp2f1(a1, a2, b1, z)
wrt_vars = [v for i, v in enumerate((a1, a2, b1, z)) if i in wrt]
hyp2f1_grad = at.grad(hyp2f1_out, wrt=wrt_vars)
mode = get_default_mode().including("local_useless_2f1grad_loop")
f_grad = function([a1, a2, b1, z], hyp2f1_grad, mode=mode)
[scalar_loop_op] = [
node.op.scalar_op
for node in f_grad.maker.fgraph.apply_nodes
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop)
]
assert scalar_loop_op.nin == 10 + 3 * len(wrt)
rtol = 1e-9 if config.floatX == "float64" else 2e-3
np.testing.assert_allclose(
f_grad(test_a1, test_a2, test_b1, test_z),
[dd for i, dd in enumerate(expected_dds) if i in wrt],
rtol=rtol,
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论