提交 b8e6b3ea authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor and fix static shape issues in IfElse

上级 01c4a55f
......@@ -11,61 +11,53 @@ it picks each entry of a matrix according to the condition) while `ifelse`
is a global operation with a scalar condition.
"""
import logging
from copy import deepcopy
from typing import List, Sequence, Union
from typing import TYPE_CHECKING, Any, Optional, Sequence, Union
import numpy as np
import aesara.tensor as at
from aesara import as_symbolic
from aesara.compile import optdb
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.type import HasDataType, HasShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
__docformat__ = "restructedtext en"
__authors__ = (
"Razvan Pascanu "
"James Bergstra "
"Dumitru Erhan "
"David Warde-Farley"
"PyMC Developers"
"Aesara Developers"
)
__copyright__ = "(c) 2010, Universite de Montreal"
_logger = logging.getLogger("aesara.ifelse")
if TYPE_CHECKING:
from aesara.tensor import TensorLike
class IfElse(_NoPythonOp):
"""
Op that provides conditional graph evaluation if used with the CVM/VM
linkers. Note that there exist a helpful function `ifelse` that should
be used to instantiate the op!
r"""An `Op` that provides conditional graph evaluation.
According to a scalar condition `condition` the op evaluates and then
returns all the tensors provided on the `then` branch, otherwise it
evaluates and returns the tensors provided on the `else` branch. The op
According to a scalar condition, this `Op` evaluates and then
returns all the tensors provided on the "then"-branch, otherwise it
evaluates and returns the tensors provided on the "else"-branch. The `Op`
supports multiple tensors on each branch, with the condition that the same
number of tensors are on the `then` as on the `else` and there is a one
to one correspondence between them (shape and dtype wise).
number of tensors are on the "then"-branch as on the "else"-branch and
there is a one to one correspondence between their dtypes and numbers of
dimensions.
The `then` branch is defined as the first N tensors (after the
condition), while the `else` branch is defined as the last N tensors.
The "then"-branch is defined as the first ``N`` tensors (after the
condition), while the "else"-branch is defined as the last ``N`` tensors.
Example usage:
``rval = ifelse(condition, rval_if_true1, .., rval_if_trueN,
rval_if_false1, rval_if_false2, .., rval_if_falseN)``
.. code-block::
rval = ifelse(condition,
rval_if_true_1, ..., rval_if_true_N,
rval_if_false_1, ..., rval_if_false_N)
.. note:
Other Linkers then CVM and VM are INCOMPATIBLE with this Op, and
will ignore its lazy characteristic, computing both the True and
False branch before picking one.
`Linker`\s other than `CVM`, and some other `VM` subclasses, are
incompatible with this `Op`, and will ignore its lazy characteristic,
computing both the true and false branches before returning one.
"""
......@@ -158,86 +150,137 @@ class IfElse(_NoPythonOp):
return out_shapes
def make_node(self, c, *args):
if len(args) != 2 * self.n_outs:
def make_node(self, condition: "TensorLike", *true_false_branches: Any):
if len(true_false_branches) != 2 * self.n_outs:
raise ValueError(
f"Wrong number of arguments to make_node: expected "
f"{int(2 * self.n_outs)}, got {len(args)}"
f"Wrong number of arguments: expected "
f"{int(2 * self.n_outs)}, got {len(true_false_branches)}"
)
c = at.basic.as_tensor_variable(c)
nw_args = []
for x in args:
if isinstance(x, Variable):
nw_args.append(x)
else:
nw_args.append(at.as_tensor_variable(x))
args = nw_args
aes = args[: self.n_outs]
fs = args[self.n_outs :]
for t, f in zip(aes, fs):
# TODO: Attempt to convert types so that they match?
# new_f = t.type.filter_variable(f)
if not t.type.is_super(f.type):
raise TypeError(
"IfElse requires compatible types for true and false return values: "
f"true_branch={t.type}, false_branch={f.type}"
condition = at.basic.as_tensor_variable(condition)
if condition.type.ndim > 0:
raise TypeError("The condition argument must be a truthy scalar value")
inputs_true_branch = true_false_branches[: self.n_outs]
inputs_false_branch = true_false_branches[self.n_outs :]
output_vars = []
new_inputs_true_branch = []
new_inputs_false_branch = []
for input_t, input_f in zip(inputs_true_branch, inputs_false_branch):
if not isinstance(input_t, Variable):
input_t = as_symbolic(input_t)
if not isinstance(input_f, Variable):
input_f = as_symbolic(input_f)
if isinstance(input_t.type, HasDataType) and isinstance(
input_f.type, HasDataType
):
# TODO: Be smarter about dtype casting.
# up_dtype = aes.upcast(input_t.type.dtype, input_f.type.dtype)
if input_t.type.dtype != input_f.type.dtype:
raise TypeError(
"IfElse requires compatible dtypes for both branches: got "
f"true_branch={input_t.type.dtype}, false_branch={input_f.type.dtype}"
)
if isinstance(input_t.type, HasShape) and isinstance(
input_f.type, HasShape
):
if input_t.type.ndim != input_f.type.ndim:
raise TypeError(
"IfElse requires compatible ndim values for both branches: got "
f"true_branch={input_t.type.ndim}, false_branch={input_f.type.ndim}"
)
# We can only use static shape information that corresponds
# in both branches, because the outputs of this `Op` are
# allowed to have distinct shapes from either branch
new_shape = tuple(
s_t if s_t == s_f else None
for s_t, s_f in zip(input_t.type.shape, input_f.type.shape)
)
if c.ndim > 0:
raise TypeError(
"Condition given to the op has to be a scalar "
"with 0 standing for False, anything else "
"for True"
)
return Apply(self, [c] + list(args), [t.type() for t in aes])
# TODO FIXME: The presence of this keyword is a strong
# assumption. Find something that's guaranteed by the/a
# confirmed interface.
output_type_t = input_t.type.clone(shape=new_shape)()
output_type_f = input_f.type.clone(shape=new_shape)()
else:
output_type_t = input_t.type()
output_type_f = input_f.type()
input_t = output_type_f.type.convert_variable(input_t)
input_f = output_type_t.type.convert_variable(input_f)
new_inputs_true_branch.append(input_t)
new_inputs_false_branch.append(input_f)
output_vars.append(output_type_t)
return Apply(
self,
[condition] + new_inputs_true_branch + new_inputs_false_branch,
output_vars,
)
def R_op(self, inputs, eval_points):
return self(inputs[0], *eval_points[1:], return_list=True)
def grad(self, ins, grads):
aes = ins[1:][: self.n_outs]
fs = ins[1:][self.n_outs :]
condition = ins[0]
inputs_true_branch = ins[1:][: self.n_outs]
inputs_false_branch = ins[1:][self.n_outs :]
if self.name is not None:
nw_name_t = self.name + "_grad_t"
nw_name_f = self.name + "_grad_f"
else:
nw_name_t = None
nw_name_f = None
if_true_op = IfElse(n_outs=self.n_outs, as_view=self.as_view, name=nw_name_t)
if_true_op = IfElse(n_outs=self.n_outs, as_view=self.as_view, name=nw_name_t)
if_false_op = IfElse(n_outs=self.n_outs, as_view=self.as_view, name=nw_name_f)
# The grads can have a different dtype then the inputs.
# As inputs true/false pair must have the same dtype,
# we must cast the zeros to the corresponding grad dtype
# and not the input dtype.
if_true = (
[ins[0]]
# The `grads` can have different dtypes than the `inputs`.
# Since input true/false entries must have the same dtypes, we need to
# cast the zeros to the corresponding `grads` dtypes and not the input
# dtypes.
inputs_true_grad = (
[condition]
+ grads
+ [at.basic.zeros_like(t, dtype=grads[i].dtype) for i, t in enumerate(aes)]
+ [
at.basic.zeros_like(t, dtype=grads[i].dtype)
for i, t in enumerate(inputs_true_branch)
]
)
if_false = (
[ins[0]]
+ [at.basic.zeros_like(f, dtype=grads[i].dtype) for i, f in enumerate(fs)]
inputs_false_grad = (
[condition]
+ [
at.basic.zeros_like(f, dtype=grads[i].dtype)
for i, f in enumerate(inputs_false_branch)
]
+ grads
)
condition = ins[0]
# condition does affect the elements of the output so it is connected.
# `condition` does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
condition_grad = condition.zeros_like().astype(config.floatX)
return (
[condition_grad]
+ if_true_op(*if_true, return_list=True)
+ if_false_op(*if_false, return_list=True)
+ if_true_op(*inputs_true_grad, return_list=True)
+ if_false_op(*inputs_false_grad, return_list=True)
)
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
cond = node.inputs[0]
aes = node.inputs[1:][: self.n_outs]
fs = node.inputs[1:][self.n_outs :]
input_true_branch = node.inputs[1:][: self.n_outs]
inputs_false_branch = node.inputs[1:][self.n_outs :]
outputs = node.outputs
def thunk():
......@@ -249,12 +292,12 @@ class IfElse(_NoPythonOp):
ls = [
idx + 1
for idx in range(self.n_outs)
if not compute_map[aes[idx]][0]
if not compute_map[input_true_branch[idx]][0]
]
if len(ls) > 0:
return ls
else:
for out, t in zip(outputs, aes):
for out, t in zip(outputs, input_true_branch):
compute_map[out][0] = 1
val = storage_map[t][0]
if self.as_view:
......@@ -269,12 +312,12 @@ class IfElse(_NoPythonOp):
ls = [
1 + idx + self.n_outs
for idx in range(self.n_outs)
if not compute_map[fs[idx]][0]
if not compute_map[inputs_false_branch[idx]][0]
]
if len(ls) > 0:
return ls
else:
for out, f in zip(outputs, fs):
for out, f in zip(outputs, inputs_false_branch):
compute_map[out][0] = 1
# can't view both outputs unless destroyhandler
# improves
......@@ -293,46 +336,42 @@ class IfElse(_NoPythonOp):
def ifelse(
condition: Variable,
then_branch: Union[Variable, List[Variable]],
else_branch: Union[Variable, List[Variable]],
name: str = None,
condition: "TensorLike",
then_branch: Union[Any, Sequence[Any]],
else_branch: Union[Any, Sequence[Any]],
name: Optional[str] = None,
) -> Union[Variable, Sequence[Variable]]:
"""
This function corresponds to an if statement, returning (and evaluating)
inputs in the ``then_branch`` if ``condition`` evaluates to True or
inputs in the ``else_branch`` if ``condition`` evaluates to False.
"""Construct a graph for an ``if`` statement.
Parameters
==========
----------
condition
``condition`` should be a tensor scalar representing the condition.
If it evaluates to 0 it corresponds to False, anything else stands
for True.
`condition` should be a tensor scalar representing the condition.
If it evaluates to ``0`` it corresponds to ``False``, anything else
stands for ``True``.
then_branch
A single aesara variable or a list of aesara variables that the
function should return as the output if ``condition`` evaluates to
A single variable or a list of variables that the
function should return as the output if `condition` evaluates to
true. The number of variables should match those in the
``else_branch``, and there should be a one to one correspondence
(type wise) with the tensors provided in the else branch
`else_branch`, as well as the dtypes and numbers of dimensions of each
tensor.
else_branch
A single aesara variable or a list of aesara variables that the
function should return as the output if ``condition`` evaluates to
false. The number of variables should match those in the then branch,
and there should be a one to one correspondence (type wise) with the
tensors provided in the then branch.
A single variable or a list of variables that the function should
return as the output if `condition` evaluates to false. The number of
variables should match those in `then_branch`, as well as the dtypes
and numbers of dimensions of each tensor.
Returns
=======
A sequence of aesara variables or a single variable (depending on the
nature of the ``then_branch`` and ``else_branch``). More exactly if
``then_branch`` and ``else_branch`` is a tensor, then
the return variable will be just a single variable, otherwise a
sequence. The value returns correspond either to the values in the
``then_branch`` or in the ``else_branch`` depending on the value of
``condition``.
-------
A sequence of variables or a single variable, depending on the
nature of `then_branch` and `else_branch`. More exactly, if
`then_branch` and `else_branch` is are single variables, then
the return variable will also be a single variable; otherwise, it will
be a sequence. The value returned correspond to either the values in
the `then_branch` or in the `else_branch` depending on the value of
`condition`.
"""
rval_type = None
......@@ -344,35 +383,17 @@ def ifelse(
if not isinstance(else_branch, (list, tuple)):
else_branch = [else_branch]
# Some of the elements might be converted into another type,
# we will store them in these new_... lists.
new_then_branch = []
new_else_branch = []
for then_branch_elem, else_branch_elem in zip(then_branch, else_branch):
if not isinstance(then_branch_elem, Variable):
then_branch_elem = at.basic.as_tensor_variable(then_branch_elem)
if not isinstance(else_branch_elem, Variable):
else_branch_elem = at.basic.as_tensor_variable(else_branch_elem)
# Make sure the types are compatible
else_branch_elem = then_branch_elem.type.filter_variable(else_branch_elem)
then_branch_elem = else_branch_elem.type.filter_variable(then_branch_elem)
new_then_branch.append(then_branch_elem)
new_else_branch.append(else_branch_elem)
if len(then_branch) != len(else_branch):
raise ValueError(
"The number of values on the `then` branch"
" should have the same number of variables as "
"the `else` branch : (variables on `then` "
f"{len(then_branch)}, variables on `else` "
f"{len(else_branch)})"
"The number of values on the `then` branch "
"must match the `else` branch: got "
f"{len(then_branch)} for `then`, and "
f"{len(else_branch)} for `else`."
)
new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, name=name)
ins = [condition] + list(new_then_branch) + list(new_else_branch)
ins = [condition] + list(then_branch) + list(else_branch)
rval = new_ifelse(*ins, return_list=True)
if rval_type is None:
......
......@@ -6,6 +6,7 @@ import pytest
import aesara
import aesara.ifelse
import aesara.sparse
import aesara.tensor.basic as at
from aesara import function
from aesara.compile.mode import Mode, get_mode
......@@ -14,15 +15,19 @@ from aesara.graph.op import Op
from aesara.ifelse import IfElse, ifelse
from aesara.link.c.type import generic
from aesara.tensor.math import eq
from aesara.tensor.type import col, iscalar, matrix, row, scalar, tensor3, vector
from aesara.tensor.type import (
col,
iscalar,
ivector,
matrix,
row,
scalar,
tensor3,
vector,
)
from tests import unittest_tools as utt
__docformat__ = "restructedtext en"
__authors__ = "Razvan Pascanu " "PyMC Development Team " "Aesara Developers "
__copyright__ = "(c) 2010, Universite de Montreal"
class TestIfelse(utt.OptimizationTestMixin):
mode = None
dtype = aesara.config.floatX
......@@ -41,7 +46,7 @@ class TestIfelse(utt.OptimizationTestMixin):
with pytest.raises(ValueError):
IfElse(0)(c, x, x)
def test_const_Op_argument(self):
def test_const_false_branch(self):
x = vector("x", dtype=self.dtype)
y = np.array([2.0, 3.0], dtype=self.dtype)
c = iscalar("c")
......@@ -321,9 +326,6 @@ class TestIfelse(utt.OptimizationTestMixin):
ifelse(cond, y, x)
def test_sparse_tensor_error(self):
pytest.importorskip("scipy", minversion="0.7.0")
import aesara.sparse
rng = np.random.default_rng(utt.fetch_seed())
data = rng.random((2, 3)).astype(self.dtype)
......@@ -527,6 +529,37 @@ class TestIfelse(utt.OptimizationTestMixin):
res.owner.op.as_view = True
assert str(res.owner).startswith("if{name,inplace}")
@pytest.mark.parametrize(
"x_shape, y_shape, x_val, y_val, exp_shape",
[
((2,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)),
((None,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)),
((3,), (3,), np.r_[1.0, 2.0, 3.0], np.r_[1.0, 2.0, 3.0], (3,)),
((1,), (3,), np.r_[1.0], np.r_[1.0, 2.0, 3.0], (None,)),
],
)
def test_static_branch_shapes(self, x_shape, y_shape, x_val, y_val, exp_shape):
x = at.tensor(dtype=self.dtype, shape=x_shape, name="x")
y = at.tensor(dtype=self.dtype, shape=y_shape, name="y")
c = iscalar("c")
z = IfElse(1)(c, x, y)
assert z.type.shape == exp_shape
f = function([c, x, y], z, mode=self.mode)
x_val = x_val.astype(self.dtype)
y_val = y_val.astype(self.dtype)
val = f(0, x_val, y_val)
assert np.array_equal(val, y_val)
def test_nonscalar_condition(self):
x = vector("x")
y = vector("y")
c = ivector("c")
with pytest.raises(TypeError, match="The condition argument"):
IfElse(1)(c, x, y)
class IfElseIfElseIf(Op):
def __init__(self, inplace=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论