提交 44066869 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix docstring and type checks in aesara.ifelse

上级 5a6d92c3
......@@ -13,6 +13,7 @@ is a global operation with a scalar condition.
import logging
from copy import deepcopy
from typing import List, Union
import numpy as np
......@@ -62,13 +63,16 @@ class IfElse(_NoPythonOp):
``rval = ifelse(condition, rval_if_true1, .., rval_if_trueN,
rval_if_false1, rval_if_false2, .., rval_if_falseN)``
:note:
.. note:
Other Linkers then CVM and VM are INCOMPATIBLE with this Op, and
will ignore its lazy characteristic, computing both the True and
False branch before picking one.
"""
__props__ = ("as_view", "gpu", "n_outs")
def __init__(self, n_outs, as_view=False, gpu=False, name=None):
if as_view:
# check destroyhandler and others to ensure that a view_map with
......@@ -83,21 +87,18 @@ class IfElse(_NoPythonOp):
self.name = name
def __eq__(self, other):
if not type(self) == type(other):
if type(self) != type(other):
return False
if not self.as_view == other.as_view:
if self.as_view != other.as_view:
return False
if not self.gpu == other.gpu:
if self.gpu != other.gpu:
return False
if not self.n_outs == other.n_outs:
if self.n_outs != other.n_outs:
return False
return True
def __hash__(self):
rval = (
hash(type(self)) ^ hash(self.as_view) ^ hash(self.gpu) ^ hash(self.n_outs)
)
return rval
return hash((type(self), self.as_view, self.gpu, self.n_outs))
def __str__(self):
args = []
......@@ -274,7 +275,7 @@ class IfElse(_NoPythonOp):
if self.as_view:
storage_map[out][0] = val
# Work around broken numpy deepcopy
elif type(val) in (np.ndarray, np.memmap):
elif isinstance(val, (np.ndarray, np.memmap)):
storage_map[out][0] = val.copy()
else:
storage_map[out][0] = deepcopy(val)
......@@ -294,7 +295,7 @@ class IfElse(_NoPythonOp):
# improves
# Work around broken numpy deepcopy
val = storage_map[f][0]
if type(val) in (np.ndarray, np.memmap):
if isinstance(val, (np.ndarray, np.memmap)):
storage_map[out][0] = val.copy()
else:
storage_map[out][0] = deepcopy(val)
......@@ -306,35 +307,40 @@ class IfElse(_NoPythonOp):
return thunk
def ifelse(condition, then_branch, else_branch, name=None):
def ifelse(
condition: Variable,
then_branch: Union[Variable, List[Variable]],
else_branch: Union[Variable, List[Variable]],
name: str = None,
) -> Union[Variable, List[Variable]]:
"""
This function corresponds to an if statement, returning (and evaluating)
inputs in the ``then_branch`` if ``condition`` evaluates to True or
inputs in the ``else_branch`` if ``condition`` evaluates to False.
:type condition: scalar like
:param condition:
Parameters
==========
condition
``condition`` should be a tensor scalar representing the condition.
If it evaluates to 0 it corresponds to False, anything else stands
for True.
:type then_branch: list of aesara expressions/ aesara expression
:param then_branch:
then_branch
A single aesara variable or a list of aesara variables that the
function should return as the output if ``condition`` evaluates to
true. The number of variables should match those in the
``else_branch``, and there should be a one to one correspondence
(type wise) with the tensors provided in the else branch
:type else_branch: list of aesara expressions/ aesara expressions
:param else_branch:
else_branch
A single aesara variable or a list of aesara variables that the
function should return as the output if ``condition`` evaluates to
false. The number of variables should match those in the then branch,
and there should be a one to one correspondace (type wise) with the
and there should be a one to one correspondence (type wise) with the
tensors provided in the then branch.
:return:
Returns
=======
A list of aesara variables or a single variable (depending on the
nature of the ``then_branch`` and ``else_branch``). More exactly if
``then_branch`` and ``else_branch`` is a tensor, then
......@@ -637,11 +643,11 @@ class CondMerge(GlobalOptimizer):
new_outs = new_ifelse(*new_ins, return_list=True)
new_outs = [clone_replace(x) for x in new_outs]
old_outs = []
if type(merging_node.outputs) not in (list, tuple):
if not isinstance(merging_node.outputs, (list, tuple)):
old_outs += [merging_node.outputs]
else:
old_outs += merging_node.outputs
if type(proposal.outputs) not in (list, tuple):
if not isinstance(proposal.outputs, (list, tuple)):
old_outs += [proposal.outputs]
else:
old_outs += proposal.outputs
......@@ -737,11 +743,11 @@ def cond_merge_random_op(fgraph, main_node):
)
new_outs = new_ifelse(*new_ins, return_list=True)
old_outs = []
if type(merging_node.outputs) not in (list, tuple):
if not isinstance(merging_node.outputs, (list, tuple)):
old_outs += [merging_node.outputs]
else:
old_outs += merging_node.outputs
if type(proposal.outputs) not in (list, tuple):
if not isinstance(proposal.outputs, (list, tuple)):
old_outs += [proposal.outputs]
else:
old_outs += proposal.outputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论