提交 b8e6b3ea authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor and fix static shape issues in IfElse

上级 01c4a55f
...@@ -11,61 +11,53 @@ it picks each entry of a matrix according to the condition) while `ifelse` ...@@ -11,61 +11,53 @@ it picks each entry of a matrix according to the condition) while `ifelse`
is a global operation with a scalar condition. is a global operation with a scalar condition.
""" """
import logging
from copy import deepcopy from copy import deepcopy
from typing import List, Sequence, Union from typing import TYPE_CHECKING, Any, Optional, Sequence, Union
import numpy as np import numpy as np
import aesara.tensor as at import aesara.tensor as at
from aesara import as_symbolic
from aesara.compile import optdb from aesara.compile import optdb
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.type import HasDataType, HasShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
__docformat__ = "restructedtext en" if TYPE_CHECKING:
__authors__ = ( from aesara.tensor import TensorLike
"Razvan Pascanu "
"James Bergstra "
"Dumitru Erhan "
"David Warde-Farley"
"PyMC Developers"
"Aesara Developers"
)
__copyright__ = "(c) 2010, Universite de Montreal"
_logger = logging.getLogger("aesara.ifelse")
class IfElse(_NoPythonOp): class IfElse(_NoPythonOp):
""" r"""An `Op` that provides conditional graph evaluation.
Op that provides conditional graph evaluation if used with the CVM/VM
linkers. Note that there exist a helpful function `ifelse` that should
be used to instantiate the op!
According to a scalar condition `condition` the op evaluates and then According to a scalar condition, this `Op` evaluates and then
returns all the tensors provided on the `then` branch, otherwise it returns all the tensors provided on the "then"-branch, otherwise it
evaluates and returns the tensors provided on the `else` branch. The op evaluates and returns the tensors provided on the "else"-branch. The `Op`
supports multiple tensors on each branch, with the condition that the same supports multiple tensors on each branch, with the condition that the same
number of tensors are on the `then` as on the `else` and there is a one number of tensors are on the "then"-branch as on the "else"-branch and
to one correspondence between them (shape and dtype wise). there is a one to one correspondence between their dtypes and numbers of
dimensions.
The `then` branch is defined as the first N tensors (after the The "then"-branch is defined as the first ``N`` tensors (after the
condition), while the `else` branch is defined as the last N tensors. condition), while the "else"-branch is defined as the last ``N`` tensors.
Example usage: Example usage:
``rval = ifelse(condition, rval_if_true1, .., rval_if_trueN, .. code-block::
rval_if_false1, rval_if_false2, .., rval_if_falseN)``
rval = ifelse(condition,
rval_if_true_1, ..., rval_if_true_N,
rval_if_false_1, ..., rval_if_false_N)
.. note: .. note:
Other Linkers then CVM and VM are INCOMPATIBLE with this Op, and `Linker`\s other than `CVM`, and some other `VM` subclasses, are
will ignore its lazy characteristic, computing both the True and incompatible with this `Op`, and will ignore its lazy characteristic,
False branch before picking one. computing both the true and false branches before returning one.
""" """
...@@ -158,86 +150,137 @@ class IfElse(_NoPythonOp): ...@@ -158,86 +150,137 @@ class IfElse(_NoPythonOp):
return out_shapes return out_shapes
def make_node(self, c, *args): def make_node(self, condition: "TensorLike", *true_false_branches: Any):
if len(args) != 2 * self.n_outs: if len(true_false_branches) != 2 * self.n_outs:
raise ValueError( raise ValueError(
f"Wrong number of arguments to make_node: expected " f"Wrong number of arguments: expected "
f"{int(2 * self.n_outs)}, got {len(args)}" f"{int(2 * self.n_outs)}, got {len(true_false_branches)}"
) )
c = at.basic.as_tensor_variable(c)
nw_args = []
for x in args:
if isinstance(x, Variable):
nw_args.append(x)
else:
nw_args.append(at.as_tensor_variable(x))
args = nw_args
aes = args[: self.n_outs]
fs = args[self.n_outs :]
for t, f in zip(aes, fs): condition = at.basic.as_tensor_variable(condition)
# TODO: Attempt to convert types so that they match?
# new_f = t.type.filter_variable(f) if condition.type.ndim > 0:
raise TypeError("The condition argument must be a truthy scalar value")
inputs_true_branch = true_false_branches[: self.n_outs]
inputs_false_branch = true_false_branches[self.n_outs :]
output_vars = []
new_inputs_true_branch = []
new_inputs_false_branch = []
for input_t, input_f in zip(inputs_true_branch, inputs_false_branch):
if not isinstance(input_t, Variable):
input_t = as_symbolic(input_t)
if not isinstance(input_f, Variable):
input_f = as_symbolic(input_f)
if not t.type.is_super(f.type): if isinstance(input_t.type, HasDataType) and isinstance(
input_f.type, HasDataType
):
# TODO: Be smarter about dtype casting.
# up_dtype = aes.upcast(input_t.type.dtype, input_f.type.dtype)
if input_t.type.dtype != input_f.type.dtype:
raise TypeError( raise TypeError(
"IfElse requires compatible types for true and false return values: " "IfElse requires compatible dtypes for both branches: got "
f"true_branch={t.type}, false_branch={f.type}" f"true_branch={input_t.type.dtype}, false_branch={input_f.type.dtype}"
) )
if c.ndim > 0:
if isinstance(input_t.type, HasShape) and isinstance(
input_f.type, HasShape
):
if input_t.type.ndim != input_f.type.ndim:
raise TypeError( raise TypeError(
"Condition given to the op has to be a scalar " "IfElse requires compatible ndim values for both branches: got "
"with 0 standing for False, anything else " f"true_branch={input_t.type.ndim}, false_branch={input_f.type.ndim}"
"for True" )
# We can only use static shape information that corresponds
# in both branches, because the outputs of this `Op` are
# allowed to have distinct shapes from either branch
new_shape = tuple(
s_t if s_t == s_f else None
for s_t, s_f in zip(input_t.type.shape, input_f.type.shape)
)
# TODO FIXME: The presence of this keyword is a strong
# assumption. Find something that's guaranteed by the/a
# confirmed interface.
output_type_t = input_t.type.clone(shape=new_shape)()
output_type_f = input_f.type.clone(shape=new_shape)()
else:
output_type_t = input_t.type()
output_type_f = input_f.type()
input_t = output_type_f.type.convert_variable(input_t)
input_f = output_type_t.type.convert_variable(input_f)
new_inputs_true_branch.append(input_t)
new_inputs_false_branch.append(input_f)
output_vars.append(output_type_t)
return Apply(
self,
[condition] + new_inputs_true_branch + new_inputs_false_branch,
output_vars,
) )
return Apply(self, [c] + list(args), [t.type() for t in aes])
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return self(inputs[0], *eval_points[1:], return_list=True) return self(inputs[0], *eval_points[1:], return_list=True)
def grad(self, ins, grads): def grad(self, ins, grads):
aes = ins[1:][: self.n_outs]
fs = ins[1:][self.n_outs :] condition = ins[0]
inputs_true_branch = ins[1:][: self.n_outs]
inputs_false_branch = ins[1:][self.n_outs :]
if self.name is not None: if self.name is not None:
nw_name_t = self.name + "_grad_t" nw_name_t = self.name + "_grad_t"
nw_name_f = self.name + "_grad_f" nw_name_f = self.name + "_grad_f"
else: else:
nw_name_t = None nw_name_t = None
nw_name_f = None nw_name_f = None
if_true_op = IfElse(n_outs=self.n_outs, as_view=self.as_view, name=nw_name_t)
if_true_op = IfElse(n_outs=self.n_outs, as_view=self.as_view, name=nw_name_t)
if_false_op = IfElse(n_outs=self.n_outs, as_view=self.as_view, name=nw_name_f) if_false_op = IfElse(n_outs=self.n_outs, as_view=self.as_view, name=nw_name_f)
# The grads can have a different dtype then the inputs. # The `grads` can have different dtypes than the `inputs`.
# As inputs true/false pair must have the same dtype, # Since input true/false entries must have the same dtypes, we need to
# we must cast the zeros to the corresponding grad dtype # cast the zeros to the corresponding `grads` dtypes and not the input
# and not the input dtype. # dtypes.
if_true = ( inputs_true_grad = (
[ins[0]] [condition]
+ grads + grads
+ [at.basic.zeros_like(t, dtype=grads[i].dtype) for i, t in enumerate(aes)] + [
at.basic.zeros_like(t, dtype=grads[i].dtype)
for i, t in enumerate(inputs_true_branch)
]
) )
if_false = ( inputs_false_grad = (
[ins[0]] [condition]
+ [at.basic.zeros_like(f, dtype=grads[i].dtype) for i, f in enumerate(fs)] + [
at.basic.zeros_like(f, dtype=grads[i].dtype)
for i, f in enumerate(inputs_false_branch)
]
+ grads + grads
) )
condition = ins[0] # `condition` does affect the elements of the output so it is connected.
# condition does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that # For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition # condition + epsilon always triggers the same branch as condition
condition_grad = condition.zeros_like().astype(config.floatX) condition_grad = condition.zeros_like().astype(config.floatX)
return ( return (
[condition_grad] [condition_grad]
+ if_true_op(*if_true, return_list=True) + if_true_op(*inputs_true_grad, return_list=True)
+ if_false_op(*if_false, return_list=True) + if_false_op(*inputs_false_grad, return_list=True)
) )
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
cond = node.inputs[0] cond = node.inputs[0]
aes = node.inputs[1:][: self.n_outs] input_true_branch = node.inputs[1:][: self.n_outs]
fs = node.inputs[1:][self.n_outs :] inputs_false_branch = node.inputs[1:][self.n_outs :]
outputs = node.outputs outputs = node.outputs
def thunk(): def thunk():
...@@ -249,12 +292,12 @@ class IfElse(_NoPythonOp): ...@@ -249,12 +292,12 @@ class IfElse(_NoPythonOp):
ls = [ ls = [
idx + 1 idx + 1
for idx in range(self.n_outs) for idx in range(self.n_outs)
if not compute_map[aes[idx]][0] if not compute_map[input_true_branch[idx]][0]
] ]
if len(ls) > 0: if len(ls) > 0:
return ls return ls
else: else:
for out, t in zip(outputs, aes): for out, t in zip(outputs, input_true_branch):
compute_map[out][0] = 1 compute_map[out][0] = 1
val = storage_map[t][0] val = storage_map[t][0]
if self.as_view: if self.as_view:
...@@ -269,12 +312,12 @@ class IfElse(_NoPythonOp): ...@@ -269,12 +312,12 @@ class IfElse(_NoPythonOp):
ls = [ ls = [
1 + idx + self.n_outs 1 + idx + self.n_outs
for idx in range(self.n_outs) for idx in range(self.n_outs)
if not compute_map[fs[idx]][0] if not compute_map[inputs_false_branch[idx]][0]
] ]
if len(ls) > 0: if len(ls) > 0:
return ls return ls
else: else:
for out, f in zip(outputs, fs): for out, f in zip(outputs, inputs_false_branch):
compute_map[out][0] = 1 compute_map[out][0] = 1
# can't view both outputs unless destroyhandler # can't view both outputs unless destroyhandler
# improves # improves
...@@ -293,46 +336,42 @@ class IfElse(_NoPythonOp): ...@@ -293,46 +336,42 @@ class IfElse(_NoPythonOp):
def ifelse( def ifelse(
condition: Variable, condition: "TensorLike",
then_branch: Union[Variable, List[Variable]], then_branch: Union[Any, Sequence[Any]],
else_branch: Union[Variable, List[Variable]], else_branch: Union[Any, Sequence[Any]],
name: str = None, name: Optional[str] = None,
) -> Union[Variable, Sequence[Variable]]: ) -> Union[Variable, Sequence[Variable]]:
""" """Construct a graph for an ``if`` statement.
This function corresponds to an if statement, returning (and evaluating)
inputs in the ``then_branch`` if ``condition`` evaluates to True or
inputs in the ``else_branch`` if ``condition`` evaluates to False.
Parameters Parameters
========== ----------
condition condition
``condition`` should be a tensor scalar representing the condition. `condition` should be a tensor scalar representing the condition.
If it evaluates to 0 it corresponds to False, anything else stands If it evaluates to ``0`` it corresponds to ``False``, anything else
for True. stands for ``True``.
then_branch then_branch
A single aesara variable or a list of aesara variables that the A single variable or a list of variables that the
function should return as the output if ``condition`` evaluates to function should return as the output if `condition` evaluates to
true. The number of variables should match those in the true. The number of variables should match those in the
``else_branch``, and there should be a one to one correspondence `else_branch`, as well as the dtypes and numbers of dimensions of each
(type wise) with the tensors provided in the else branch tensor.
else_branch else_branch
A single aesara variable or a list of aesara variables that the A single variable or a list of variables that the function should
function should return as the output if ``condition`` evaluates to return as the output if `condition` evaluates to false. The number of
false. The number of variables should match those in the then branch, variables should match those in `then_branch`, as well as the dtypes
and there should be a one to one correspondence (type wise) with the and numbers of dimensions of each tensor.
tensors provided in the then branch.
Returns Returns
======= -------
A sequence of aesara variables or a single variable (depending on the A sequence of variables or a single variable, depending on the
nature of the ``then_branch`` and ``else_branch``). More exactly if nature of `then_branch` and `else_branch`. More exactly, if
``then_branch`` and ``else_branch`` is a tensor, then `then_branch` and `else_branch` is are single variables, then
the return variable will be just a single variable, otherwise a the return variable will also be a single variable; otherwise, it will
sequence. The value returns correspond either to the values in the be a sequence. The value returned correspond to either the values in
``then_branch`` or in the ``else_branch`` depending on the value of the `then_branch` or in the `else_branch` depending on the value of
``condition``. `condition`.
""" """
rval_type = None rval_type = None
...@@ -344,35 +383,17 @@ def ifelse( ...@@ -344,35 +383,17 @@ def ifelse(
if not isinstance(else_branch, (list, tuple)): if not isinstance(else_branch, (list, tuple)):
else_branch = [else_branch] else_branch = [else_branch]
# Some of the elements might be converted into another type,
# we will store them in these new_... lists.
new_then_branch = []
new_else_branch = []
for then_branch_elem, else_branch_elem in zip(then_branch, else_branch):
if not isinstance(then_branch_elem, Variable):
then_branch_elem = at.basic.as_tensor_variable(then_branch_elem)
if not isinstance(else_branch_elem, Variable):
else_branch_elem = at.basic.as_tensor_variable(else_branch_elem)
# Make sure the types are compatible
else_branch_elem = then_branch_elem.type.filter_variable(else_branch_elem)
then_branch_elem = else_branch_elem.type.filter_variable(then_branch_elem)
new_then_branch.append(then_branch_elem)
new_else_branch.append(else_branch_elem)
if len(then_branch) != len(else_branch): if len(then_branch) != len(else_branch):
raise ValueError( raise ValueError(
"The number of values on the `then` branch" "The number of values on the `then` branch "
" should have the same number of variables as " "must match the `else` branch: got "
"the `else` branch : (variables on `then` " f"{len(then_branch)} for `then`, and "
f"{len(then_branch)}, variables on `else` " f"{len(else_branch)} for `else`."
f"{len(else_branch)})"
) )
new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, name=name) new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, name=name)
ins = [condition] + list(new_then_branch) + list(new_else_branch) ins = [condition] + list(then_branch) + list(else_branch)
rval = new_ifelse(*ins, return_list=True) rval = new_ifelse(*ins, return_list=True)
if rval_type is None: if rval_type is None:
......
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import aesara import aesara
import aesara.ifelse import aesara.ifelse
import aesara.sparse
import aesara.tensor.basic as at import aesara.tensor.basic as at
from aesara import function from aesara import function
from aesara.compile.mode import Mode, get_mode from aesara.compile.mode import Mode, get_mode
...@@ -14,15 +15,19 @@ from aesara.graph.op import Op ...@@ -14,15 +15,19 @@ from aesara.graph.op import Op
from aesara.ifelse import IfElse, ifelse from aesara.ifelse import IfElse, ifelse
from aesara.link.c.type import generic from aesara.link.c.type import generic
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.type import col, iscalar, matrix, row, scalar, tensor3, vector from aesara.tensor.type import (
col,
iscalar,
ivector,
matrix,
row,
scalar,
tensor3,
vector,
)
from tests import unittest_tools as utt from tests import unittest_tools as utt
__docformat__ = "restructedtext en"
__authors__ = "Razvan Pascanu " "PyMC Development Team " "Aesara Developers "
__copyright__ = "(c) 2010, Universite de Montreal"
class TestIfelse(utt.OptimizationTestMixin): class TestIfelse(utt.OptimizationTestMixin):
mode = None mode = None
dtype = aesara.config.floatX dtype = aesara.config.floatX
...@@ -41,7 +46,7 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -41,7 +46,7 @@ class TestIfelse(utt.OptimizationTestMixin):
with pytest.raises(ValueError): with pytest.raises(ValueError):
IfElse(0)(c, x, x) IfElse(0)(c, x, x)
def test_const_Op_argument(self): def test_const_false_branch(self):
x = vector("x", dtype=self.dtype) x = vector("x", dtype=self.dtype)
y = np.array([2.0, 3.0], dtype=self.dtype) y = np.array([2.0, 3.0], dtype=self.dtype)
c = iscalar("c") c = iscalar("c")
...@@ -321,9 +326,6 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -321,9 +326,6 @@ class TestIfelse(utt.OptimizationTestMixin):
ifelse(cond, y, x) ifelse(cond, y, x)
def test_sparse_tensor_error(self): def test_sparse_tensor_error(self):
pytest.importorskip("scipy", minversion="0.7.0")
import aesara.sparse
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
data = rng.random((2, 3)).astype(self.dtype) data = rng.random((2, 3)).astype(self.dtype)
...@@ -527,6 +529,37 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -527,6 +529,37 @@ class TestIfelse(utt.OptimizationTestMixin):
res.owner.op.as_view = True res.owner.op.as_view = True
assert str(res.owner).startswith("if{name,inplace}") assert str(res.owner).startswith("if{name,inplace}")
@pytest.mark.parametrize(
"x_shape, y_shape, x_val, y_val, exp_shape",
[
((2,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)),
((None,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)),
((3,), (3,), np.r_[1.0, 2.0, 3.0], np.r_[1.0, 2.0, 3.0], (3,)),
((1,), (3,), np.r_[1.0], np.r_[1.0, 2.0, 3.0], (None,)),
],
)
def test_static_branch_shapes(self, x_shape, y_shape, x_val, y_val, exp_shape):
x = at.tensor(dtype=self.dtype, shape=x_shape, name="x")
y = at.tensor(dtype=self.dtype, shape=y_shape, name="y")
c = iscalar("c")
z = IfElse(1)(c, x, y)
assert z.type.shape == exp_shape
f = function([c, x, y], z, mode=self.mode)
x_val = x_val.astype(self.dtype)
y_val = y_val.astype(self.dtype)
val = f(0, x_val, y_val)
assert np.array_equal(val, y_val)
def test_nonscalar_condition(self):
x = vector("x")
y = vector("y")
c = ivector("c")
with pytest.raises(TypeError, match="The condition argument"):
IfElse(1)(c, x, y)
class IfElseIfElseIf(Op): class IfElseIfElseIf(Op):
def __init__(self, inplace=False): def __init__(self, inplace=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论