Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b8e6b3ea
提交
b8e6b3ea
authored
8月 04, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
8月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor and fix static shape issues in IfElse
上级
01c4a55f
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
199 行增加
和
145 行删除
+199
-145
ifelse.py
aesara/ifelse.py
+156
-135
test_ifelse.py
tests/test_ifelse.py
+43
-10
没有找到文件。
aesara/ifelse.py
浏览文件 @
b8e6b3ea
...
@@ -11,61 +11,53 @@ it picks each entry of a matrix according to the condition) while `ifelse`
...
@@ -11,61 +11,53 @@ it picks each entry of a matrix according to the condition) while `ifelse`
is a global operation with a scalar condition.
is a global operation with a scalar condition.
"""
"""
import
logging
from
copy
import
deepcopy
from
copy
import
deepcopy
from
typing
import
List
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Sequence
,
Union
import
numpy
as
np
import
numpy
as
np
import
aesara.tensor
as
at
import
aesara.tensor
as
at
from
aesara
import
as_symbolic
from
aesara.compile
import
optdb
from
aesara.compile
import
optdb
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph.basic
import
Apply
,
Variable
,
clone_replace
,
is_in_ancestors
from
aesara.graph.basic
import
Apply
,
Variable
,
clone_replace
,
is_in_ancestors
from
aesara.graph.op
import
_NoPythonOp
from
aesara.graph.op
import
_NoPythonOp
from
aesara.graph.opt
import
GlobalOptimizer
,
in2out
,
local_optimizer
from
aesara.graph.opt
import
GlobalOptimizer
,
in2out
,
local_optimizer
from
aesara.graph.type
import
HasDataType
,
HasShape
from
aesara.tensor.shape
import
Reshape
,
Shape
,
SpecifyShape
,
Unbroadcast
from
aesara.tensor.shape
import
Reshape
,
Shape
,
SpecifyShape
,
Unbroadcast
__docformat__
=
"restructedtext en"
if
TYPE_CHECKING
:
__authors__
=
(
from
aesara.tensor
import
TensorLike
"Razvan Pascanu "
"James Bergstra "
"Dumitru Erhan "
"David Warde-Farley"
"PyMC Developers"
"Aesara Developers"
)
__copyright__
=
"(c) 2010, Universite de Montreal"
_logger
=
logging
.
getLogger
(
"aesara.ifelse"
)
class
IfElse
(
_NoPythonOp
):
class
IfElse
(
_NoPythonOp
):
"""
r"""An `Op` that provides conditional graph evaluation.
Op that provides conditional graph evaluation if used with the CVM/VM
linkers. Note that there exist a helpful function `ifelse` that should
be used to instantiate the op!
According to a scalar condition
`condition` the op
evaluates and then
According to a scalar condition
, this `Op`
evaluates and then
returns all the tensors provided on the
`then`
branch, otherwise it
returns all the tensors provided on the
"then"-
branch, otherwise it
evaluates and returns the tensors provided on the
`else` branch. The op
evaluates and returns the tensors provided on the
"else"-branch. The `Op`
supports multiple tensors on each branch, with the condition that the same
supports multiple tensors on each branch, with the condition that the same
number of tensors are on the `then` as on the `else` and there is a one
number of tensors are on the "then"-branch as on the "else"-branch and
to one correspondence between them (shape and dtype wise).
there is a one to one correspondence between their dtypes and numbers of
dimensions.
The
`then` branch is defined as the first N
tensors (after the
The
"then"-branch is defined as the first ``N``
tensors (after the
condition), while the
`else` branch is defined as the last N
tensors.
condition), while the
"else"-branch is defined as the last ``N``
tensors.
Example usage:
Example usage:
``rval = ifelse(condition, rval_if_true1, .., rval_if_trueN,
.. code-block::
rval_if_false1, rval_if_false2, .., rval_if_falseN)``
rval = ifelse(condition,
rval_if_true_1, ..., rval_if_true_N,
rval_if_false_1, ..., rval_if_false_N)
.. note:
.. note:
Other Linkers then CVM and VM are INCOMPATIBLE with this Op, and
`Linker`\s other than `CVM`, and some other `VM` subclasses, are
will ignore its lazy characteristic, computing both the True and
incompatible with this `Op`, and will ignore its lazy characteristic,
False branch before pick
ing one.
computing both the true and false branches before return
ing one.
"""
"""
...
@@ -158,86 +150,137 @@ class IfElse(_NoPythonOp):
...
@@ -158,86 +150,137 @@ class IfElse(_NoPythonOp):
return
out_shapes
return
out_shapes
def
make_node
(
self
,
c
,
*
args
):
def
make_node
(
self
,
c
ondition
:
"TensorLike"
,
*
true_false_branches
:
Any
):
if
len
(
arg
s
)
!=
2
*
self
.
n_outs
:
if
len
(
true_false_branche
s
)
!=
2
*
self
.
n_outs
:
raise
ValueError
(
raise
ValueError
(
f
"Wrong number of arguments
to make_node
: expected "
f
"Wrong number of arguments: expected "
f
"{int(2 * self.n_outs)}, got {len(
arg
s)}"
f
"{int(2 * self.n_outs)}, got {len(
true_false_branche
s)}"
)
)
c
=
at
.
basic
.
as_tensor_variable
(
c
)
nw_args
=
[]
condition
=
at
.
basic
.
as_tensor_variable
(
condition
)
for
x
in
args
:
if
isinstance
(
x
,
Variable
):
if
condition
.
type
.
ndim
>
0
:
nw_args
.
append
(
x
)
raise
TypeError
(
"The condition argument must be a truthy scalar value"
)
else
:
nw_args
.
append
(
at
.
as_tensor_variable
(
x
))
inputs_true_branch
=
true_false_branches
[:
self
.
n_outs
]
args
=
nw_args
inputs_false_branch
=
true_false_branches
[
self
.
n_outs
:]
aes
=
args
[:
self
.
n_outs
]
fs
=
args
[
self
.
n_outs
:]
output_vars
=
[]
new_inputs_true_branch
=
[]
for
t
,
f
in
zip
(
aes
,
fs
):
new_inputs_false_branch
=
[]
# TODO: Attempt to convert types so that they match?
for
input_t
,
input_f
in
zip
(
inputs_true_branch
,
inputs_false_branch
):
# new_f = t.type.filter_variable(f)
if
not
isinstance
(
input_t
,
Variable
):
if
not
t
.
type
.
is_super
(
f
.
type
):
input_t
=
as_symbolic
(
input_t
)
raise
TypeError
(
if
not
isinstance
(
input_f
,
Variable
):
"IfElse requires compatible types for true and false return values: "
input_f
=
as_symbolic
(
input_f
)
f
"true_branch={t.type}, false_branch={f.type}"
if
isinstance
(
input_t
.
type
,
HasDataType
)
and
isinstance
(
input_f
.
type
,
HasDataType
):
# TODO: Be smarter about dtype casting.
# up_dtype = aes.upcast(input_t.type.dtype, input_f.type.dtype)
if
input_t
.
type
.
dtype
!=
input_f
.
type
.
dtype
:
raise
TypeError
(
"IfElse requires compatible dtypes for both branches: got "
f
"true_branch={input_t.type.dtype}, false_branch={input_f.type.dtype}"
)
if
isinstance
(
input_t
.
type
,
HasShape
)
and
isinstance
(
input_f
.
type
,
HasShape
):
if
input_t
.
type
.
ndim
!=
input_f
.
type
.
ndim
:
raise
TypeError
(
"IfElse requires compatible ndim values for both branches: got "
f
"true_branch={input_t.type.ndim}, false_branch={input_f.type.ndim}"
)
# We can only use static shape information that corresponds
# in both branches, because the outputs of this `Op` are
# allowed to have distinct shapes from either branch
new_shape
=
tuple
(
s_t
if
s_t
==
s_f
else
None
for
s_t
,
s_f
in
zip
(
input_t
.
type
.
shape
,
input_f
.
type
.
shape
)
)
)
if
c
.
ndim
>
0
:
# TODO FIXME: The presence of this keyword is a strong
raise
TypeError
(
# assumption. Find something that's guaranteed by the/a
"Condition given to the op has to be a scalar "
# confirmed interface.
"with 0 standing for False, anything else "
output_type_t
=
input_t
.
type
.
clone
(
shape
=
new_shape
)()
"for True"
output_type_f
=
input_f
.
type
.
clone
(
shape
=
new_shape
)()
)
else
:
return
Apply
(
self
,
[
c
]
+
list
(
args
),
[
t
.
type
()
for
t
in
aes
])
output_type_t
=
input_t
.
type
()
output_type_f
=
input_f
.
type
()
input_t
=
output_type_f
.
type
.
convert_variable
(
input_t
)
input_f
=
output_type_t
.
type
.
convert_variable
(
input_f
)
new_inputs_true_branch
.
append
(
input_t
)
new_inputs_false_branch
.
append
(
input_f
)
output_vars
.
append
(
output_type_t
)
return
Apply
(
self
,
[
condition
]
+
new_inputs_true_branch
+
new_inputs_false_branch
,
output_vars
,
)
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
return
self
(
inputs
[
0
],
*
eval_points
[
1
:],
return_list
=
True
)
return
self
(
inputs
[
0
],
*
eval_points
[
1
:],
return_list
=
True
)
def
grad
(
self
,
ins
,
grads
):
def
grad
(
self
,
ins
,
grads
):
aes
=
ins
[
1
:][:
self
.
n_outs
]
fs
=
ins
[
1
:][
self
.
n_outs
:]
condition
=
ins
[
0
]
inputs_true_branch
=
ins
[
1
:][:
self
.
n_outs
]
inputs_false_branch
=
ins
[
1
:][
self
.
n_outs
:]
if
self
.
name
is
not
None
:
if
self
.
name
is
not
None
:
nw_name_t
=
self
.
name
+
"_grad_t"
nw_name_t
=
self
.
name
+
"_grad_t"
nw_name_f
=
self
.
name
+
"_grad_f"
nw_name_f
=
self
.
name
+
"_grad_f"
else
:
else
:
nw_name_t
=
None
nw_name_t
=
None
nw_name_f
=
None
nw_name_f
=
None
if_true_op
=
IfElse
(
n_outs
=
self
.
n_outs
,
as_view
=
self
.
as_view
,
name
=
nw_name_t
)
if_true_op
=
IfElse
(
n_outs
=
self
.
n_outs
,
as_view
=
self
.
as_view
,
name
=
nw_name_t
)
if_false_op
=
IfElse
(
n_outs
=
self
.
n_outs
,
as_view
=
self
.
as_view
,
name
=
nw_name_f
)
if_false_op
=
IfElse
(
n_outs
=
self
.
n_outs
,
as_view
=
self
.
as_view
,
name
=
nw_name_f
)
# The
grads can have a different dtype then the inputs
.
# The
`grads` can have different dtypes than the `inputs`
.
#
As inputs true/false pair must have the same dtype,
#
Since input true/false entries must have the same dtypes, we need to
#
we must cast the zeros to the corresponding grad dtype
#
cast the zeros to the corresponding `grads` dtypes and not the input
#
and not the input dtype
.
#
dtypes
.
i
f_true
=
(
i
nputs_true_grad
=
(
[
ins
[
0
]
]
[
condition
]
+
grads
+
grads
+
[
at
.
basic
.
zeros_like
(
t
,
dtype
=
grads
[
i
]
.
dtype
)
for
i
,
t
in
enumerate
(
aes
)]
+
[
at
.
basic
.
zeros_like
(
t
,
dtype
=
grads
[
i
]
.
dtype
)
for
i
,
t
in
enumerate
(
inputs_true_branch
)
]
)
)
if_false
=
(
inputs_false_grad
=
(
[
ins
[
0
]]
[
condition
]
+
[
at
.
basic
.
zeros_like
(
f
,
dtype
=
grads
[
i
]
.
dtype
)
for
i
,
f
in
enumerate
(
fs
)]
+
[
at
.
basic
.
zeros_like
(
f
,
dtype
=
grads
[
i
]
.
dtype
)
for
i
,
f
in
enumerate
(
inputs_false_branch
)
]
+
grads
+
grads
)
)
condition
=
ins
[
0
]
# `condition` does affect the elements of the output so it is connected.
# condition does affect the elements of the output so it is connected.
# For the sake of making the gradient convenient we assume that
# For the sake of making the gradient convenient we assume that
# condition + epsilon always triggers the same branch as condition
# condition + epsilon always triggers the same branch as condition
condition_grad
=
condition
.
zeros_like
()
.
astype
(
config
.
floatX
)
condition_grad
=
condition
.
zeros_like
()
.
astype
(
config
.
floatX
)
return
(
return
(
[
condition_grad
]
[
condition_grad
]
+
if_true_op
(
*
i
f_true
,
return_list
=
True
)
+
if_true_op
(
*
i
nputs_true_grad
,
return_list
=
True
)
+
if_false_op
(
*
i
f_false
,
return_list
=
True
)
+
if_false_op
(
*
i
nputs_false_grad
,
return_list
=
True
)
)
)
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
,
impl
=
None
):
def
make_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
,
impl
=
None
):
cond
=
node
.
inputs
[
0
]
cond
=
node
.
inputs
[
0
]
aes
=
node
.
inputs
[
1
:][:
self
.
n_outs
]
input_true_branch
=
node
.
inputs
[
1
:][:
self
.
n_outs
]
fs
=
node
.
inputs
[
1
:][
self
.
n_outs
:]
inputs_false_branch
=
node
.
inputs
[
1
:][
self
.
n_outs
:]
outputs
=
node
.
outputs
outputs
=
node
.
outputs
def
thunk
():
def
thunk
():
...
@@ -249,12 +292,12 @@ class IfElse(_NoPythonOp):
...
@@ -249,12 +292,12 @@ class IfElse(_NoPythonOp):
ls
=
[
ls
=
[
idx
+
1
idx
+
1
for
idx
in
range
(
self
.
n_outs
)
for
idx
in
range
(
self
.
n_outs
)
if
not
compute_map
[
aes
[
idx
]][
0
]
if
not
compute_map
[
input_true_branch
[
idx
]][
0
]
]
]
if
len
(
ls
)
>
0
:
if
len
(
ls
)
>
0
:
return
ls
return
ls
else
:
else
:
for
out
,
t
in
zip
(
outputs
,
aes
):
for
out
,
t
in
zip
(
outputs
,
input_true_branch
):
compute_map
[
out
][
0
]
=
1
compute_map
[
out
][
0
]
=
1
val
=
storage_map
[
t
][
0
]
val
=
storage_map
[
t
][
0
]
if
self
.
as_view
:
if
self
.
as_view
:
...
@@ -269,12 +312,12 @@ class IfElse(_NoPythonOp):
...
@@ -269,12 +312,12 @@ class IfElse(_NoPythonOp):
ls
=
[
ls
=
[
1
+
idx
+
self
.
n_outs
1
+
idx
+
self
.
n_outs
for
idx
in
range
(
self
.
n_outs
)
for
idx
in
range
(
self
.
n_outs
)
if
not
compute_map
[
fs
[
idx
]][
0
]
if
not
compute_map
[
inputs_false_branch
[
idx
]][
0
]
]
]
if
len
(
ls
)
>
0
:
if
len
(
ls
)
>
0
:
return
ls
return
ls
else
:
else
:
for
out
,
f
in
zip
(
outputs
,
fs
):
for
out
,
f
in
zip
(
outputs
,
inputs_false_branch
):
compute_map
[
out
][
0
]
=
1
compute_map
[
out
][
0
]
=
1
# can't view both outputs unless destroyhandler
# can't view both outputs unless destroyhandler
# improves
# improves
...
@@ -293,46 +336,42 @@ class IfElse(_NoPythonOp):
...
@@ -293,46 +336,42 @@ class IfElse(_NoPythonOp):
def
ifelse
(
def
ifelse
(
condition
:
Variable
,
condition
:
"TensorLike"
,
then_branch
:
Union
[
Variable
,
List
[
Variable
]],
then_branch
:
Union
[
Any
,
Sequence
[
Any
]],
else_branch
:
Union
[
Variable
,
List
[
Variable
]],
else_branch
:
Union
[
Any
,
Sequence
[
Any
]],
name
:
str
=
None
,
name
:
Optional
[
str
]
=
None
,
)
->
Union
[
Variable
,
Sequence
[
Variable
]]:
)
->
Union
[
Variable
,
Sequence
[
Variable
]]:
"""
"""Construct a graph for an ``if`` statement.
This function corresponds to an if statement, returning (and evaluating)
inputs in the ``then_branch`` if ``condition`` evaluates to True or
inputs in the ``else_branch`` if ``condition`` evaluates to False.
Parameters
Parameters
==========
----------
condition
condition
`
`condition`
` should be a tensor scalar representing the condition.
`
condition
` should be a tensor scalar representing the condition.
If it evaluates to
0 it corresponds to False, anything else stands
If it evaluates to
``0`` it corresponds to ``False``, anything else
for True
.
stands for ``True``
.
then_branch
then_branch
A single
aesara variable or a list of aesara
variables that the
A single
variable or a list of
variables that the
function should return as the output if `
`condition`
` evaluates to
function should return as the output if `
condition
` evaluates to
true. The number of variables should match those in the
true. The number of variables should match those in the
`
`else_branch``, and there should be a one to one correspondence
`
else_branch`, as well as the dtypes and numbers of dimensions of each
(type wise) with the tensors provided in the else branch
tensor.
else_branch
else_branch
A single aesara variable or a list of aesara variables that the
A single variable or a list of variables that the function should
function should return as the output if ``condition`` evaluates to
return as the output if `condition` evaluates to false. The number of
false. The number of variables should match those in the then branch,
variables should match those in `then_branch`, as well as the dtypes
and there should be a one to one correspondence (type wise) with the
and numbers of dimensions of each tensor.
tensors provided in the then branch.
Returns
Returns
=======
-------
A sequence of
aesara variables or a single variable (
depending on the
A sequence of
variables or a single variable,
depending on the
nature of
the ``then_branch`` and ``else_branch``). More exactly
if
nature of
`then_branch` and `else_branch`. More exactly,
if
`
`then_branch`` and ``else_branch`` is a tensor
, then
`
then_branch` and `else_branch` is are single variables
, then
the return variable will
be just a single variable, otherwise a
the return variable will
also be a single variable; otherwise, it will
sequence. The value returns correspond either to the values in the
be a sequence. The value returned correspond to either the values in
``then_branch`` or in the ``else_branch`
` depending on the value of
the `then_branch` or in the `else_branch
` depending on the value of
`
`condition`
`.
`
condition
`.
"""
"""
rval_type
=
None
rval_type
=
None
...
@@ -344,35 +383,17 @@ def ifelse(
...
@@ -344,35 +383,17 @@ def ifelse(
if
not
isinstance
(
else_branch
,
(
list
,
tuple
)):
if
not
isinstance
(
else_branch
,
(
list
,
tuple
)):
else_branch
=
[
else_branch
]
else_branch
=
[
else_branch
]
# Some of the elements might be converted into another type,
# we will store them in these new_... lists.
new_then_branch
=
[]
new_else_branch
=
[]
for
then_branch_elem
,
else_branch_elem
in
zip
(
then_branch
,
else_branch
):
if
not
isinstance
(
then_branch_elem
,
Variable
):
then_branch_elem
=
at
.
basic
.
as_tensor_variable
(
then_branch_elem
)
if
not
isinstance
(
else_branch_elem
,
Variable
):
else_branch_elem
=
at
.
basic
.
as_tensor_variable
(
else_branch_elem
)
# Make sure the types are compatible
else_branch_elem
=
then_branch_elem
.
type
.
filter_variable
(
else_branch_elem
)
then_branch_elem
=
else_branch_elem
.
type
.
filter_variable
(
then_branch_elem
)
new_then_branch
.
append
(
then_branch_elem
)
new_else_branch
.
append
(
else_branch_elem
)
if
len
(
then_branch
)
!=
len
(
else_branch
):
if
len
(
then_branch
)
!=
len
(
else_branch
):
raise
ValueError
(
raise
ValueError
(
"The number of values on the `then` branch"
"The number of values on the `then` branch "
" should have the same number of variables as "
"must match the `else` branch: got "
"the `else` branch : (variables on `then` "
f
"{len(then_branch)} for `then`, and "
f
"{len(then_branch)}, variables on `else` "
f
"{len(else_branch)} for `else`."
f
"{len(else_branch)})"
)
)
new_ifelse
=
IfElse
(
n_outs
=
len
(
then_branch
),
as_view
=
False
,
name
=
name
)
new_ifelse
=
IfElse
(
n_outs
=
len
(
then_branch
),
as_view
=
False
,
name
=
name
)
ins
=
[
condition
]
+
list
(
new_then_branch
)
+
list
(
new_
else_branch
)
ins
=
[
condition
]
+
list
(
then_branch
)
+
list
(
else_branch
)
rval
=
new_ifelse
(
*
ins
,
return_list
=
True
)
rval
=
new_ifelse
(
*
ins
,
return_list
=
True
)
if
rval_type
is
None
:
if
rval_type
is
None
:
...
...
tests/test_ifelse.py
浏览文件 @
b8e6b3ea
...
@@ -6,6 +6,7 @@ import pytest
...
@@ -6,6 +6,7 @@ import pytest
import
aesara
import
aesara
import
aesara.ifelse
import
aesara.ifelse
import
aesara.sparse
import
aesara.tensor.basic
as
at
import
aesara.tensor.basic
as
at
from
aesara
import
function
from
aesara
import
function
from
aesara.compile.mode
import
Mode
,
get_mode
from
aesara.compile.mode
import
Mode
,
get_mode
...
@@ -14,15 +15,19 @@ from aesara.graph.op import Op
...
@@ -14,15 +15,19 @@ from aesara.graph.op import Op
from
aesara.ifelse
import
IfElse
,
ifelse
from
aesara.ifelse
import
IfElse
,
ifelse
from
aesara.link.c.type
import
generic
from
aesara.link.c.type
import
generic
from
aesara.tensor.math
import
eq
from
aesara.tensor.math
import
eq
from
aesara.tensor.type
import
col
,
iscalar
,
matrix
,
row
,
scalar
,
tensor3
,
vector
from
aesara.tensor.type
import
(
col
,
iscalar
,
ivector
,
matrix
,
row
,
scalar
,
tensor3
,
vector
,
)
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
__docformat__
=
"restructedtext en"
__authors__
=
"Razvan Pascanu "
"PyMC Development Team "
"Aesara Developers "
__copyright__
=
"(c) 2010, Universite de Montreal"
class
TestIfelse
(
utt
.
OptimizationTestMixin
):
class
TestIfelse
(
utt
.
OptimizationTestMixin
):
mode
=
None
mode
=
None
dtype
=
aesara
.
config
.
floatX
dtype
=
aesara
.
config
.
floatX
...
@@ -41,7 +46,7 @@ class TestIfelse(utt.OptimizationTestMixin):
...
@@ -41,7 +46,7 @@ class TestIfelse(utt.OptimizationTestMixin):
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
IfElse
(
0
)(
c
,
x
,
x
)
IfElse
(
0
)(
c
,
x
,
x
)
def
test_const_
Op_argument
(
self
):
def
test_const_
false_branch
(
self
):
x
=
vector
(
"x"
,
dtype
=
self
.
dtype
)
x
=
vector
(
"x"
,
dtype
=
self
.
dtype
)
y
=
np
.
array
([
2.0
,
3.0
],
dtype
=
self
.
dtype
)
y
=
np
.
array
([
2.0
,
3.0
],
dtype
=
self
.
dtype
)
c
=
iscalar
(
"c"
)
c
=
iscalar
(
"c"
)
...
@@ -321,9 +326,6 @@ class TestIfelse(utt.OptimizationTestMixin):
...
@@ -321,9 +326,6 @@ class TestIfelse(utt.OptimizationTestMixin):
ifelse
(
cond
,
y
,
x
)
ifelse
(
cond
,
y
,
x
)
def
test_sparse_tensor_error
(
self
):
def
test_sparse_tensor_error
(
self
):
pytest
.
importorskip
(
"scipy"
,
minversion
=
"0.7.0"
)
import
aesara.sparse
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
data
=
rng
.
random
((
2
,
3
))
.
astype
(
self
.
dtype
)
data
=
rng
.
random
((
2
,
3
))
.
astype
(
self
.
dtype
)
...
@@ -527,6 +529,37 @@ class TestIfelse(utt.OptimizationTestMixin):
...
@@ -527,6 +529,37 @@ class TestIfelse(utt.OptimizationTestMixin):
res
.
owner
.
op
.
as_view
=
True
res
.
owner
.
op
.
as_view
=
True
assert
str
(
res
.
owner
)
.
startswith
(
"if{name,inplace}"
)
assert
str
(
res
.
owner
)
.
startswith
(
"if{name,inplace}"
)
@pytest.mark.parametrize
(
"x_shape, y_shape, x_val, y_val, exp_shape"
,
[
((
2
,),
(
3
,),
np
.
r_
[
1.0
,
2.0
],
np
.
r_
[
1.0
,
2.0
,
3.0
],
(
None
,)),
((
None
,),
(
3
,),
np
.
r_
[
1.0
,
2.0
],
np
.
r_
[
1.0
,
2.0
,
3.0
],
(
None
,)),
((
3
,),
(
3
,),
np
.
r_
[
1.0
,
2.0
,
3.0
],
np
.
r_
[
1.0
,
2.0
,
3.0
],
(
3
,)),
((
1
,),
(
3
,),
np
.
r_
[
1.0
],
np
.
r_
[
1.0
,
2.0
,
3.0
],
(
None
,)),
],
)
def
test_static_branch_shapes
(
self
,
x_shape
,
y_shape
,
x_val
,
y_val
,
exp_shape
):
x
=
at
.
tensor
(
dtype
=
self
.
dtype
,
shape
=
x_shape
,
name
=
"x"
)
y
=
at
.
tensor
(
dtype
=
self
.
dtype
,
shape
=
y_shape
,
name
=
"y"
)
c
=
iscalar
(
"c"
)
z
=
IfElse
(
1
)(
c
,
x
,
y
)
assert
z
.
type
.
shape
==
exp_shape
f
=
function
([
c
,
x
,
y
],
z
,
mode
=
self
.
mode
)
x_val
=
x_val
.
astype
(
self
.
dtype
)
y_val
=
y_val
.
astype
(
self
.
dtype
)
val
=
f
(
0
,
x_val
,
y_val
)
assert
np
.
array_equal
(
val
,
y_val
)
def
test_nonscalar_condition
(
self
):
x
=
vector
(
"x"
)
y
=
vector
(
"y"
)
c
=
ivector
(
"c"
)
with
pytest
.
raises
(
TypeError
,
match
=
"The condition argument"
):
IfElse
(
1
)(
c
,
x
,
y
)
class
IfElseIfElseIf
(
Op
):
class
IfElseIfElseIf
(
Op
):
def
__init__
(
self
,
inplace
=
False
):
def
__init__
(
self
,
inplace
=
False
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论