提交 44066869 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix docstring and type checks in aesara.ifelse

上级 5a6d92c3
...@@ -13,6 +13,7 @@ is a global operation with a scalar condition. ...@@ -13,6 +13,7 @@ is a global operation with a scalar condition.
import logging import logging
from copy import deepcopy from copy import deepcopy
from typing import List, Union
import numpy as np import numpy as np
...@@ -62,13 +63,16 @@ class IfElse(_NoPythonOp): ...@@ -62,13 +63,16 @@ class IfElse(_NoPythonOp):
``rval = ifelse(condition, rval_if_true1, .., rval_if_trueN, ``rval = ifelse(condition, rval_if_true1, .., rval_if_trueN,
rval_if_false1, rval_if_false2, .., rval_if_falseN)`` rval_if_false1, rval_if_false2, .., rval_if_falseN)``
:note: .. note:
Other Linkers then CVM and VM are INCOMPATIBLE with this Op, and Other Linkers then CVM and VM are INCOMPATIBLE with this Op, and
will ignore its lazy characteristic, computing both the True and will ignore its lazy characteristic, computing both the True and
False branch before picking one. False branch before picking one.
""" """
__props__ = ("as_view", "gpu", "n_outs")
def __init__(self, n_outs, as_view=False, gpu=False, name=None): def __init__(self, n_outs, as_view=False, gpu=False, name=None):
if as_view: if as_view:
# check destroyhandler and others to ensure that a view_map with # check destroyhandler and others to ensure that a view_map with
...@@ -83,21 +87,18 @@ class IfElse(_NoPythonOp): ...@@ -83,21 +87,18 @@ class IfElse(_NoPythonOp):
self.name = name self.name = name
def __eq__(self, other): def __eq__(self, other):
if not type(self) == type(other): if type(self) != type(other):
return False return False
if not self.as_view == other.as_view: if self.as_view != other.as_view:
return False return False
if not self.gpu == other.gpu: if self.gpu != other.gpu:
return False return False
if not self.n_outs == other.n_outs: if self.n_outs != other.n_outs:
return False return False
return True return True
def __hash__(self): def __hash__(self):
rval = ( return hash((type(self), self.as_view, self.gpu, self.n_outs))
hash(type(self)) ^ hash(self.as_view) ^ hash(self.gpu) ^ hash(self.n_outs)
)
return rval
def __str__(self): def __str__(self):
args = [] args = []
...@@ -274,7 +275,7 @@ class IfElse(_NoPythonOp): ...@@ -274,7 +275,7 @@ class IfElse(_NoPythonOp):
if self.as_view: if self.as_view:
storage_map[out][0] = val storage_map[out][0] = val
# Work around broken numpy deepcopy # Work around broken numpy deepcopy
elif type(val) in (np.ndarray, np.memmap): elif isinstance(val, (np.ndarray, np.memmap)):
storage_map[out][0] = val.copy() storage_map[out][0] = val.copy()
else: else:
storage_map[out][0] = deepcopy(val) storage_map[out][0] = deepcopy(val)
...@@ -294,7 +295,7 @@ class IfElse(_NoPythonOp): ...@@ -294,7 +295,7 @@ class IfElse(_NoPythonOp):
# improves # improves
# Work around broken numpy deepcopy # Work around broken numpy deepcopy
val = storage_map[f][0] val = storage_map[f][0]
if type(val) in (np.ndarray, np.memmap): if isinstance(val, (np.ndarray, np.memmap)):
storage_map[out][0] = val.copy() storage_map[out][0] = val.copy()
else: else:
storage_map[out][0] = deepcopy(val) storage_map[out][0] = deepcopy(val)
...@@ -306,35 +307,40 @@ class IfElse(_NoPythonOp): ...@@ -306,35 +307,40 @@ class IfElse(_NoPythonOp):
return thunk return thunk
def ifelse(condition, then_branch, else_branch, name=None): def ifelse(
condition: Variable,
then_branch: Union[Variable, List[Variable]],
else_branch: Union[Variable, List[Variable]],
name: str = None,
) -> Union[Variable, List[Variable]]:
""" """
This function corresponds to an if statement, returning (and evaluating) This function corresponds to an if statement, returning (and evaluating)
inputs in the ``then_branch`` if ``condition`` evaluates to True or inputs in the ``then_branch`` if ``condition`` evaluates to True or
inputs in the ``else_branch`` if ``condition`` evaluates to False. inputs in the ``else_branch`` if ``condition`` evaluates to False.
:type condition: scalar like Parameters
:param condition: ==========
condition
``condition`` should be a tensor scalar representing the condition. ``condition`` should be a tensor scalar representing the condition.
If it evaluates to 0 it corresponds to False, anything else stands If it evaluates to 0 it corresponds to False, anything else stands
for True. for True.
:type then_branch: list of aesara expressions/ aesara expression then_branch
:param then_branch:
A single aesara variable or a list of aesara variables that the A single aesara variable or a list of aesara variables that the
function should return as the output if ``condition`` evaluates to function should return as the output if ``condition`` evaluates to
true. The number of variables should match those in the true. The number of variables should match those in the
``else_branch``, and there should be a one to one correspondence ``else_branch``, and there should be a one to one correspondence
(type wise) with the tensors provided in the else branch (type wise) with the tensors provided in the else branch
:type else_branch: list of aesara expressions/ aesara expressions else_branch
:param else_branch:
A single aesara variable or a list of aesara variables that the A single aesara variable or a list of aesara variables that the
function should return as the output if ``condition`` evaluates to function should return as the output if ``condition`` evaluates to
false. The number of variables should match those in the then branch, false. The number of variables should match those in the then branch,
and there should be a one to one correspondace (type wise) with the and there should be a one to one correspondence (type wise) with the
tensors provided in the then branch. tensors provided in the then branch.
:return: Returns
=======
A list of aesara variables or a single variable (depending on the A list of aesara variables or a single variable (depending on the
nature of the ``then_branch`` and ``else_branch``). More exactly if nature of the ``then_branch`` and ``else_branch``). More exactly if
``then_branch`` and ``else_branch`` is a tensor, then ``then_branch`` and ``else_branch`` is a tensor, then
...@@ -637,11 +643,11 @@ class CondMerge(GlobalOptimizer): ...@@ -637,11 +643,11 @@ class CondMerge(GlobalOptimizer):
new_outs = new_ifelse(*new_ins, return_list=True) new_outs = new_ifelse(*new_ins, return_list=True)
new_outs = [clone_replace(x) for x in new_outs] new_outs = [clone_replace(x) for x in new_outs]
old_outs = [] old_outs = []
if type(merging_node.outputs) not in (list, tuple): if not isinstance(merging_node.outputs, (list, tuple)):
old_outs += [merging_node.outputs] old_outs += [merging_node.outputs]
else: else:
old_outs += merging_node.outputs old_outs += merging_node.outputs
if type(proposal.outputs) not in (list, tuple): if not isinstance(proposal.outputs, (list, tuple)):
old_outs += [proposal.outputs] old_outs += [proposal.outputs]
else: else:
old_outs += proposal.outputs old_outs += proposal.outputs
...@@ -737,11 +743,11 @@ def cond_merge_random_op(fgraph, main_node): ...@@ -737,11 +743,11 @@ def cond_merge_random_op(fgraph, main_node):
) )
new_outs = new_ifelse(*new_ins, return_list=True) new_outs = new_ifelse(*new_ins, return_list=True)
old_outs = [] old_outs = []
if type(merging_node.outputs) not in (list, tuple): if not isinstance(merging_node.outputs, (list, tuple)):
old_outs += [merging_node.outputs] old_outs += [merging_node.outputs]
else: else:
old_outs += merging_node.outputs old_outs += merging_node.outputs
if type(proposal.outputs) not in (list, tuple): if not isinstance(proposal.outputs, (list, tuple)):
old_outs += [proposal.outputs] old_outs += [proposal.outputs]
else: else:
old_outs += proposal.outputs old_outs += proposal.outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论