提交 a8cc03c8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove `until_condition_failed` in ScalarLoop

This was problematic when OpenMP was used in the Elemwise outer loop We add one extra output flag stating whether iteration converged or not. This however breaks Hyp2F1 grad in python mode because it goes beyond the Elemwise limit on number of operands. To fix it we split the grad when on python mode
上级 df2ffe4b
import warnings
from copy import copy from copy import copy
from itertools import chain from itertools import chain
from textwrap import dedent from typing import Optional, Sequence, Tuple, cast
from typing import Literal, Optional, Sequence, Tuple
from pytensor.compile import rebuild_collect_shared from pytensor.compile import rebuild_collect_shared
from pytensor.graph import Constant, FunctionGraph, Variable, clone from pytensor.graph import Constant, FunctionGraph, Variable, clone
...@@ -14,7 +12,33 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -14,7 +12,33 @@ class ScalarLoop(ScalarInnerGraphOp):
"""Scalar Op that encapsulates a scalar loop operation. """Scalar Op that encapsulates a scalar loop operation.
This Op can be used for the gradient of other Scalar Ops. This Op can be used for the gradient of other Scalar Ops.
It is much more restricted that `Scan` in that the entire inner graph must be composed of Scalar operations. It is much more restricted than `Scan` in that the entire inner graph
must be composed of Scalar operations, and all inputs and outputs must be ScalarVariables.
The pseudocode of the computation performed by this Op looks like the following:
```python
def scalar_for_loop(fn, n_steps, init, update, constant):
for i in range(n_steps):
state = fn(*state, *constant)
return state
```
When an until condition is present it behaves like this:
```python
def scalar_while_loop(fn, n_steps, init, update, constant):
# If n_steps <= 0, we skip the loop altogether.
# This does not count as a "failure"
done = True
for i in range(n_steps):
*state, done = fn(*state, *constant)
if done:
break
return *state, done
```
""" """
...@@ -23,7 +47,6 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -23,7 +47,6 @@ class ScalarLoop(ScalarInnerGraphOp):
"update", "update",
"constant", "constant",
"until", "until",
"until_condition_failed",
) )
def __init__( def __init__(
...@@ -32,14 +55,8 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -32,14 +55,8 @@ class ScalarLoop(ScalarInnerGraphOp):
update: Sequence[Variable], update: Sequence[Variable],
constant: Optional[Sequence[Variable]] = None, constant: Optional[Sequence[Variable]] = None,
until: Optional[Variable] = None, until: Optional[Variable] = None,
until_condition_failed: Literal["ignore", "warn", "raise"] = "warn",
name="ScalarLoop", name="ScalarLoop",
): ):
if until_condition_failed not in ["ignore", "warn", "raise"]:
raise ValueError(
f"Invalid until_condition_failed: {until_condition_failed}"
)
if constant is None: if constant is None:
constant = [] constant = []
if not len(init) == len(update): if not len(init) == len(update):
...@@ -52,12 +69,13 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -52,12 +69,13 @@ class ScalarLoop(ScalarInnerGraphOp):
self.outputs = copy(outputs) self.outputs = copy(outputs)
self.inputs = copy(inputs) self.inputs = copy(inputs)
self.is_while = bool(until)
self.inputs_type = tuple(input.type for input in inputs) self.inputs_type = tuple(input.type for input in inputs)
self.outputs_type = tuple(output.type for output in outputs) self.outputs_type = tuple(output.type for output in outputs)
if self.is_while:
self.outputs_type = self.outputs_type + (cast(Variable, until).type,)
self.nin = len(inputs) + 1 # n_steps is not part of the inner graph self.nin = len(inputs) + 1 # n_steps is not part of the inner graph
self.nout = len(outputs) # until is not output self.nout = len(outputs) + (1 if self.is_while else 0)
self.is_while = bool(until)
self.until_condition_failed = until_condition_failed
self.name = name self.name = name
self._validate_fgraph(FunctionGraph(self.inputs, self.outputs, clone=False)) self._validate_fgraph(FunctionGraph(self.inputs, self.outputs, clone=False))
super().__init__() super().__init__()
...@@ -135,7 +153,6 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -135,7 +153,6 @@ class ScalarLoop(ScalarInnerGraphOp):
update=update, update=update,
constant=constant, constant=constant,
until=until, until=until,
until_condition_failed=self.until_condition_failed,
name=self.name, name=self.name,
) )
...@@ -191,7 +208,6 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -191,7 +208,6 @@ class ScalarLoop(ScalarInnerGraphOp):
update=cloned_update, update=cloned_update,
constant=cloned_constant, constant=cloned_constant,
until=cloned_until, until=cloned_until,
until_condition_failed=self.until_condition_failed,
name=self.name, name=self.name,
) )
node = op.make_node(n_steps, *inputs) node = op.make_node(n_steps, *inputs)
...@@ -209,17 +225,8 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -209,17 +225,8 @@ class ScalarLoop(ScalarInnerGraphOp):
*carry, until = inner_fn(*carry, *constant) *carry, until = inner_fn(*carry, *constant)
if until: if until:
break break
carry.append(until)
if not until: # no-break
if self.until_condition_failed == "raise":
raise RuntimeError(
f"Until condition in ScalarLoop {self.name} not reached!"
)
elif self.until_condition_failed == "warn":
warnings.warn(
f"Until condition in ScalarLoop {self.name} not reached!",
RuntimeWarning,
)
else: else:
if n_steps < 0: if n_steps < 0:
raise ValueError("ScalarLoop does not have a termination condition.") raise ValueError("ScalarLoop does not have a termination condition.")
...@@ -324,27 +331,12 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -324,27 +331,12 @@ class ScalarLoop(ScalarInnerGraphOp):
if self.is_while: if self.is_while:
_c_code += "\nif(until){break;}\n" _c_code += "\nif(until){break;}\n"
# End of the loop
_c_code += "}\n" _c_code += "}\n"
# End of the loop # Output until flag
if self.is_while: if self.is_while:
if self.until_condition_failed == "raise": _c_code += f"%(o{len(fgraph.outputs)-1})s = until;\n"
_c_code += dedent(
f"""
if (!until) {{
PyErr_SetString(PyExc_RuntimeError, "Until condition in ScalarLoop {self.name} not reached!");
%(fail)s
}}
"""
)
elif self.until_condition_failed == "warn":
_c_code += dedent(
f"""
if (!until) {{
PyErr_WarnEx(PyExc_RuntimeWarning, "Until condition in ScalarLoop {self.name} not reached!", 1);
}}
"""
)
_c_code += "}\n" _c_code += "}\n"
...@@ -376,13 +368,4 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -376,13 +368,4 @@ class ScalarLoop(ScalarInnerGraphOp):
return res return res
def c_code_cache_version_outer(self): def c_code_cache_version_outer(self):
return (1,) return (2,)
def __eq__(self, other):
return (
super().__eq__(other)
and self.until_condition_failed == other.until_condition_failed
)
def __hash__(self):
return hash((super().__hash__(), self.until_condition_failed))
...@@ -703,7 +703,6 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=Scal ...@@ -703,7 +703,6 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=Scal
constant=constant_, constant=constant_,
update=update_, update=update_,
until=until_, until=until_,
until_condition_failed="warn",
name=name, name=name,
) )
return op(n_steps, *init, *constant) return op(n_steps, *init, *constant)
...@@ -747,9 +746,10 @@ def gammainc_grad(k, x): ...@@ -747,9 +746,10 @@ def gammainc_grad(k, x):
init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n] init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n]
constant = [log_x] constant = [log_x]
sum_a, *_ = _make_scalar_loop( sum_a, *_, sum_a_converges = _make_scalar_loop(
max_iters, init, constant, inner_loop_a, name="gammainc_grad_a" max_iters, init, constant, inner_loop_a, name="gammainc_grad_a"
) )
sum_a = switch(sum_a_converges, sum_a, np.nan)
# Second loop # Second loop
n = np.array(0, dtype="int32") n = np.array(0, dtype="int32")
...@@ -772,9 +772,10 @@ def gammainc_grad(k, x): ...@@ -772,9 +772,10 @@ def gammainc_grad(k, x):
init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n] init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n]
constant = [log_x] constant = [log_x]
sum_b, *_ = _make_scalar_loop( sum_b, *_, sum_b_converges = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammainc_grad_b" max_iters, init, constant, inner_loop_b, name="gammainc_grad_b"
) )
sum_b = switch(sum_b_converges, sum_b, np.nan)
grad_approx = exp(-x) * (log_x * sum_a - sum_b) grad_approx = exp(-x) * (log_x * sum_a - sum_b)
return grad_approx return grad_approx
...@@ -877,9 +878,10 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): ...@@ -877,9 +878,10 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
init = [sum_b0, log_s, s_sign, log_delta, n] init = [sum_b0, log_s, s_sign, log_delta, n]
constant = [k, log_x] constant = [k, log_x]
sum_b, *_ = _make_scalar_loop( sum_b, *_, sum_b_converges = _make_scalar_loop(
max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b" max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b"
) )
sum_b = switch(sum_b_converges, sum_b, np.nan)
grad_approx_b = ( grad_approx_b = (
gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k
) )
...@@ -1547,10 +1549,10 @@ def betainc_grad(p, q, x, wrtp: bool): ...@@ -1547,10 +1549,10 @@ def betainc_grad(p, q, x, wrtp: bool):
init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n] init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n]
constant = [f, p, q, K, dK] constant = [f, p, q, K, dK]
grad, *_ = _make_scalar_loop( grad, *_, grad_converges = _make_scalar_loop(
max_iters, init, constant, inner_loop, name="betainc_grad" max_iters, init, constant, inner_loop, name="betainc_grad"
) )
return grad return switch(grad_converges, grad, np.nan)
# Input validation # Input validation
nan_branch = (x < 0) | (x > 1) | (p < 0) | (q < 0) nan_branch = (x < 0) | (x > 1) | (p < 0) | (q < 0)
...@@ -1752,10 +1754,10 @@ def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype): ...@@ -1752,10 +1754,10 @@ def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype):
init = [*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k] init = [*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k]
constant = [a, b, c, log_z, sign_z] constant = [a, b, c, log_z, sign_z]
loop_outs = _make_scalar_loop( *loop_outs, converges = _make_scalar_loop(
max_steps, init, constant, inner_loop, name="hyp2f1_grad", loop_op=Grad2F1Loop max_steps, init, constant, inner_loop, name="hyp2f1_grad", loop_op=Grad2F1Loop
) )
return loop_outs[: len(wrt)] return *loop_outs[: len(wrt)], converges
def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]): def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]):
...@@ -1792,7 +1794,7 @@ def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]): ...@@ -1792,7 +1794,7 @@ def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]):
# We have to pass the converges flag to interrupt the loop, as the switch is not lazy # We have to pass the converges flag to interrupt the loop, as the switch is not lazy
z_is_zero = eq(z, 0) z_is_zero = eq(z, 0)
converges = check_2f1_converges(a, b, c, z) converges = check_2f1_converges(a, b, c, z)
grads = _grad_2f1_loop( *grads, grad_converges = _grad_2f1_loop(
a, b, c, z, skip_loop=z_is_zero | (~converges), wrt=wrt, dtype=dtype a, b, c, z, skip_loop=z_is_zero | (~converges), wrt=wrt, dtype=dtype
) )
......
...@@ -1219,6 +1219,30 @@ compile.optdb.register( # type: ignore ...@@ -1219,6 +1219,30 @@ compile.optdb.register( # type: ignore
) )
def _rebuild_partial_2f1grad_loop(node, wrt):
a, b, c, log_z, sign_z = node.inputs[-5:]
z = exp(log_z) * sign_z
# Reconstruct scalar loop with relevant outputs
a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z))
new_loop_op = _grad_2f1_loop(
a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype
)[0].owner.op
# Reconstruct elemwise loop
new_elemwise_op = Elemwise(scalar_op=new_loop_op)
n_steps = node.inputs[0]
init_grad_vars = node.inputs[1:10]
other_inputs = node.inputs[10:]
init_grads = init_grad_vars[: len(wrt)]
init_gs = init_grad_vars[3 : 3 + len(wrt)]
init_gs_signs = init_grad_vars[6 : 6 + len(wrt)]
subset_init_grad_vars = init_grads + init_gs + init_gs_signs
return new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs)
@register_specialize @register_specialize
@node_rewriter([Elemwise]) @node_rewriter([Elemwise])
def local_useless_2f1grad_loop(fgraph, node): def local_useless_2f1grad_loop(fgraph, node):
...@@ -1240,34 +1264,16 @@ def local_useless_2f1grad_loop(fgraph, node): ...@@ -1240,34 +1264,16 @@ def local_useless_2f1grad_loop(fgraph, node):
if sum(grad_var_is_used) == 3: if sum(grad_var_is_used) == 3:
return None return None
# Check that None of the remaining vars is used anywhere *other_vars, converges = node.outputs[3:]
if any(bool(fgraph.clients.get(v)) for v in node.outputs[3:]):
return None
a, b, c, log_z, sign_z = node.inputs[-5:] # Check that None of the remaining vars (except the converge flag) is used anywhere
z = exp(log_z) * sign_z if any(bool(fgraph.clients.get(v)) for v in other_vars):
return None
# Reconstruct scalar loop with relevant outputs
a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z))
wrt = [i for i, used in enumerate(grad_var_is_used) if used] wrt = [i for i, used in enumerate(grad_var_is_used) if used]
new_loop_op = _grad_2f1_loop( *new_outs, new_converges = _rebuild_partial_2f1grad_loop(node, wrt=wrt)
a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype
)[0].owner.op
# Reconstruct elemwise loop replacements = {converges: new_converges}
new_elemwise_op = Elemwise(scalar_op=new_loop_op)
n_steps = node.inputs[0]
init_grad_vars = node.inputs[1:10]
other_inputs = node.inputs[10:]
init_grads = init_grad_vars[: len(wrt)]
init_gs = init_grad_vars[3 : 3 + len(wrt)]
init_gs_signs = init_grad_vars[6 : 6 + len(wrt)]
subset_init_grad_vars = init_grads + init_gs + init_gs_signs
new_outs = new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs)
replacements = {}
i = 0 i = 0
for grad_var, is_used in zip(grad_vars, grad_var_is_used): for grad_var, is_used in zip(grad_vars, grad_var_is_used):
if not is_used: if not is_used:
...@@ -1275,3 +1281,48 @@ def local_useless_2f1grad_loop(fgraph, node): ...@@ -1275,3 +1281,48 @@ def local_useless_2f1grad_loop(fgraph, node):
replacements[grad_var] = new_outs[i] replacements[grad_var] = new_outs[i]
i += 1 i += 1
return replacements return replacements
@node_rewriter([Elemwise])
def split_2f1grad_loop(fgraph, node):
"""
2f1grad loop has too many operands for Numpy frompyfunc code used by Elemwise nodes on python mode.
This rewrite splits it across 3 different operations. It is not needed if `local_useless_2f1grad_loop` was applied
"""
loop_op = node.op.scalar_op
if not isinstance(loop_op, Grad2F1Loop):
return None
grad_related_vars = node.outputs[:-4]
# local_useless_2f1grad_loop was used, we should be safe
if len(grad_related_vars) // 3 != 3:
return None
grad_vars = grad_related_vars[:3]
*other_vars, converges = node.outputs[3:]
# Check that None of the remaining vars is used anywhere
if any(bool(fgraph.clients.get(v)) for v in other_vars):
return None
new_grad0, new_grad1, *_, new_converges01 = _rebuild_partial_2f1grad_loop(
node, wrt=[0, 1]
)
new_grad2, *_, new_converges2 = _rebuild_partial_2f1grad_loop(node, wrt=[2])
replacements = {
converges: new_converges01 & new_converges2,
grad_vars[0]: new_grad0,
grad_vars[1]: new_grad1,
grad_vars[2]: new_grad2,
}
return replacements
compile.optdb["py_only"].register( # type: ignore
"split_2f1grad_loop",
split_2f1grad_loop,
"fast_compile",
)
...@@ -103,41 +103,11 @@ def test_until(mode): ...@@ -103,41 +103,11 @@ def test_until(mode):
x = x0 + 1 x = x0 + 1
until = x >= 10 until = x >= 10
op = ScalarLoop(init=[x0], update=[x], until=until, until_condition_failed="ignore") op = ScalarLoop(init=[x0], update=[x], until=until)
fn = function([n_steps, x0], op(n_steps, x0), mode=mode) fn = function([n_steps, x0], op(n_steps, x0), mode=mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), 10) np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
np.testing.assert_allclose(fn(n_steps=20, x0=1), 10) np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
np.testing.assert_allclose(fn(n_steps=5, x0=1), 6) np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
op = ScalarLoop(
init=[x0],
update=[x],
until=until,
until_condition_failed="warn",
name="TestLoop",
)
fn = function([n_steps, x0], op(n_steps, x0), mode=mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), 10)
np.testing.assert_allclose(fn(n_steps=20, x0=1), 10)
with pytest.warns(
RuntimeWarning, match="Until condition in ScalarLoop TestLoop not reached!"
):
np.testing.assert_allclose(fn(n_steps=5, x0=1), 6)
op = ScalarLoop(
init=[x0],
update=[x],
until=until,
until_condition_failed="raise",
name="TestLoop",
)
fn = function([n_steps, x0], op(n_steps, x0), mode=mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), 10)
np.testing.assert_allclose(fn(n_steps=20, x0=1), 10)
with pytest.raises(
RuntimeError, match="Until condition in ScalarLoop TestLoop not reached!"
):
fn(n_steps=5, x0=1)
def test_update_missing_error(): def test_update_missing_error():
......
...@@ -1075,6 +1075,17 @@ class TestHyp2F1Grad: ...@@ -1075,6 +1075,17 @@ class TestHyp2F1Grad:
mode = get_default_mode().including("local_useless_2f1grad_loop") mode = get_default_mode().including("local_useless_2f1grad_loop")
f_grad = function([a1, a2, b1, z], hyp2f1_grad, mode=mode) f_grad = function([a1, a2, b1, z], hyp2f1_grad, mode=mode)
if len(wrt) == 3 and config.mode == "FAST_COMPILE" or not config.cxx:
# In this case we actually get two scalar_loops, because the merged one can't be executed in Python
[scalar_loop_op1, scalar_loop_op2] = [
node.op.scalar_op
for node in f_grad.maker.fgraph.toposort()
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ScalarLoop)
]
assert scalar_loop_op1.nin == 10 + 3 * 2 # wrt=[0, 1]
assert scalar_loop_op2.nin == 10 + 3 * 1 # wrt=[2]
else:
[scalar_loop_op] = [ [scalar_loop_op] = [
node.op.scalar_op node.op.scalar_op
for node in f_grad.maker.fgraph.apply_nodes for node in f_grad.maker.fgraph.apply_nodes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论