提交 a8cc03c8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove `until_condition_failed` in ScalarLoop

This was problematic when OpenMP was used in the Elemwise outer loop We add one extra output flag stating whether iteration converged or not. This however breaks Hyp2F1 grad in python mode because it goes beyond the Elemwise limit on number of operands. To fix it we split the grad when on python mode
上级 df2ffe4b
import warnings
from copy import copy
from itertools import chain
from textwrap import dedent
from typing import Literal, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, cast
from pytensor.compile import rebuild_collect_shared
from pytensor.graph import Constant, FunctionGraph, Variable, clone
......@@ -14,7 +12,33 @@ class ScalarLoop(ScalarInnerGraphOp):
"""Scalar Op that encapsulates a scalar loop operation.
This Op can be used for the gradient of other Scalar Ops.
It is much more restricted that `Scan` in that the entire inner graph must be composed of Scalar operations.
It is much more restricted than `Scan` in that the entire inner graph
must be composed of Scalar operations, and all inputs and outputs must be ScalarVariables.
The pseudocode of the computation performed by this Op looks like the following:
```python
def scalar_for_loop(fn, n_steps, init, update, constant):
for i in range(n_steps):
state = fn(*state, *constant)
return state
```
When an until condition is present it behaves like this:
```python
def scalar_while_loop(fn, n_steps, init, update, constant):
# If n_steps <= 0, we skip the loop altogether.
# This does not count as a "failure"
done = True
for i in range(n_steps):
*state, done = fn(*state, *constant)
if done:
break
return *state, done
```
"""
......@@ -23,7 +47,6 @@ class ScalarLoop(ScalarInnerGraphOp):
"update",
"constant",
"until",
"until_condition_failed",
)
def __init__(
......@@ -32,14 +55,8 @@ class ScalarLoop(ScalarInnerGraphOp):
update: Sequence[Variable],
constant: Optional[Sequence[Variable]] = None,
until: Optional[Variable] = None,
until_condition_failed: Literal["ignore", "warn", "raise"] = "warn",
name="ScalarLoop",
):
if until_condition_failed not in ["ignore", "warn", "raise"]:
raise ValueError(
f"Invalid until_condition_failed: {until_condition_failed}"
)
if constant is None:
constant = []
if not len(init) == len(update):
......@@ -52,12 +69,13 @@ class ScalarLoop(ScalarInnerGraphOp):
self.outputs = copy(outputs)
self.inputs = copy(inputs)
self.is_while = bool(until)
self.inputs_type = tuple(input.type for input in inputs)
self.outputs_type = tuple(output.type for output in outputs)
if self.is_while:
self.outputs_type = self.outputs_type + (cast(Variable, until).type,)
self.nin = len(inputs) + 1 # n_steps is not part of the inner graph
self.nout = len(outputs) # until is not output
self.is_while = bool(until)
self.until_condition_failed = until_condition_failed
self.nout = len(outputs) + (1 if self.is_while else 0)
self.name = name
self._validate_fgraph(FunctionGraph(self.inputs, self.outputs, clone=False))
super().__init__()
......@@ -135,7 +153,6 @@ class ScalarLoop(ScalarInnerGraphOp):
update=update,
constant=constant,
until=until,
until_condition_failed=self.until_condition_failed,
name=self.name,
)
......@@ -191,7 +208,6 @@ class ScalarLoop(ScalarInnerGraphOp):
update=cloned_update,
constant=cloned_constant,
until=cloned_until,
until_condition_failed=self.until_condition_failed,
name=self.name,
)
node = op.make_node(n_steps, *inputs)
......@@ -209,17 +225,8 @@ class ScalarLoop(ScalarInnerGraphOp):
*carry, until = inner_fn(*carry, *constant)
if until:
break
carry.append(until)
if not until: # no-break
if self.until_condition_failed == "raise":
raise RuntimeError(
f"Until condition in ScalarLoop {self.name} not reached!"
)
elif self.until_condition_failed == "warn":
warnings.warn(
f"Until condition in ScalarLoop {self.name} not reached!",
RuntimeWarning,
)
else:
if n_steps < 0:
raise ValueError("ScalarLoop does not have a termination condition.")
......@@ -324,27 +331,12 @@ class ScalarLoop(ScalarInnerGraphOp):
if self.is_while:
_c_code += "\nif(until){break;}\n"
# End of the loop
_c_code += "}\n"
# End of the loop
# Output until flag
if self.is_while:
if self.until_condition_failed == "raise":
_c_code += dedent(
f"""
if (!until) {{
PyErr_SetString(PyExc_RuntimeError, "Until condition in ScalarLoop {self.name} not reached!");
%(fail)s
}}
"""
)
elif self.until_condition_failed == "warn":
_c_code += dedent(
f"""
if (!until) {{
PyErr_WarnEx(PyExc_RuntimeWarning, "Until condition in ScalarLoop {self.name} not reached!", 1);
}}
"""
)
_c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n"
_c_code += "}\n"
......@@ -376,13 +368,4 @@ class ScalarLoop(ScalarInnerGraphOp):
return res
def c_code_cache_version_outer(self):
return (1,)
def __eq__(self, other):
return (
super().__eq__(other)
and self.until_condition_failed == other.until_condition_failed
)
def __hash__(self):
return hash((super().__hash__(), self.until_condition_failed))
return (2,)
......@@ -703,7 +703,6 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=Scal
constant=constant_,
update=update_,
until=until_,
until_condition_failed="warn",
name=name,
)
return op(n_steps, *init, *constant)
......@@ -747,9 +746,10 @@ def gammainc_grad(k, x):
init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
constant = [log_x]
sum_a, *_ = _make_scalar_loop(
sum_a, *_, sum_a_converges = _make_scalar_loop(
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
)
sum_a = switch(sum_a_converges, sum_a, np.nan)
# Second loop
n = np.array(0, dtype="int32")
......@@ -772,9 +772,10 @@ def gammainc_grad(k, x):
init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n]
constant = [log_x]
sum_b, *_ = _make_scalar_loop(
sum_b, *_, sum_b_converges = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammainc_grad_b"
)
sum_b = switch(sum_b_converges, sum_b, np.nan)
grad_approx = exp(-x) * (log_x * sum_a - sum_b)
return grad_approx
......@@ -877,9 +878,10 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
init = [sum_b0, log_s, s_sign, log_delta, n]
constant = [k, log_x]
sum_b, *_ = _make_scalar_loop(
sum_b, *_, sum_b_converges = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
)
sum_b = switch(sum_b_converges, sum_b, np.nan)
grad_approx_b = (
gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k
)
......@@ -1547,10 +1549,10 @@ def betainc_grad(p, q, x, wrtp: bool):
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
constant = [f, p, q, K, dK]
grad, *_ = _make_scalar_loop(
grad, *_, grad_converges = _make_scalar_loop(
max_iters, init, constant, inner_loop, name="betainc_grad"
)
return grad
return switch(grad_converges, grad, np.nan)
# Input validation
nan_branch = (x < 0) | (x > 1) | (p < 0) | (q < 0)
......@@ -1752,10 +1754,10 @@ def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype):
init = [*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k]
constant = [a, b, c, log_z, sign_z]
loop_outs = _make_scalar_loop(
*loop_outs, converges = _make_scalar_loop(
max_steps, init, constant, inner_loop, name="hyp2f1_grad", loop_op=Grad2F1Loop
)
return loop_outs[: len(wrt)]
return *loop_outs[: len(wrt)], converges
def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]):
......@@ -1792,7 +1794,7 @@ def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]):
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy
z_is_zero = eq(z, 0)
converges = check_2f1_converges(a, b, c, z)
grads = _grad_2f1_loop(
*grads, grad_converges = _grad_2f1_loop(
a, b, c, z, skip_loop=z_is_zero | (~converges), wrt=wrt, dtype=dtype
)
......
......@@ -1219,6 +1219,30 @@ compile.optdb.register( # type: ignore
)
def _rebuild_partial_2f1grad_loop(node, wrt):
a, b, c, log_z, sign_z = node.inputs[-5:]
z = exp(log_z) * sign_z
# Reconstruct scalar loop with relevant outputs
a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z))
new_loop_op = _grad_2f1_loop(
a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype
)[0].owner.op
# Reconstruct elemwise loop
new_elemwise_op = Elemwise(scalar_op=new_loop_op)
n_steps = node.inputs[0]
init_grad_vars = node.inputs[1:10]
other_inputs = node.inputs[10:]
init_grads = init_grad_vars[: len(wrt)]
init_gs = init_grad_vars[3 : 3 + len(wrt)]
init_gs_signs = init_grad_vars[6 : 6 + len(wrt)]
subset_init_grad_vars = init_grads + init_gs + init_gs_signs
return new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs)
@register_specialize
@node_rewriter([Elemwise])
def local_useless_2f1grad_loop(fgraph, node):
......@@ -1240,34 +1264,16 @@ def local_useless_2f1grad_loop(fgraph, node):
if sum(grad_var_is_used) == 3:
return None
# Check that None of the remaining vars is used anywhere
if any(bool(fgraph.clients.get(v)) for v in node.outputs[3:]):
return None
*other_vars, converges = node.outputs[3:]
a, b, c, log_z, sign_z = node.inputs[-5:]
z = exp(log_z) * sign_z
# Check that None of the remaining vars (except the converge flag) is used anywhere
if any(bool(fgraph.clients.get(v)) for v in other_vars):
return None
# Reconstruct scalar loop with relevant outputs
a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z))
wrt = [i for i, used in enumerate(grad_var_is_used) if used]
new_loop_op = _grad_2f1_loop(
a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype
)[0].owner.op
*new_outs, new_converges = _rebuild_partial_2f1grad_loop(node, wrt=wrt)
# Reconstruct elemwise loop
new_elemwise_op = Elemwise(scalar_op=new_loop_op)
n_steps = node.inputs[0]
init_grad_vars = node.inputs[1:10]
other_inputs = node.inputs[10:]
init_grads = init_grad_vars[: len(wrt)]
init_gs = init_grad_vars[3 : 3 + len(wrt)]
init_gs_signs = init_grad_vars[6 : 6 + len(wrt)]
subset_init_grad_vars = init_grads + init_gs + init_gs_signs
new_outs = new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs)
replacements = {}
replacements = {converges: new_converges}
i = 0
for grad_var, is_used in zip(grad_vars, grad_var_is_used):
if not is_used:
......@@ -1275,3 +1281,48 @@ def local_useless_2f1grad_loop(fgraph, node):
replacements[grad_var] = new_outs[i]
i += 1
return replacements
@node_rewriter([Elemwise])
def split_2f1grad_loop(fgraph, node):
"""
2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode.
This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied
"""
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return None
grad_related_vars = node.outputs[:-4]
# local_useless_2f1grad_loop was used, we should be safe
if len(grad_related_vars) // 3 != 3:
return None
grad_vars = grad_related_vars[:3]
*other_vars, converges = node.outputs[3:]
# Check that None of the remaining vars is used anywhere
if any(bool(fgraph.clients.get(v)) for v in other_vars):
return None
new_grad0, new_grad1, *_, new_converges01 = _rebuild_partial_2f1grad_loop(
node, wrt=[0, 1]
)
new_grad2, *_, new_converges2 = _rebuild_partial_2f1grad_loop(node, wrt=[2])
replacements = {
converges: new_converges01 & new_converges2,
grad_vars[0]: new_grad0,
grad_vars[1]: new_grad1,
grad_vars[2]: new_grad2,
}
return replacements
compile.optdb["py_only"].register( # type: ignore
"split_2f1grad_loop",
split_2f1grad_loop,
"fast_compile",
)
......@@ -103,41 +103,11 @@ def test_until(mode):
x = x0 + 1
until = x >= 10
op = ScalarLoop(init=[x0], update=[x], until=until, until_condition_failed="ignore")
op = ScalarLoop(init=[x0], update=[x], until=until)
fn = function([n_steps, x0], op(n_steps, x0), mode=mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), 10)
np.testing.assert_allclose(fn(n_steps=20, x0=1), 10)
np.testing.assert_allclose(fn(n_steps=5, x0=1), 6)
op = ScalarLoop(
init=[x0],
update=[x],
until=until,
until_condition_failed="warn",
name="TestLoop",
)
fn = function([n_steps, x0], op(n_steps, x0), mode=mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), 10)
np.testing.assert_allclose(fn(n_steps=20, x0=1), 10)
with pytest.warns(
RuntimeWarning, match="Until condition in ScalarLoop TestLoop not reached!"
):
np.testing.assert_allclose(fn(n_steps=5, x0=1), 6)
op = ScalarLoop(
init=[x0],
update=[x],
until=until,
until_condition_failed="raise",
name="TestLoop",
)
fn = function([n_steps, x0], op(n_steps, x0), mode=mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), 10)
np.testing.assert_allclose(fn(n_steps=20, x0=1), 10)
with pytest.raises(
RuntimeError, match="Until condition in ScalarLoop TestLoop not reached!"
):
fn(n_steps=5, x0=1)
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
def test_update_missing_error():
......
......@@ -1075,13 +1075,24 @@ class TestHyp2F1Grad:
mode = get_default_mode().including("local_useless_2f1grad_loop")
f_grad = function([a1, a2, b1, z], hyp2f1_grad, mode=mode)
[scalar_loop_op] = [
node.op.scalar_op
for node in f_grad.maker.fgraph.apply_nodes
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop)
]
assert scalar_loop_op.nin == 10 + 3 * len(wrt)
if len(wrt) == 3 and config.mode == "FAST_COMPILE" or not config.cxx:
# In this case we actually get two scalar_loops, because the merged one can't be executed in Python
[scalar_loop_op1, scalar_loop_op2] = [
node.op.scalar_op
for node in f_grad.maker.fgraph.toposort()
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop)
]
assert scalar_loop_op1.nin == 10 + 3 * 2 # wrt=[0, 1]
assert scalar_loop_op2.nin == 10 + 3 * 1 # wrt=[2]
else:
[scalar_loop_op] = [
node.op.scalar_op
for node in f_grad.maker.fgraph.apply_nodes
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop)
]
assert scalar_loop_op.nin == 10 + 3 * len(wrt)
rtol = 1e-9 if config.floatX == "float64" else 2e-3
np.testing.assert_allclose(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论